From 874b108ada62c150b8e8fbc49fb3204ca341c3ba Mon Sep 17 00:00:00 2001 From: Zach Mullen Date: Sat, 12 Aug 2023 20:09:14 -0400 Subject: [PATCH] WIP django-ninja cursor pagination implementation This is essentially a direct port of DRF's cursor-based pagination impl to django-ninja. --- isic/conftest.py | 1 + isic/core/api/collection.py | 3 + isic/core/pagination.py | 257 +++++++++++++++++++++++++ isic/core/tests/test_api_collection.py | 14 +- 4 files changed, 268 insertions(+), 7 deletions(-) diff --git a/isic/conftest.py b/isic/conftest.py index fae2ed56..cdb2a5db 100644 --- a/isic/conftest.py +++ b/isic/conftest.py @@ -40,6 +40,7 @@ def setup_groups(request): @pytest.fixture def api_client() -> APIClient: + # TODO have this return a django.test.Client instead of DRF's APIClient return APIClient() diff --git a/isic/core/api/collection.py b/isic/core/api/collection.py index 2c5ecfd7..03efe299 100644 --- a/isic/core/api/collection.py +++ b/isic/core/api/collection.py @@ -3,11 +3,13 @@ from django.http.response import JsonResponse from django.shortcuts import get_object_or_404 from ninja import ModelSchema, Router, Schema +from ninja.pagination import paginate from pydantic.types import conlist, constr from isic.core.constants import ISIC_ID_REGEX from isic.core.models.collection import Collection from isic.core.models.image import Image +from isic.core.pagination import CursorPagination from isic.core.permissions import get_visible_objects from isic.core.serializers import SearchQueryIn from isic.core.services.collection.image import ( @@ -31,6 +33,7 @@ def resolve_doi_url(self, obj: Collection): @router.get("/", response=list[CollectionOut]) +@paginate(CursorPagination) def collection_list(request, pinned: bool | None = None) -> list[CollectionOut]: queryset = Collection.objects.all() if pinned is not None: diff --git a/isic/core/pagination.py b/isic/core/pagination.py index a4def6ce..53146f80 100644 --- a/isic/core/pagination.py +++ b/isic/core/pagination.py @@ -1,5 +1,13 @@ +from base64 import b64decode, b64encode from collections import OrderedDict +from dataclasses import dataclass +from typing import Any +from urllib import parse +from django.db.models.query import QuerySet +from django.http.request import HttpRequest +from ninja import Schema +from ninja.pagination import PaginationBase from rest_framework.pagination import CursorPagination from rest_framework.response import Response @@ -23,3 +31,252 @@ def get_paginated_response(self, data): ] ) ) + +################## + +@dataclass +class Cursor: + offset: int = 0 + reverse: bool = False + position: str | None = None + + +def _clamp(val: int, min_: int, max_: int) -> int: + return max(min_, min(val, max_)) + + +def _reverse_order(order: tuple): + """ + Given an order_by tuple such as `('-created', 'uuid')` reverse the + ordering and return a new tuple, eg. `('created', '-uuid')`. + """ + def invert(x): + return x[1:] if x.startswith('-') else f'-{x}' + + return tuple(invert(item) for item in order) + + +def _replace_query_param(url: str, key: str, val: str): + scheme, netloc, path, query, fragment = parse.urlsplit(url) + query_dict = parse.parse_qs(query, keep_blank_values=True) + query_dict[key] = [val] + query = parse.urlencode(sorted(query_dict.items()), doseq=True) + return parse.urlunsplit((scheme, netloc, path, query, fragment)) + + +class CursorPagination(PaginationBase): + class Input(Schema): + limit: int | None + cursor: str | None + + class Output(Schema): + results: list[Any] + count: int + next: str | None + previous: str | None + + items_attribute = "results" + max_page_size = 100 + _offset_cutoff = 100 # limit to protect against possibly malicious queries + + def paginate_queryset(self, queryset: QuerySet, pagination: Input, request: HttpRequest, **params) -> dict: + limit = _clamp(pagination.limit or self.max_page_size, 0, self.max_page_size) + order = queryset.query.order_by # e.g. ('-created',) + total_count = queryset.count() + + base_url = request.build_absolute_uri() + cursor = self._decode_cursor(pagination.cursor) + + if cursor.reverse: + queryset = queryset.order_by(*_reverse_order(order)) + + if cursor.position is not None: + is_reversed = order[0].startswith('-') + order_attr = order[0].lstrip('-') + + if cursor.reverse != is_reversed: + queryset = queryset.filter(**{f"{order_attr}__lt": cursor.position}) + else: + queryset = queryset.filter(**{f"{order_attr}__gt": cursor.position}) + + # If we have an offset cursor then offset the entire page by that amount. + # We also always fetch an extra item in order to determine if there is a + # page following on from this one. + results = list(queryset[cursor.offset:cursor.offset + limit + 1]) + page = list(results[:limit]) + + # Determine the position of the final item following the page. + if len(results) > len(page): + has_following_position = True + following_position = self._get_position_from_instance(results[-1], order) + else: + has_following_position = False + following_position = None + + if cursor.reverse: + # If we have a reverse queryset, then the query ordering was in reverse + # so we need to reverse the items again before returning them to the user. + page = list(reversed(page)) + + has_next = (cursor.position is not None) or (cursor.offset > 0) + has_previous = has_following_position + if has_next: + next_position = cursor.position + if has_previous: + previous_position = following_position + else: + has_next = has_following_position + has_previous = (cursor.position is not None) or (cursor.offset > 0) + if has_next: + next_position = following_position + if has_previous: + previous_position = cursor.position + + return { + "results": page, + "count": total_count, + "next": self.next_link(base_url, page, cursor, order, has_previous, limit, next_position, previous_position) if has_next else None, + "previous": self.previous_link(base_url, page, cursor, order, has_next, limit, next_position, previous_position) if has_previous else None, + } + + def _decode_cursor(self, encoded_cursor: str | None) -> Cursor: + if encoded_cursor is None: + return Cursor() + + try: + querystring = b64decode(encoded_cursor) + tokens = parse.parse_qs(querystring, keep_blank_values=True) + + offset = int(tokens.get('o', ['0'])[0]) + offset = _clamp(offset, 0, self._offset_cutoff) + + reverse = tokens.get('r', ['0'])[0] + reverse = bool(int(reverse)) + + position = tokens.get('p', [None])[0] + except (TypeError, ValueError): + # TODO what should be done here if someone passed an invalid cursor string? + raise + + return Cursor(offset=offset, reverse=reverse, position=position) + + def _encode_cursor(self, cursor: Cursor, base_url: str) -> str: + tokens = {} + if cursor.offset != 0: + tokens['o'] = str(cursor.offset) + if cursor.reverse: + tokens['r'] = '1' + if cursor.position is not None: + tokens['p'] = cursor.position + + querystring = parse.urlencode(tokens, doseq=True) + encoded = b64encode(querystring) + return _replace_query_param(base_url, "cursor", encoded) + + def next_link(self, base_url: str, page: list, cursor: Cursor, order: tuple, has_previous: bool, limit: int, next_position: str, previous_position: str) -> str: + if page and cursor.reverse and cursor.offset: + # If we're reversing direction and we have an offset cursor + # then we cannot use the first position we find as a marker. + compare = self._get_position_from_instance(page[-1], order) + else: + compare = next_position + offset = 0 + + has_item_with_unique_position = False + for item in reversed(page): + position = self._get_position_from_instance(item, order) + if position != compare: + # The item in this position and the item following it + # have different positions. We can use this position as + # our marker. + has_item_with_unique_position = True + break + + # The item in this position has the same position as the item + # following it, we can't use it as a marker position, so increment + # the offset and keep seeking to the previous item. + compare = position + offset += 1 + + if page and not has_item_with_unique_position: + # There were no unique positions in the page. + if not has_previous: + # We are on the first page. + # Our cursor will have an offset equal to the page size, + # but no position to filter against yet. + offset = limit + position = None + elif cursor.reverse: + # The change in direction will introduce a paging artifact, + # where we end up skipping forward a few extra items. + offset = 0 + position = previous_position + else: + # Use the position from the existing cursor and increment + # it's offset by the page size. + offset = cursor.offset + limit + position = previous_position + + if not page: + position = next_position + + next_cursor = Cursor(offset=offset, reverse=False, position=position) + return self._encode_cursor(next_cursor, base_url) + + def previous_link(self, base_url: str, page: list, cursor: Cursor, order: tuple, has_next: bool, limit: int, next_position: str, previous_position: str): + if page and not cursor.reverse and cursor.offset: + # If we're reversing direction and we have an offset cursor + # then we cannot use the first position we find as a marker. + compare = self._get_position_from_instance(page[0], order) + else: + compare = previous_position + offset = 0 + + has_item_with_unique_position = False + for item in page: + position = self._get_position_from_instance(item, order) + if position != compare: + # The item in this position and the item following it + # have different positions. We can use this position as + # our marker. + has_item_with_unique_position = True + break + + # The item in this position has the same position as the item + # following it, we can't use it as a marker position, so increment + # the offset and keep seeking to the previous item. + compare = position + offset += 1 + + if page and not has_item_with_unique_position: + # There were no unique positions in the page. + if not has_next: + # We are on the final page. + # Our cursor will have an offset equal to the page size, + # but no position to filter against yet. + offset = limit + position = None + elif cursor.reverse: + # Use the position from the existing cursor and increment + # it's offset by the page size. + offset = cursor.offset + limit + position = next_position + else: + # The change in direction will introduce a paging artifact, + # where we end up skipping back a few extra items. + offset = 0 + position = next_position + + if not page: + position = previous_position + + cursor = Cursor(offset=offset, reverse=True, position=position) + return self._encode_cursor(cursor, base_url) + + def _get_position_from_instance(self, instance, ordering): + field_name = ordering[0].lstrip('-') + if isinstance(instance, dict): + attr = instance[field_name] + else: + attr = getattr(instance, field_name) + return str(attr) diff --git a/isic/core/tests/test_api_collection.py b/isic/core/tests/test_api_collection.py index 13e00692..d35af674 100644 --- a/isic/core/tests/test_api_collection.py +++ b/isic/core/tests/test_api_collection.py @@ -27,8 +27,8 @@ def collections(public_collection, private_collection): def test_core_api_collection_list_permissions(client, colls, num_visible): r = client.get("/api/v2/collections/") - assert r.status_code == 200, r.data - assert r.data["count"] == num_visible + assert r.status_code == 200, r.json() + assert r.json()["count"] == num_visible @pytest.mark.django_db @@ -63,10 +63,10 @@ def test_core_api_collection_detail_permissions(client, collection, visible): r = client.get(f"/api/v2/collections/{collection.pk}/") if visible: - assert r.status_code == 200, r.data - assert r.data["id"] == collection.id + assert r.status_code == 200, r.json() + assert r.json()["id"] == collection.id else: - assert r.status_code == 404, r.data + assert r.status_code == 404, r.json() @pytest.mark.django_db @@ -87,7 +87,7 @@ def test_core_api_collection_populate_from_search( f"/api/v2/collections/{collection.pk}/populate-from-search/", {"query": "sex:male"} ) - assert r.status_code == 202, r.data + assert r.status_code == 202, r.json() assert collection.images.count() == 1 assert collection.images.first().accession.metadata["sex"] == "male" @@ -106,7 +106,7 @@ def test_core_api_collection_modify_locked(endpoint, data, staff_client, collect r = staff_client.post(f"/api/v2/collections/{collection.pk}/{endpoint}/", data) - assert r.status_code == 409, r.data + assert r.status_code == 409, r.json() @pytest.mark.django_db