Skip to content

Commit

Permalink
WIP django-ninja cursor pagination implementation
Browse files Browse the repository at this point in the history
This is essentially a direct port of DRF's cursor-based pagination impl to django-ninja.
  • Loading branch information
zachmullen committed Aug 13, 2023
1 parent 9a3ac7b commit 874b108
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 7 deletions.
1 change: 1 addition & 0 deletions isic/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def setup_groups(request):

@pytest.fixture
def api_client() -> APIClient:
# TODO have this return a django.test.Client instead of DRF's APIClient
return APIClient()


Expand Down
3 changes: 3 additions & 0 deletions isic/core/api/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from django.http.response import JsonResponse
from django.shortcuts import get_object_or_404
from ninja import ModelSchema, Router, Schema
from ninja.pagination import paginate
from pydantic.types import conlist, constr

from isic.core.constants import ISIC_ID_REGEX
from isic.core.models.collection import Collection
from isic.core.models.image import Image
from isic.core.pagination import CursorPagination
from isic.core.permissions import get_visible_objects
from isic.core.serializers import SearchQueryIn
from isic.core.services.collection.image import (
Expand All @@ -31,6 +33,7 @@ def resolve_doi_url(self, obj: Collection):


@router.get("/", response=list[CollectionOut])
@paginate(CursorPagination)
def collection_list(request, pinned: bool | None = None) -> list[CollectionOut]:
queryset = Collection.objects.all()
if pinned is not None:
Expand Down
257 changes: 257 additions & 0 deletions isic/core/pagination.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from base64 import b64decode, b64encode
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any
from urllib import parse

from django.db.models.query import QuerySet
from django.http.request import HttpRequest
from ninja import Schema
from ninja.pagination import PaginationBase
from rest_framework.pagination import CursorPagination
from rest_framework.response import Response

Expand All @@ -23,3 +31,252 @@ def get_paginated_response(self, data):
]
)
)

##################

@dataclass
class Cursor:
offset: int = 0
reverse: bool = False
position: str | None = None


def _clamp(val: int, min_: int, max_: int) -> int:
return max(min_, min(val, max_))


def _reverse_order(order: tuple):
"""
Given an order_by tuple such as `('-created', 'uuid')` reverse the
ordering and return a new tuple, eg. `('created', '-uuid')`.
"""
def invert(x):
return x[1:] if x.startswith('-') else f'-{x}'

return tuple(invert(item) for item in order)


def _replace_query_param(url: str, key: str, val: str):
scheme, netloc, path, query, fragment = parse.urlsplit(url)
query_dict = parse.parse_qs(query, keep_blank_values=True)
query_dict[key] = [val]
query = parse.urlencode(sorted(query_dict.items()), doseq=True)
return parse.urlunsplit((scheme, netloc, path, query, fragment))


class CursorPagination(PaginationBase):
class Input(Schema):
limit: int | None
cursor: str | None

class Output(Schema):
results: list[Any]
count: int
next: str | None
previous: str | None

items_attribute = "results"
max_page_size = 100
_offset_cutoff = 100 # limit to protect against possibly malicious queries

def paginate_queryset(self, queryset: QuerySet, pagination: Input, request: HttpRequest, **params) -> dict:
limit = _clamp(pagination.limit or self.max_page_size, 0, self.max_page_size)
order = queryset.query.order_by # e.g. ('-created',)
total_count = queryset.count()

base_url = request.build_absolute_uri()
cursor = self._decode_cursor(pagination.cursor)

if cursor.reverse:
queryset = queryset.order_by(*_reverse_order(order))

if cursor.position is not None:
is_reversed = order[0].startswith('-')
order_attr = order[0].lstrip('-')

if cursor.reverse != is_reversed:
queryset = queryset.filter(**{f"{order_attr}__lt": cursor.position})
else:
queryset = queryset.filter(**{f"{order_attr}__gt": cursor.position})

# If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a
# page following on from this one.
results = list(queryset[cursor.offset:cursor.offset + limit + 1])
page = list(results[:limit])

# Determine the position of the final item following the page.
if len(results) > len(page):
has_following_position = True
following_position = self._get_position_from_instance(results[-1], order)
else:
has_following_position = False
following_position = None

if cursor.reverse:
# If we have a reverse queryset, then the query ordering was in reverse
# so we need to reverse the items again before returning them to the user.
page = list(reversed(page))

has_next = (cursor.position is not None) or (cursor.offset > 0)
has_previous = has_following_position
if has_next:
next_position = cursor.position
if has_previous:
previous_position = following_position
else:
has_next = has_following_position
has_previous = (cursor.position is not None) or (cursor.offset > 0)
if has_next:
next_position = following_position
if has_previous:
previous_position = cursor.position

return {
"results": page,
"count": total_count,
"next": self.next_link(base_url, page, cursor, order, has_previous, limit, next_position, previous_position) if has_next else None,
"previous": self.previous_link(base_url, page, cursor, order, has_next, limit, next_position, previous_position) if has_previous else None,
}

def _decode_cursor(self, encoded_cursor: str | None) -> Cursor:
if encoded_cursor is None:
return Cursor()

try:
querystring = b64decode(encoded_cursor)
tokens = parse.parse_qs(querystring, keep_blank_values=True)

offset = int(tokens.get('o', ['0'])[0])
offset = _clamp(offset, 0, self._offset_cutoff)

reverse = tokens.get('r', ['0'])[0]
reverse = bool(int(reverse))

position = tokens.get('p', [None])[0]
except (TypeError, ValueError):
# TODO what should be done here if someone passed an invalid cursor string?
raise

return Cursor(offset=offset, reverse=reverse, position=position)

def _encode_cursor(self, cursor: Cursor, base_url: str) -> str:
tokens = {}
if cursor.offset != 0:
tokens['o'] = str(cursor.offset)
if cursor.reverse:
tokens['r'] = '1'
if cursor.position is not None:
tokens['p'] = cursor.position

querystring = parse.urlencode(tokens, doseq=True)
encoded = b64encode(querystring)
return _replace_query_param(base_url, "cursor", encoded)

def next_link(self, base_url: str, page: list, cursor: Cursor, order: tuple, has_previous: bool, limit: int, next_position: str, previous_position: str) -> str:
if page and cursor.reverse and cursor.offset:
# If we're reversing direction and we have an offset cursor
# then we cannot use the first position we find as a marker.
compare = self._get_position_from_instance(page[-1], order)
else:
compare = next_position
offset = 0

has_item_with_unique_position = False
for item in reversed(page):
position = self._get_position_from_instance(item, order)
if position != compare:
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
break

# The item in this position has the same position as the item
# following it, we can't use it as a marker position, so increment
# the offset and keep seeking to the previous item.
compare = position
offset += 1

if page and not has_item_with_unique_position:
# There were no unique positions in the page.
if not has_previous:
# We are on the first page.
# Our cursor will have an offset equal to the page size,
# but no position to filter against yet.
offset = limit
position = None
elif cursor.reverse:
# The change in direction will introduce a paging artifact,
# where we end up skipping forward a few extra items.
offset = 0
position = previous_position
else:
# Use the position from the existing cursor and increment
# it's offset by the page size.
offset = cursor.offset + limit
position = previous_position

if not page:
position = next_position

next_cursor = Cursor(offset=offset, reverse=False, position=position)
return self._encode_cursor(next_cursor, base_url)

def previous_link(self, base_url: str, page: list, cursor: Cursor, order: tuple, has_next: bool, limit: int, next_position: str, previous_position: str):
if page and not cursor.reverse and cursor.offset:
# If we're reversing direction and we have an offset cursor
# then we cannot use the first position we find as a marker.
compare = self._get_position_from_instance(page[0], order)
else:
compare = previous_position
offset = 0

has_item_with_unique_position = False
for item in page:
position = self._get_position_from_instance(item, order)
if position != compare:
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
break

# The item in this position has the same position as the item
# following it, we can't use it as a marker position, so increment
# the offset and keep seeking to the previous item.
compare = position
offset += 1

if page and not has_item_with_unique_position:
# There were no unique positions in the page.
if not has_next:
# We are on the final page.
# Our cursor will have an offset equal to the page size,
# but no position to filter against yet.
offset = limit
position = None
elif cursor.reverse:
# Use the position from the existing cursor and increment
# it's offset by the page size.
offset = cursor.offset + limit
position = next_position
else:
# The change in direction will introduce a paging artifact,
# where we end up skipping back a few extra items.
offset = 0
position = next_position

if not page:
position = previous_position

cursor = Cursor(offset=offset, reverse=True, position=position)
return self._encode_cursor(cursor, base_url)

def _get_position_from_instance(self, instance, ordering):
field_name = ordering[0].lstrip('-')
if isinstance(instance, dict):
attr = instance[field_name]
else:
attr = getattr(instance, field_name)
return str(attr)
14 changes: 7 additions & 7 deletions isic/core/tests/test_api_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def collections(public_collection, private_collection):
def test_core_api_collection_list_permissions(client, colls, num_visible):
r = client.get("/api/v2/collections/")

assert r.status_code == 200, r.data
assert r.data["count"] == num_visible
assert r.status_code == 200, r.json()
assert r.json()["count"] == num_visible


@pytest.mark.django_db
Expand Down Expand Up @@ -63,10 +63,10 @@ def test_core_api_collection_detail_permissions(client, collection, visible):
r = client.get(f"/api/v2/collections/{collection.pk}/")

if visible:
assert r.status_code == 200, r.data
assert r.data["id"] == collection.id
assert r.status_code == 200, r.json()
assert r.json()["id"] == collection.id
else:
assert r.status_code == 404, r.data
assert r.status_code == 404, r.json()


@pytest.mark.django_db
Expand All @@ -87,7 +87,7 @@ def test_core_api_collection_populate_from_search(
f"/api/v2/collections/{collection.pk}/populate-from-search/", {"query": "sex:male"}
)

assert r.status_code == 202, r.data
assert r.status_code == 202, r.json()
assert collection.images.count() == 1
assert collection.images.first().accession.metadata["sex"] == "male"

Expand All @@ -106,7 +106,7 @@ def test_core_api_collection_modify_locked(endpoint, data, staff_client, collect

r = staff_client.post(f"/api/v2/collections/{collection.pk}/{endpoint}/", data)

assert r.status_code == 409, r.data
assert r.status_code == 409, r.json()


@pytest.mark.django_db
Expand Down

0 comments on commit 874b108

Please sign in to comment.