From 46ab697a607273f4bfb251aa27fe7ffd1f2bc25d Mon Sep 17 00:00:00 2001 From: adolph liu Date: Wed, 1 Jan 2025 23:38:40 -0600 Subject: [PATCH] Add support for multiple aggregations --- datafusion-uwheel/src/lib.rs | 216 +++++++++++++++++++++++++++++++++-- 1 file changed, 206 insertions(+), 10 deletions(-) diff --git a/datafusion-uwheel/src/lib.rs b/datafusion-uwheel/src/lib.rs index 5d6d95c..54e6c3d 100644 --- a/datafusion-uwheel/src/lib.rs +++ b/datafusion-uwheel/src/lib.rs @@ -265,6 +265,16 @@ impl UWheelOptimizer { agg.group_expr.is_empty() && agg.aggr_expr.len() == 1 } + /// checks whether the Aggregate has a single group_by expression + fn single_group_by(agg: &Aggregate) -> bool { + agg.group_expr.len() == 1 + } + + /// check whether the Aggregate has no group_expr and aggr_expr has a length greater than 1 + fn multiple_aggregates(agg: &Aggregate) -> bool { + agg.group_expr.is_empty() && agg.aggr_expr.len() > 1 + } + // Attemps to rewrite a top-level Projection plan fn try_rewrite_projection( &self, @@ -315,7 +325,7 @@ impl UWheelOptimizer { } } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(agg) if Self::single_group_by(agg) => { let group_expr = agg.group_expr.first()?; // Only continue if the aggregation has a filter @@ -372,6 +382,57 @@ impl UWheelOptimizer { } None } + + LogicalPlan::Aggregate(agg) if Self::multiple_aggregates(agg) => { + // Only continue if the aggregation has a filter + let LogicalPlan::Filter(filter) = agg.input.as_ref() else { + return None; + }; + + let agg_exprs = &agg.aggr_expr; + + let mut agg_results = Vec::new(); + + for agg_expr in agg_exprs { + match agg_expr { + // Single Aggregate Function (e.g., SUM(col)) + Expr::AggregateFunction(agg) if agg.args.len() == 1 => { + if let Expr::Column(col) = &agg.args[0] { + // Fetch temporal filter range and expr key which is used to identify a wheel + let (range, expr_key) = match extract_filter_expr( + &filter.predicate, + &self.time_column, + )? { + (range, Some(expr)) => { + (range, maybe_replace_table_name(&expr, &self.name)) + } + (range, None) => (range, STAR_AGGREGATION_ALIAS.to_string()), + }; + + // build the key for the wheel + let wheel_key = format!("{}.{}.{}", self.name, col.name, expr_key); + + let agg_type = func_def_to_aggregate_type(&agg.func)?; + + // get aggregation result + let result = + self.get_aggregate_result(agg_type, &wheel_key, range)?; + + agg_results.push(result); + } else { + return None; + } + } + _ => { + return None; + } + } + } + + let schema = Arc::new(plan.schema().clone().as_arrow().clone()); + + uwheel_multiple_aggregations_to_table_scan(agg_results, schema).ok() + } // Check whether it follows the pattern: SELECT * FROM X WHERE TIME >= X AND TIME <= Y LogicalPlan::Filter(filter) => self.try_rewrite_filter(filter, plan), _ => None, @@ -453,27 +514,32 @@ impl UWheelOptimizer { range: WheelRange, schema: SchemaRef, ) -> Option { + let result = self.get_aggregate_result(agg_type, wheel_key, range)?; + uwheel_agg_to_table_scan(result, schema).ok() + } + + fn get_aggregate_result( + &self, + agg_type: UWheelAggregate, + wheel_key: &str, + range: WheelRange, + ) -> Option { match agg_type { UWheelAggregate::Sum => { let wheel = self.wheels.sum.lock().unwrap().get(wheel_key)?.clone(); - let result = wheel.combine_range_and_lower(range)?; - uwheel_agg_to_table_scan(result, schema).ok() + wheel.combine_range_and_lower(range) } UWheelAggregate::Avg => { let wheel = self.wheels.avg.lock().unwrap().get(wheel_key)?.clone(); - let result = wheel.combine_range_and_lower(range)?; - - uwheel_agg_to_table_scan(result, schema).ok() + wheel.combine_range_and_lower(range) } UWheelAggregate::Min => { let wheel = self.wheels.min.lock().unwrap().get(wheel_key)?.clone(); - let result = wheel.combine_range_and_lower(range)?; - uwheel_agg_to_table_scan(result, schema).ok() + wheel.combine_range_and_lower(range) } UWheelAggregate::Max => { let wheel = self.wheels.max.lock().unwrap().get(wheel_key)?.clone(); - let result = wheel.combine_range_and_lower(range)?; - uwheel_agg_to_table_scan(result, schema).ok() + wheel.combine_range_and_lower(range) } _ => unimplemented!(), } @@ -517,6 +583,24 @@ fn uwheel_group_by_to_table_scan( mem_table_as_table_scan(mem_table, df_schema) } +fn uwheel_multiple_aggregations_to_table_scan( + agg_results: Vec, + schema: SchemaRef, +) -> Result { + let mut columns = Vec::new(); + + for result in agg_results { + let data = Float64Array::from(vec![result]); + columns.push(Arc::new(data) as Arc); + } + + let record_batch = RecordBatch::try_new(schema.clone(), columns)?; + + let df_schema = Arc::new(DFSchema::try_from(schema.clone())?); + let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?; + mem_table_as_table_scan(mem_table, df_schema) +} + // helper for possibly removing the table name from the expression key fn maybe_replace_table_name(expr: &Expr, table_name: &str) -> String { let expr_str = expr.to_string(); @@ -1575,4 +1659,116 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn multiple_aggregation_rewrite() -> Result<()> { + let optimizer = test_optimizer().await?; + + optimizer + .build_index(IndexBuilder::with_col_and_aggregate( + "agg_col", + UWheelAggregate::Avg, + )) + .await?; + + optimizer + .build_index(IndexBuilder::with_col_and_aggregate( + "agg_col", + UWheelAggregate::Sum, + )) + .await?; + + let temporal_filter = col("timestamp") + .gt_eq(lit("2024-05-10T00:00:00Z")) + .and(col("timestamp").lt(lit("2024-05-10T00:00:10Z"))); + + let plan = + LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)? + .filter(temporal_filter)? + .aggregate( + Vec::::new(), + vec![avg(col("agg_col")), sum(col("agg_col"))], + )? + .project(vec![avg(col("agg_col")), sum(col("agg_col"))])? + .build()?; + + // Assert that the original plan is a Projection + assert!(matches!(plan, LogicalPlan::Projection(_))); + + let rewritten = optimizer.try_rewrite(&plan).unwrap(); + // assert it was rewritten to a TableScan + assert!(matches!(rewritten, LogicalPlan::TableScan(_))); + + Ok(()) + } + + #[tokio::test] + async fn multiple_aggregation_exec() -> Result<()> { + let optimizer = test_optimizer().await?; + + optimizer + .build_index(IndexBuilder::with_col_and_aggregate( + "agg_col", + UWheelAggregate::Avg, + )) + .await?; + + optimizer + .build_index(IndexBuilder::with_col_and_aggregate( + "agg_col", + UWheelAggregate::Sum, + )) + .await?; + + let temporal_filter = col("timestamp") + .gt_eq(lit("2024-05-10T00:00:00Z")) + .and(col("timestamp").lt(lit("2024-05-10T00:00:10Z"))); + + let plan = + LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)? + .filter(temporal_filter)? + .aggregate( + Vec::::new(), + vec![avg(col("agg_col")), sum(col("agg_col"))], + )? + .project(vec![avg(col("agg_col")), sum(col("agg_col"))])? + .build()?; + + let ctx = SessionContext::new(); + ctx.register_table("test", optimizer.provider().clone())?; + + // Set UWheelOptimizer as optimizer rule + let session_state = SessionStateBuilder::new() + .with_optimizer_rules(vec![optimizer.clone()]) + .build(); + let uwheel_ctx = SessionContext::new_with_state(session_state); + + // Run the query through the ctx that has our OptimizerRule + let df = uwheel_ctx.execute_logical_plan(plan).await?; + let results = df.collect().await?; + + assert_eq!(results.len(), 1); + + assert_eq!( + results[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 5.5 + ); + + assert_eq!( + results[0] + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 55.0 + ); + + Ok(()) + } }