diff --git a/isic/core/dsl.py b/isic/core/dsl.py index 31eee327..c3a1298f 100644 --- a/isic/core/dsl.py +++ b/isic/core/dsl.py @@ -1,8 +1,10 @@ from __future__ import barry_as_FLUFL +from dataclasses import dataclass + from django.db.models.query_utils import Q from isic_metadata import FIELD_REGISTRY -from pyparsing import Keyword, ParserElement, Word, alphas, infixNotation, nums, opAssoc +from pyparsing import Keyword, Optional, ParserElement, Word, alphas, infixNotation, nums, opAssoc from pyparsing.common import pyparsing_common from pyparsing.core import Literal, OneOrMore, Or, QuotedString, Suppress from pyparsing.helpers import one_of @@ -11,33 +13,48 @@ ParserElement.enablePackrat() +@dataclass(frozen=True) +class SearchTermKey: + field_lookup: str + negated: bool = False + + class Value: - def to_q(self, key): - return Q(**{key: self.value}) + def to_q(self, key: SearchTermKey) -> Q: + if self.value == "*": + return Q(**{f"{key.field_lookup}__isnull": False}, _negated=key.negated) + else: + return Q(**{key.field_lookup: self.value}, _negated=key.negated) class BoolValue(Value): def __init__(self, toks) -> None: - self.value = True if toks[0] == "true" else False + if toks[0] == "*": + self.value = "*" + else: + self.value = True if toks[0] == "true" else False class StrValue(Value): def __init__(self, toks) -> None: self.value = toks[0] - def to_q(self, key): + def to_q(self, key: SearchTermKey) -> Q: # Special casing for image type renaming, see # https://linear.app/isic/issue/ISIC-138#comment-93029f64 # TODO: Remove this once better error messages are put in place. - if key == "accession__metadata__image_type" and self.value == "clinical": + if key.field_lookup == "accession__metadata__image_type" and self.value == "clinical": self.value = "clinical: close-up" - elif key == "accession__metadata__image_type" and self.value == "overview": + elif key.field_lookup == "accession__metadata__image_type" and self.value == "overview": self.value = "clinical: overview" + # so asterisk is any present value + if self.value == "*": + return Q(**{f"{key.field_lookup}__isnull": False}, _negated=key.negated) if self.value.startswith("*"): - return Q(**{f"{key}__endswith": self.value[1:]}) + return Q(**{f"{key.field_lookup}__endswith": self.value[1:]}, _negated=key.negated) elif self.value.endswith("*"): - return Q(**{f"{key}__startswith": self.value[:-1]}) + return Q(**{f"{key.field_lookup}__startswith": self.value[:-1]}, _negated=key.negated) else: return super().to_q(key) @@ -53,10 +70,13 @@ def __init__(self, toks) -> None: self.upper_lookup = "lte" if toks[-1] == "]" else "lt" self.value = (toks[1].value, toks[2].value) - def to_q(self, key): - start_key, end_key = f"{key}__{self.lower_lookup}", f"{key}__{self.upper_lookup}" + def to_q(self, key: SearchTermKey) -> Q: + start_key, end_key = ( + f"{key.field_lookup}__{self.lower_lookup}", + f"{key.field_lookup}__{self.upper_lookup}", + ) start_value, end_value = self.value - return Q(**{start_key: start_value}, **{end_key: end_value}) + return Q(**{start_key: start_value}, **{end_key: end_value}, _negated=key.negated) def q(s, loc, toks): @@ -97,36 +117,50 @@ def q_or(s, loc, toks): AND = Suppress(Keyword("AND")) OR = Suppress(Keyword("OR")) +# Note that the Lucene DSL treats a single asterisk as a replacement for whether +# the field exists and has a non null value. +# https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-query-string-query.html#query-string-wildcard +EXISTS = Literal("*") + # asterisks for wildcard, _ for ISIC ID search, - for license types str_value = (Word(alphas + nums + "*" + "_" + "-") | QuotedString('"')).add_parse_action(StrValue) number_value = pyparsing_common.number.add_parse_action(NumberValue) number_range_value = ( one_of("[ {") + number_value + Suppress(Literal("TO")) + number_value + one_of("] }") ).add_parse_action(NumberRangeValue) -bool_value = one_of("true false").add_parse_action(BoolValue) +bool_value = one_of("true false *").add_parse_action(BoolValue) def convert_term(s, loc, toks): + negate = False + + if len(toks) == 2 and toks[0] == "-": + negate = True + toks = toks[1:] + + if len(toks) > 1: + raise Exception("Something went wrong") + if toks[0] == "public": - return toks[0] + return SearchTermKey(toks[0], negate) elif toks[0] == "isic_id": # isic_id can't be used with wildcards since it's a foreign key, so join the table and # refer to the __id. - return "isic__id" + return SearchTermKey("isic__id", negate) elif toks[0] == "lesion_id": - return "accession__lesion__id" + return SearchTermKey("accession__lesion__id", negate) elif toks[0] == "patient_id": - return "accession__patient__id" + return SearchTermKey("accession__patient__id", negate) elif toks[0] == "age_approx": - return "accession__metadata__age__approx" + return SearchTermKey("accession__metadata__age__approx", negate) elif toks[0] == "copyright_license": - return "accession__copyright_license" + return SearchTermKey("accession__copyright_license", negate) else: - return f"accession__metadata__{toks[0]}" + return SearchTermKey(f"accession__metadata__{toks[0]}", negate) def make_term_keyword(name): - return Keyword(name).add_parse_action(convert_term) + return (Optional("-") + Keyword(name)).add_parse_action(convert_term) def make_term(name, values): diff --git a/isic/core/search.py b/isic/core/search.py index 2180cc96..6ffa5955 100644 --- a/isic/core/search.py +++ b/isic/core/search.py @@ -1,3 +1,4 @@ +from copy import deepcopy from functools import lru_cache, partial import logging @@ -18,6 +19,7 @@ INDEX_MAPPINGS = {"properties": {}} DEFAULT_SEARCH_AGGREGATES = {} +COUNTS_AGGREGATES = {} # TODO: include private meta fields (e.g. patient/lesion id) for key, definition in FIELD_REGISTRY.items(): @@ -52,6 +54,9 @@ "extended_bounds": {"min": 0, "max": 85}, } } +for key, _ in DEFAULT_SEARCH_AGGREGATES.items(): + COUNTS_AGGREGATES[f"{key}_missing"] = {"missing": {"field": key}} + COUNTS_AGGREGATES[f"{key}_present"] = {"value_count": {"field": key}} # These are all approaching 10 unique values, which would require passing a size attribute @@ -116,23 +121,52 @@ def bulk_add_to_search_index(qs: QuerySet[Image], chunk_size: int = 2000) -> Non def facets(query: dict | None = None, collections: list[int] | None = None) -> dict: - body = { + """ + Generate the facet counts for a given query. + + This has to perform 2 elasticsearch queries, one for computing the present/absent + counts for each facet, and another for generating the buckets themselves. + """ + counts_body = { "size": 0, - "aggs": DEFAULT_SEARCH_AGGREGATES, + "aggs": COUNTS_AGGREGATES, } + if query: + counts_body["query"] = query + + counts = get_elasticsearch_client().search( + index=settings.ISIC_ELASTICSEARCH_INDEX, + body=counts_body, + )["aggregations"] + + facets_body = { + "size": 0, + "aggs": deepcopy(DEFAULT_SEARCH_AGGREGATES), + } + + # pass the counts through as metadata in the final aggregation query + # https://www.elastic.co/guide/en/elasticsearch/reference/8.10/search-aggregations.html#add-metadata-to-an-agg + for field in facets_body["aggs"]: + facets_body["aggs"][field]["meta"] = { + "missing_count": counts[f"{field}_missing"]["doc_count"], + "present_count": counts[f"{field}_present"]["value"], + } + if collections is not None: # Note this include statement means we can only filter by ~65k collections. See: # "By default, Elasticsearch limits the terms query to a maximum of 65,536 terms. # You can change this limit using the index.max_terms_count setting." - body["aggs"]["collections"] = {"terms": {"field": "collections", "include": collections}} + facets_body["aggs"]["collections"] = { + "terms": {"field": "collections", "include": collections} + } if query: - body["query"] = query + facets_body["query"] = query - return get_elasticsearch_client().search(index=settings.ISIC_ELASTICSEARCH_INDEX, body=body)[ - "aggregations" - ] + return get_elasticsearch_client().search( + index=settings.ISIC_ELASTICSEARCH_INDEX, body=facets_body + )["aggregations"] def build_elasticsearch_query( diff --git a/isic/core/tests/test_dsl.py b/isic/core/tests/test_dsl.py index af9c479e..c940a8f7 100644 --- a/isic/core/tests/test_dsl.py +++ b/isic/core/tests/test_dsl.py @@ -10,6 +10,34 @@ [ # test isic_id especially due to the weirdness of the foreign key ["isic_id:ISIC_123*", Q(isic__id__startswith="ISIC_123")], + # test negation and present/missing values + ["-isic_id:*", ~Q(isic__id__isnull=False)], + ["-lesion_id:*", ~Q(accession__lesion__id__isnull=False)], + [ + "-diagnosis:* OR diagnosis:foobar", + ~Q(accession__metadata__diagnosis__isnull=False) + | Q(accession__metadata__diagnosis="foobar"), + ], + ["age_approx:[50 TO *]", ParseException], + ["-melanocytic:*", ~Q(accession__metadata__melanocytic__isnull=False)], + ["melanocytic:*", Q(accession__metadata__melanocytic__isnull=False)], + ["-age_approx:50", ~Q(accession__metadata__age__approx=50)], + ["-diagnosis:foo*", ~Q(accession__metadata__diagnosis__startswith="foo")], + [ + "-age_approx:[50 TO 70]", + ~Q(accession__metadata__age__approx__gte=50, accession__metadata__age__approx__lte=70), + ], + [ + "-diagnosis:foobar OR (diagnosis:foobaz AND (-diagnosis:foo* OR age_approx:50))", + ~Q(accession__metadata__diagnosis="foobar") + | ( + Q(accession__metadata__diagnosis="foobaz") + & ( + ~Q(accession__metadata__diagnosis__startswith="foo") + | Q(accession__metadata__age__approx=50) + ) + ), + ], ["isic_id:*123", Q(isic__id__endswith="123")], ["lesion_id:IL_123*", Q(accession__lesion__id__startswith="IL_123")], ["lesion_id:*123", Q(accession__lesion__id__endswith="123")], diff --git a/isic/core/tests/test_search.py b/isic/core/tests/test_search.py index 712cf813..7f483603 100644 --- a/isic/core/tests/test_search.py +++ b/isic/core/tests/test_search.py @@ -246,6 +246,16 @@ def test_core_api_image_faceting(private_and_public_images_collections, client_) assert buckets[0] == {"key": public_coll.pk, "doc_count": 1}, buckets +@pytest.mark.django_db +def test_core_api_image_faceting_structure(searchable_images, client): + r = client.get( + "/api/v2/images/facets/", + ) + assert r.status_code == 200, r.json() + assert len(r.json()["diagnosis"]["buckets"]) == 1, r.json() + assert r.json()["diagnosis"]["meta"] == {"missing_count": 0, "present_count": 1}, r.json() + + @pytest.mark.parametrize( "client_", [