From f7cc822f5774c2006ddc2aeb69cb66ffd57612e1 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Thu, 11 Apr 2024 15:36:02 +0300 Subject: [PATCH] Fix DistinctCount for timestamps with time zone (#10043) * Fix DistinctCount for timestamps with time zone Preserve the original data type in the aggregation state * Add tests for decimal count distinct --- .../src/aggregate/count_distinct/mod.rs | 42 +++++++++++-------- .../src/aggregate/count_distinct/native.rs | 15 +++++-- .../sqllogictest/test_files/aggregate.slt | 37 +++++++++++++--- .../sqllogictest/test_files/decimal.slt | 11 +++++ 4 files changed, 79 insertions(+), 26 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index 9c5605f495ea..ee63945eb249 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -109,12 +109,14 @@ impl AggregateExpr for DistinctCount { UInt16 => Box::new(PrimitiveDistinctCountAccumulator::::new()), UInt32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), UInt64 => Box::new(PrimitiveDistinctCountAccumulator::::new()), - Decimal128(_, _) => { - Box::new(PrimitiveDistinctCountAccumulator::::new()) - } - Decimal256(_, _) => { - Box::new(PrimitiveDistinctCountAccumulator::::new()) - } + dt @ Decimal128(_, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new() + .with_data_type(dt.clone()), + ), + dt @ Decimal256(_, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new() + .with_data_type(dt.clone()), + ), Date32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Date64 => Box::new(PrimitiveDistinctCountAccumulator::::new()), @@ -130,18 +132,22 @@ impl AggregateExpr for DistinctCount { Time64(Nanosecond) => { Box::new(PrimitiveDistinctCountAccumulator::::new()) } - Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMicrosecondType, - >::new()), - Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMillisecondType, - >::new()), - Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampNanosecondType, - >::new()), - Timestamp(Second, _) => { - Box::new(PrimitiveDistinctCountAccumulator::::new()) - } + dt @ Timestamp(Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new() + .with_data_type(dt.clone()), + ), + dt @ Timestamp(Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new() + .with_data_type(dt.clone()), + ), + dt @ Timestamp(Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new() + .with_data_type(dt.clone()), + ), + dt @ Timestamp(Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new() + .with_data_type(dt.clone()), + ), Float16 => Box::new(FloatDistinctCountAccumulator::::new()), Float32 => Box::new(FloatDistinctCountAccumulator::::new()), diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs index a44e8b772e5a..8f3ce8acfe07 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs @@ -30,6 +30,7 @@ use ahash::RandomState; use arrow::array::ArrayRef; use arrow_array::types::ArrowPrimitiveType; use arrow_array::PrimitiveArray; +use arrow_schema::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::utils::array_into_list_array; @@ -45,6 +46,7 @@ where T::Native: Eq + Hash, { values: HashSet, + data_type: DataType, } impl PrimitiveDistinctCountAccumulator @@ -55,8 +57,14 @@ where pub(super) fn new() -> Self { Self { values: HashSet::default(), + data_type: T::DATA_TYPE, } } + + pub(super) fn with_data_type(mut self, data_type: DataType) -> Self { + self.data_type = data_type; + self + } } impl Accumulator for PrimitiveDistinctCountAccumulator @@ -65,9 +73,10 @@ where T::Native: Eq + Hash, { fn state(&mut self) -> datafusion_common::Result> { - let arr = Arc::new(PrimitiveArray::::from_iter_values( - self.values.iter().cloned(), - )) as ArrayRef; + let arr = Arc::new( + PrimitiveArray::::from_iter_values(self.values.iter().cloned()) + .with_data_type(self.data_type.clone()), + ); let list = Arc::new(array_into_list_array(arr)); Ok(vec![ScalarValue::List(list)]) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 4929ab485d6d..966236db2732 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1876,18 +1876,22 @@ select arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros, arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis, arrow_cast(column1, 'Timestamp(Second, None)') as secs, + arrow_cast(column1, 'Timestamp(Nanosecond, Some("UTC"))') as nanos_utc, + arrow_cast(column1, 'Timestamp(Microsecond, Some("UTC"))') as micros_utc, + arrow_cast(column1, 'Timestamp(Millisecond, Some("UTC"))') as millis_utc, + arrow_cast(column1, 'Timestamp(Second, Some("UTC"))') as secs_utc, column2 as names, column3 as tag from t_source; # Demonstate the contents -query PPPPTT +query PPPPPPPPTT select * from t; ---- -2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 X -2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 X -NULL NULL NULL NULL Row 2 Y -2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 Y +2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 X +2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 X +NULL NULL NULL NULL NULL NULL NULL NULL Row 2 Y +2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 Y # aggregate_timestamps_sum @@ -1933,6 +1937,17 @@ SELECT tag, max(nanos), max(micros), max(millis), max(secs) FROM t GROUP BY tag X 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Y 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 +# aggregate_timestamps_count_distinct_with_tz +query IIII +SELECT count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t; +---- +3 3 3 3 + +query TIIII +SELECT tag, count(DISTINCT nanos_utc), count(DISTINCT micros_utc), count(DISTINCT millis_utc), count(DISTINCT secs_utc) FROM t GROUP BY tag ORDER BY tag; +---- +X 2 2 2 2 +Y 1 1 1 1 # aggregate_timestamps_avg statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. @@ -2285,6 +2300,18 @@ select c2, avg(c1), arrow_typeof(avg(c1)) from d_table GROUP BY c2 ORDER BY c2 A 110.0045 Decimal128(14, 7) B -100.0045 Decimal128(14, 7) +# aggregate_decimal_count_distinct +query I +select count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table +---- +4 + +query TI +select c2, count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table GROUP BY c2 ORDER BY c2 +---- +A 2 +B 2 + # Use PostgresSQL dialect statement ok set datafusion.sql_parser.dialect = 'Postgres'; diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index c220a5fc9a52..3f75e42d9304 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -720,5 +720,16 @@ select count(*),c1 from decimal256_simple group by c1 order by c1; 4 0.00004 5 0.00005 +query I +select count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple; +---- +2 + +query BI +select c4, count(DISTINCT cast(c1 AS DECIMAL(42, 4))) from decimal256_simple GROUP BY c4 ORDER BY c4; +---- +false 2 +true 2 + statement ok drop table decimal256_simple;