diff --git a/isic/core/api/image.py b/isic/core/api/image.py index cfa16f6f..4e6f4bea 100644 --- a/isic/core/api/image.py +++ b/isic/core/api/image.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any from django.conf import settings from django.http.request import HttpRequest @@ -9,11 +9,10 @@ from ninja.pagination import paginate from pyparsing.exceptions import ParseException -from isic.core.dsl import es_parser, parse_query from isic.core.models import Collection, Image -from isic.core.pagination import CursorPagination +from isic.core.pagination import CursorPagination, qs_with_hardcoded_count from isic.core.permissions import get_visible_objects -from isic.core.search import build_elasticsearch_query, facets +from isic.core.search import facets, get_elasticsearch_client from isic.core.serializers import SearchQueryIn router = Router() @@ -90,7 +89,17 @@ def resolve_metadata(image: Image) -> dict: ) @paginate(CursorPagination) def list_images(request: HttpRequest): - return get_visible_objects(request.user, "core.view_image", default_qs) + qs = get_visible_objects(request.user, "core.view_image", default_qs) + + if settings.ISIC_USE_ELASTICSEARCH_COUNTS: + es_query = SearchQueryIn().to_es_query(request.user) + es_count = get_elasticsearch_client().count( + index=settings.ISIC_ELASTICSEARCH_INDEX, + body={"query": es_query}, + )["count"] + return qs_with_hardcoded_count(qs, es_count) + + return qs @router.get( @@ -103,30 +112,33 @@ def list_images(request: HttpRequest): @paginate(CursorPagination) def search_images(request: HttpRequest, search: SearchQueryIn = Query(...)): try: - return search.to_queryset(user=request.user, qs=default_qs) + qs = search.to_queryset(user=request.user, qs=default_qs) + if settings.ISIC_USE_ELASTICSEARCH_COUNTS: + es_query = search.to_es_query(request.user) except ParseException as e: # Normally we'd like this to be handled by the input serializer validation, but # for backwards compatibility we must return 400 rather than 422. # The pagination wrapper means we can't just return the response we'd like from here. # The handler for this exception type is defined in urls.py. raise ImageSearchParseError from e + else: + if settings.ISIC_USE_ELASTICSEARCH_COUNTS: + es_count = get_elasticsearch_client().count( + index=settings.ISIC_ELASTICSEARCH_INDEX, + body={"query": es_query}, + )["count"] + return qs_with_hardcoded_count(qs, es_count) + + return qs @router.get("/facets/", response=dict, include_in_schema=False) def get_facets(request: HttpRequest, search: SearchQueryIn = Query(...)): - es_query: dict | None = None - if search.query: - try: - # we know it can't be a Q object because we're using es_parser - es_query = cast(dict | None, parse_query(es_parser, search.query)) - except ParseException as e: - raise ImageSearchParseError from e - - query = build_elasticsearch_query( - es_query or {}, - request.user, - search.collections, - ) + try: + query = search.to_es_query(request.user) + except ParseException as e: + raise ImageSearchParseError from e + # Manually pass the list of visible collection PKs through so buckets with # counts of 0 aren't included in the facets output for non-visible collections. collection_pks = list( diff --git a/isic/core/pagination.py b/isic/core/pagination.py index bb3e00c1..7b804248 100644 --- a/isic/core/pagination.py +++ b/isic/core/pagination.py @@ -10,6 +10,25 @@ from pydantic import field_validator +def qs_with_hardcoded_count(qs: QuerySet, count: int) -> QuerySet: + """ + Modify a queryset to return a hardcoded count rather than querying the database. + + This is useful when the count can be obtained with a cheaper method instead of + the default queryset.count() method, e.g. elasticsearch, a separate query with + fewer joins, etc. + """ + # This is an unfortunate bit of hackery to get around the fact that the CursorPagination class + # adds an order by which clones the queryset, overriding our hardcoded count. We have to repeat + # the logic here to make sure the paginator doesn't modify our queryset. + if not qs.query.order_by: + qs = qs.order_by(*CursorPagination.default_ordering) + + qs.count = lambda: count + + return qs + + @dataclass class Cursor: offset: int = 0 diff --git a/isic/core/serializers.py b/isic/core/serializers.py index 7108d17c..f21c40c8 100644 --- a/isic/core/serializers.py +++ b/isic/core/serializers.py @@ -1,14 +1,18 @@ from __future__ import annotations +from typing import cast + from django.contrib.auth.models import AnonymousUser, User from django.db.models.query import QuerySet from django.shortcuts import get_object_or_404 from ninja import Schema from pydantic import field_validator +from isic.core.dsl import es_parser, parse_query from isic.core.models import Image from isic.core.models.collection import Collection from isic.core.permissions import get_visible_objects +from isic.core.search import build_elasticsearch_query class SearchQueryIn(Schema): @@ -70,3 +74,15 @@ def to_queryset(self, user: User, qs: QuerySet[Image] | None = None) -> QuerySet ) return get_visible_objects(user, "core.view_image", qs).distinct() + + def to_es_query(self, user: User) -> dict: + es_query: dict | None = None + if self.query: + # we know it can't be a Q object because we're using es_parser and not django_parser + es_query = cast(dict | None, parse_query(es_parser, self.query)) + + return build_elasticsearch_query( + es_query or {}, + user, + self.collections, + ) diff --git a/isic/core/tests/test_search.py b/isic/core/tests/test_search.py index d5ad9f05..83f31942 100644 --- a/isic/core/tests/test_search.py +++ b/isic/core/tests/test_search.py @@ -1,3 +1,4 @@ +from django.urls import reverse from isic_metadata.fields import ImageTypeEnum import pytest from pytest_lazy_fixtures import lf @@ -31,6 +32,25 @@ def searchable_images(image_factory, _search_index): return images +@pytest.mark.django_db() +@pytest.mark.parametrize( + "route", + ["api:search_images", "api:list_images"], +) +def test_elasticsearch_counts(searchable_images, settings, client, route): + settings.ISIC_USE_ELASTICSEARCH_COUNTS = False + + r = client.get(reverse(route)) + assert r.status_code == 200, r.json() + assert r.json()["count"] == 1, r.json() + + settings.ISIC_USE_ELASTICSEARCH_COUNTS = True + + r = client.get(reverse(route)) + assert r.status_code == 200, r.json() + assert r.json()["count"] == 1, r.json() + + @pytest.fixture() def searchable_image_with_private_field(image_factory, _search_index): image = image_factory(public=True, accession__age=50) diff --git a/isic/settings.py b/isic/settings.py index b5d1c78f..499b9bb6 100644 --- a/isic/settings.py +++ b/isic/settings.py @@ -2,6 +2,7 @@ from pathlib import Path from botocore.config import Config +from celery.schedules import crontab from composed_configuration import ( ComposedConfiguration, ConfigMixin, @@ -142,6 +143,12 @@ def mutate_configuration(configuration: ComposedConfiguration) -> None: # https://github.com/noripyt/django-cachalot/issues/266. CACHALOT_FINAL_SQL_CHECK = True + # This is an unfortunate feature flag that lets us disable this feature in testing, + # where having a permanently available ES index which is updated consistently in real + # time is too difficult. We hedge by having tests that verify our counts are correct + # with both methods. + ISIC_USE_ELASTICSEARCH_COUNTS = values.BooleanValue(False) + ISIC_ELASTICSEARCH_URI = values.SecretValue() ISIC_ELASTICSEARCH_INDEX = "isic" ISIC_GUI_URL = "https://www.isic-archive.com/" @@ -179,7 +186,7 @@ def mutate_configuration(configuration: ComposedConfiguration) -> None: }, "sync-elasticsearch-index": { "task": "isic.core.tasks.sync_elasticsearch_index_task", - "schedule": timedelta(hours=12), + "schedule": crontab(minute="0", hour="0"), }, } @@ -273,6 +280,7 @@ def mutate_configuration(configuration: ComposedConfiguration): class HerokuProductionConfiguration(IsicMixin, HerokuProductionBaseConfiguration): ISIC_DATACITE_DOI_PREFIX = "10.34970" ISIC_ELASTICSEARCH_URI = values.SecretValue(environ_name="SEARCHBOX_URL", environ_prefix=None) + ISIC_USE_ELASTICSEARCH_COUNTS = True CACHES = CacheURLValue(environ_name="STACKHERO_REDIS_URL_TLS", environ_prefix=None)