From a0ad376840daac8fdfecee5a4988c585350c629b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 2 Aug 2024 02:47:27 +0200 Subject: [PATCH] [Minor] Refactor approx_percentile (#11769) * Refactor approx_percentile * Refactor approx_percentile * Types * Types * Types --- .../functions-aggregate/src/approx_median.rs | 2 +- .../src/approx_percentile_cont.rs | 8 +-- .../src/aggregate/tdigest.rs | 62 +++++++++++-------- 3 files changed, 41 insertions(+), 31 deletions(-) diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index e12e3445a83e..c386ad89f0fb 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -78,7 +78,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(vec![ Field::new(format_state_name(args.name, "max_size"), UInt64, false), Field::new(format_state_name(args.name, "sum"), Float64, false), - Field::new(format_state_name(args.name, "count"), Float64, false), + Field::new(format_state_name(args.name, "count"), UInt64, false), Field::new(format_state_name(args.name, "max"), Float64, false), Field::new(format_state_name(args.name, "min"), Float64, false), Field::new_list( diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 844e48f0a44d..af2a26fd05ec 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -214,7 +214,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { ), Field::new( format_state_name(args.name, "count"), - DataType::Float64, + DataType::UInt64, false, ), Field::new( @@ -406,7 +406,7 @@ impl Accumulator for ApproxPercentileAccumulator { } fn evaluate(&mut self) -> datafusion_common::Result { - if self.digest.count() == 0.0 { + if self.digest.count() == 0 { return ScalarValue::try_from(self.return_type.clone()); } let q = self.digest.estimate_quantile(self.percentile); @@ -487,8 +487,8 @@ mod tests { ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); accumulator.merge_digests(&[t1]); - assert_eq!(accumulator.digest.count(), 50_000.0); + assert_eq!(accumulator.digest.count(), 50_000); accumulator.merge_digests(&[t2]); - assert_eq!(accumulator.digest.count(), 100_000.0); + assert_eq!(accumulator.digest.count(), 100_000); } } diff --git a/datafusion/physical-expr-common/src/aggregate/tdigest.rs b/datafusion/physical-expr-common/src/aggregate/tdigest.rs index 1da3d7180d84..070ebc46483b 100644 --- a/datafusion/physical-expr-common/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr-common/src/aggregate/tdigest.rs @@ -47,6 +47,17 @@ macro_rules! cast_scalar_f64 { }; } +// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or +// panic. +macro_rules! cast_scalar_u64 { + ($value:expr ) => { + match &$value { + ScalarValue::UInt64(Some(v)) => *v, + v => panic!("invalid type {:?}", v), + } + }; +} + /// This trait is implemented for each type a [`TDigest`] can operate on, /// allowing it to support both numerical rust types (obtained from /// `PrimitiveArray` instances), and [`ScalarValue`] instances. @@ -142,7 +153,7 @@ pub struct TDigest { centroids: Vec, max_size: usize, sum: f64, - count: f64, + count: u64, max: f64, min: f64, } @@ -153,7 +164,7 @@ impl TDigest { centroids: Vec::new(), max_size, sum: 0_f64, - count: 0_f64, + count: 0, max: f64::NAN, min: f64::NAN, } @@ -164,14 +175,14 @@ impl TDigest { centroids: vec![centroid.clone()], max_size, sum: centroid.mean * centroid.weight, - count: 1_f64, + count: 1, max: centroid.mean, min: centroid.mean, } } #[inline] - pub fn count(&self) -> f64 { + pub fn count(&self) -> u64 { self.count } @@ -203,7 +214,7 @@ impl Default for TDigest { centroids: Vec::new(), max_size: 100, sum: 0_f64, - count: 0_f64, + count: 0, max: f64::NAN, min: f64::NAN, } @@ -211,8 +222,8 @@ impl Default for TDigest { } impl TDigest { - fn k_to_q(k: f64, d: f64) -> f64 { - let k_div_d = k / d; + fn k_to_q(k: u64, d: usize) -> f64 { + let k_div_d = k as f64 / d as f64; if k_div_d >= 0.5 { let base = 1.0 - k_div_d; 1.0 - 2.0 * base * base @@ -244,12 +255,12 @@ impl TDigest { } let mut result = TDigest::new(self.max_size()); - result.count = self.count() + (sorted_values.len() as f64); + result.count = self.count() + sorted_values.len() as u64; let maybe_min = *sorted_values.first().unwrap(); let maybe_max = *sorted_values.last().unwrap(); - if self.count() > 0.0 { + if self.count() > 0 { result.min = self.min.min(maybe_min); result.max = self.max.max(maybe_max); } else { @@ -259,10 +270,10 @@ impl TDigest { let mut compressed: Vec = Vec::with_capacity(self.max_size); - let mut k_limit: f64 = 1.0; + let mut k_limit: u64 = 1; let mut q_limit_times_count = - Self::k_to_q(k_limit, self.max_size as f64) * result.count(); - k_limit += 1.0; + Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + k_limit += 1; let mut iter_centroids = self.centroids.iter().peekable(); let mut iter_sorted_values = sorted_values.iter().peekable(); @@ -309,8 +320,8 @@ impl TDigest { compressed.push(curr.clone()); q_limit_times_count = - Self::k_to_q(k_limit, self.max_size as f64) * result.count(); - k_limit += 1.0; + Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + k_limit += 1; curr = next; } } @@ -381,7 +392,7 @@ impl TDigest { let mut centroids: Vec = Vec::with_capacity(n_centroids); let mut starts: Vec = Vec::with_capacity(digests.len()); - let mut count: f64 = 0.0; + let mut count = 0; let mut min = f64::INFINITY; let mut max = f64::NEG_INFINITY; @@ -389,8 +400,8 @@ impl TDigest { for digest in digests.iter() { starts.push(start); - let curr_count: f64 = digest.count(); - if curr_count > 0.0 { + let curr_count = digest.count(); + if curr_count > 0 { min = min.min(digest.min); max = max.max(digest.max); count += curr_count; @@ -424,8 +435,8 @@ impl TDigest { let mut result = TDigest::new(max_size); let mut compressed: Vec = Vec::with_capacity(max_size); - let mut k_limit: f64 = 1.0; - let mut q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count); + let mut k_limit = 1; + let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; let mut iter_centroids = centroids.iter_mut(); let mut curr = iter_centroids.next().unwrap(); @@ -444,8 +455,8 @@ impl TDigest { sums_to_merge = 0_f64; weights_to_merge = 0_f64; compressed.push(curr.clone()); - q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count); - k_limit += 1.0; + q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; + k_limit += 1; curr = centroid; } } @@ -468,8 +479,7 @@ impl TDigest { return 0.0; } - let count_ = self.count; - let rank = q * count_; + let rank = q * self.count as f64; let mut pos: usize; let mut t; @@ -479,7 +489,7 @@ impl TDigest { } pos = 0; - t = count_; + t = self.count as f64; for (k, centroid) in self.centroids.iter().enumerate().rev() { t -= centroid.weight(); @@ -581,7 +591,7 @@ impl TDigest { vec![ ScalarValue::UInt64(Some(self.max_size as u64)), ScalarValue::Float64(Some(self.sum)), - ScalarValue::Float64(Some(self.count)), + ScalarValue::UInt64(Some(self.count)), ScalarValue::Float64(Some(self.max)), ScalarValue::Float64(Some(self.min)), ScalarValue::List(arr), @@ -627,7 +637,7 @@ impl TDigest { Self { max_size, sum: cast_scalar_f64!(state[1]), - count: cast_scalar_f64!(&state[2]), + count: cast_scalar_u64!(&state[2]), max, min, centroids,