Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove common_timestamp function from postgis driver #1623

Merged
merged 11 commits into from
Sep 20, 2024
187 changes: 112 additions & 75 deletions datacube/drivers/postgis/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from sqlalchemy import select, text, and_, or_, func
from sqlalchemy.dialects.postgresql import INTERVAL
from sqlalchemy.exc import IntegrityError
from sqlalchemy.engine import Row

from typing import Iterable, Sequence, Optional, Set, Any
from typing import cast as type_cast
Expand All @@ -50,22 +49,40 @@

# Make a function because it's broken
def _dataset_select_fields() -> tuple:
return tuple(f.alchemy_expression for f in _dataset_fields())


def _dataset_fields() -> tuple:
native_flds = get_native_fields()
return (
Dataset,
# All active URIs, from newest to oldest
func.array(
select(
SelectedDatasetLocation.uri
).where(
and_(
SelectedDatasetLocation.dataset_ref == Dataset.id,
SelectedDatasetLocation.archived == None
)
).order_by(
SelectedDatasetLocation.added.desc(),
SelectedDatasetLocation.id.desc()
).label('uris')
).label('uris')
native_flds["id"],
native_flds["indexed_time"],
native_flds["indexed_by"],
native_flds["product_id"],
native_flds["metadata_type_id"],
native_flds["metadata_doc"],
NativeField(
'archived',
'Archived date',
Dataset.archived
),
NativeField("uris",
"all uris",
func.array(
select(
SelectedDatasetLocation.uri
).where(
and_(
SelectedDatasetLocation.dataset_ref == Dataset.id,
SelectedDatasetLocation.archived == None
)
).order_by(
SelectedDatasetLocation.added.desc(),
SelectedDatasetLocation.id.desc()
).label('uris')
),
alchemy_table=Dataset.__table__ # type: ignore[attr-defined]
SpacemanPaul marked this conversation as resolved.
Show resolved Hide resolved
)
)


Expand Down Expand Up @@ -230,6 +247,29 @@ def extract_dataset_fields(ds_metadata, fields):
return result


# Min/Max aggregating time fields for temporal_extent methods
time_min = DateDocField('acquisition_time_min',
'Min of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:start_datetime'],
['properties', 'datetime']
],
selection='least')


time_max = DateDocField('acquisition_time_max',
'Max of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:end_datetime'],
['properties', 'datetime']
],
selection='greatest')


class PostgisDbAPI:
def __init__(self, parentdb, connection):
self._db = parentdb
Expand Down Expand Up @@ -476,19 +516,6 @@ def all_dataset_ids(self, archived: bool | None = False):
)
return self._connection.execute(query).fetchall()

# Not currently implemented.
# def insert_dataset_source(self, classifier, dataset_id, source_dataset_id):
# r = self._connection.execute(
# insert(DatasetSource).on_conflict_do_nothing(
# index_elements=['classifier', 'dataset_ref']
# ).values(
# classifier=classifier,
# dataset_ref=dataset_id,
# source_dataset_ref=source_dataset_id
# )
# )
# return r.rowcount > 0

def archive_dataset(self, dataset_id):
r = self._connection.execute(
update(Dataset).where(
Expand Down Expand Up @@ -548,10 +575,10 @@ def get_datasets(self, dataset_ids):
).fetchall()

def get_derived_datasets(self, dataset_id):
raise NotImplementedError
raise NotImplementedError()

def get_dataset_sources(self, dataset_id):
raise NotImplementedError
raise NotImplementedError()

def search_datasets_by_metadata(self, metadata, archived):
"""
Expand Down Expand Up @@ -621,14 +648,13 @@ def search_datasets_query(self,
assert source_exprs is None
assert not with_source_ids

if select_fields:
select_columns = tuple(
f.alchemy_expression.label(f.name)
for f in select_fields
)
else:
select_columns = _dataset_select_fields()
if not select_fields:
select_fields = _dataset_fields()

select_columns = tuple(
f.alchemy_expression.label(f.name)
for f in select_fields
)
if geom:
SpatialIndex, spatialquery = self.geospatial_query(geom)
else:
Expand Down Expand Up @@ -663,12 +689,21 @@ def search_datasets(self, expressions,
:type with_source_ids: bool
:type select_fields: tuple[datacube.drivers.postgis._fields.PgField]
:type expressions: tuple[datacube.drivers.postgis._fields.PgExpression]

:return: An iterable of tuples of decoded values
"""
if select_fields is None:
select_fields = _dataset_fields()
select_query = self.search_datasets_query(expressions, source_exprs,
select_fields, with_source_ids,
limit, geom=geom, archived=archived)
_LOG.debug("search_datasets SQL: %s", str(select_query))
return self._connection.execute(select_query)

def decode_row(raw: Iterable[Any]) -> dict[str, Any]:
return {f.name: f.normalise_value(r) for r, f in zip(raw, select_fields)}

for row in self._connection.execute(select_query):
yield decode_row(row)

def bulk_simple_dataset_search(self, products=None, batch_size=0):
"""
Expand All @@ -690,7 +725,7 @@ def bulk_simple_dataset_search(self, products=None, batch_size=0):
query = select(
*_dataset_bulk_select_fields()
).select_from(Dataset).where(
Dataset.archived == None
Dataset.archived.is_(None)
)
if products:
query = query.where(Dataset.product_ref.in_(products))
Expand Down Expand Up @@ -733,10 +768,12 @@ def insert_lineage_bulk(self, values):
)
return res.rowcount, requested - res.rowcount

def get_duplicates(self, match_fields: Sequence[PgField], expressions: Sequence[PgExpression]) -> Iterable[Row]:
def get_duplicates(self,
match_fields: Sequence[PgField],
expressions: Sequence[PgExpression]) -> Iterable[dict[str, Any]]:
# TODO
if "time" in [f.name for f in match_fields]:
return self.get_duplicates_with_time(match_fields, expressions)
yield from self.get_duplicates_with_time(match_fields, expressions)

group_expressions = tuple(f.alchemy_expression for f in match_fields)
join_tables = PostgisDbAPI._join_tables(expressions, match_fields)
Expand All @@ -749,43 +786,47 @@ def get_duplicates(self, match_fields: Sequence[PgField], expressions: Sequence[
query = query.join(*joins)

query = query.where(
and_(Dataset.archived == None, *(PostgisDbAPI._alchemify_expressions(expressions)))
and_(Dataset.archived.is_(None), *(PostgisDbAPI._alchemify_expressions(expressions)))
).group_by(
*group_expressions
).having(
func.count(Dataset.id) > 1
)
return self._connection.execute(query)
for row in self._connection.execute(query):
drow = {"ids": row.ids}
for f in match_fields:
drow[f.name] = getattr(row, f.name)
yield drow

def get_duplicates_with_time(
self, match_fields: Sequence[PgField], expressions: Sequence[PgExpression]
) -> Iterable[Row]:
) -> Iterable[dict[str, Any]]:
fields = []
for f in match_fields:
if f.name == "time":
time_field = type_cast(DateRangeDocField, f).expression_with_leniency
for fld in match_fields:
if fld.name == "time":
time_field = type_cast(DateRangeDocField, fld)
else:
fields.append(f.alchemy_expression)
fields.append(fld.alchemy_expression)

join_tables = PostgisDbAPI._join_tables(expressions, match_fields)

cols = [Dataset.id, time_field.label('time'), *fields]
cols = [Dataset.id, time_field.expression_with_leniency.label('time'), *fields]
query = select(
*cols
).select_from(Dataset)
for joins in join_tables:
query = query.join(*joins)

query = query.where(
and_(Dataset.archived == None, *(PostgisDbAPI._alchemify_expressions(expressions)))
and_(Dataset.archived.is_(None), *(PostgisDbAPI._alchemify_expressions(expressions)))
)

t1 = query.alias("t1")
t2 = query.alias("t2")

time_overlap = select(
t1.c.id,
text("t1.time * t2.time as time_intersect"),
t1.c.time.intersection(t2.c.time).label('time_intersect'),
*fields
).select_from(
t1.join(
Expand All @@ -797,15 +838,24 @@ def get_duplicates_with_time(
query = select(
func.array_agg(func.distinct(time_overlap.c.id)).label("ids"),
*fields, # type: ignore[arg-type]
text("(lower(time_intersect) at time zone 'UTC', upper(time_intersect) at time zone 'UTC') as time")
text("time_intersect as time")
).select_from(
time_overlap # type: ignore[arg-type]
).group_by(
*fields, text("time_intersect")
).having(
func.count(time_overlap.c.id) > 1
)
return self._connection.execute(query)

for row in self._connection.execute(query):
# TODO: Use decode_rows above - would require creating a field class for the ids array.
drow: dict[str, Any] = {
"ids": row.ids,
}
for f in fields:
drow[f.key] = getattr(row, f.key) # type: ignore[union-attr]
drow["time"] = time_field.normalise_value((row.time.lower, row.time.upper))
SpacemanPaul marked this conversation as resolved.
Show resolved Hide resolved
yield drow

def count_datasets(self, expressions, archived: bool | None = False, geom: Geometry | None = None):
"""
Expand Down Expand Up @@ -1474,33 +1524,20 @@ def remove_lineage_relations(self,
def temporal_extent_by_prod(self, product_id: int) -> tuple[datetime.datetime, datetime.datetime]:
query = self.temporal_extent_full().where(Dataset.product_ref == product_id)
res = self._connection.execute(query)
return res.first()
for tmin, tmax in res:
return (self.time_min.normalise_value(tmin), self.time_max.normalise_value(tmax))
raise RuntimeError("Product has no datasets and therefore no temporal extent")

def temporal_extent_by_ids(self, ids: Iterable[DSID]) -> tuple[datetime.datetime, datetime.datetime]:
query = self.temporal_extent_full().where(Dataset.id.in_(ids))
res = self._connection.execute(query)
return res.first()
for tmin, tmax in res:
return (self.time_min.normalise_value(tmin), self.time_max.normalise_value(tmax))
raise ValueError("no dataset ids provided")

def temporal_extent_full(self) -> Select:
# Hardcode eo3 standard time locations - do not use this approach in a legacy index driver.
time_min = DateDocField('aquisition_time_min',
'Min of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:start_datetime'],
['properties', 'datetime']
],
selection='least')
time_max = DateDocField('aquisition_time_max',
'Max of time when dataset was acquired',
Dataset.metadata_doc,
False, # is it indexed
offset=[
['properties', 'dtr:end_datetime'],
['properties', 'datetime']
],
selection='greatest')

return select(
func.min(time_min.alchemy_expression), func.max(time_max.alchemy_expression)
func.min(self.time_min.alchemy_expression), func.max(self.time_max.alchemy_expression)
)
1 change: 0 additions & 1 deletion datacube/drivers/postgis/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def ensure_db(engine, with_permissions=True):
c.execute(text(f"""
grant usage on schema {SCHEMA_NAME} to odc_user;
grant select on all tables in schema {SCHEMA_NAME} to odc_user;
grant execute on function {SCHEMA_NAME}.common_timestamp(text) to odc_user;

grant insert on {SCHEMA_NAME}.dataset,
{SCHEMA_NAME}.location,
Expand Down
Loading
Loading