From e802b9b7357a0a03540bc070c9a58b97372f6785 Mon Sep 17 00:00:00 2001 From: adolph liu Date: Sat, 11 Jan 2025 15:30:39 -0600 Subject: [PATCH] Support group by & multiple aggregations --- datafusion-uwheel/src/lib.rs | 340 ++++++++++++++++++++++++++++++----- 1 file changed, 297 insertions(+), 43 deletions(-) diff --git a/datafusion-uwheel/src/lib.rs b/datafusion-uwheel/src/lib.rs index b97f0ad..ac824a7 100644 --- a/datafusion-uwheel/src/lib.rs +++ b/datafusion-uwheel/src/lib.rs @@ -337,53 +337,165 @@ impl UWheelOptimizer { return None; }; - let (wheel_range, _) = extract_filter_expr(&filter.predicate, &self.time_column)?; + let (wheel_range, expr_key) = + match extract_filter_expr(&filter.predicate, &self.time_column)? { + (range, Some(expr)) => (range, maybe_replace_table_name(&expr, &self.name)), + (range, None) => (range, STAR_AGGREGATION_ALIAS.to_string()), + }; match group_expr { Expr::ScalarFunction(func) if func.name() == "date_trunc" => { let interval = func.args.first()?; if let Expr::Literal(ScalarValue::Utf8(duration)) = interval { - match duration.as_ref()?.as_str() { - "second" => { - unimplemented!("date_trunc('second') group by is not supported") - } - "minute" => { - unimplemented!("date_trunc('minute') group by is not supported") - } - "hour" => { - unimplemented!("date_trunc('hour') group by is not supported") - } - "day" => { - let res = self - .wheels - .count - .group_by(wheel_range, Duration::DAY) - .unwrap_or_default() - .iter() - .map(|(k, v)| ((*k * 1_000) as i64, *v as i64)) // transform milliseconds to microseconds by multiplying by 1_000 - .collect(); - - let schema = Arc::new(plan.schema().clone().as_arrow().clone()); - - return uwheel_group_by_to_table_scan(res, schema).ok(); - } - "week" => { - unimplemented!("date_trunc('week') group by is not supported") - } - "month" => { - unimplemented!("date_trunc('month') group by is not supported") - } - "year" => { - unimplemented!("date_trunc('year') group by is not supported") + let group_by_interval = match duration.as_ref()?.as_str() { + "second" => Duration::SECOND, + "minute" => Duration::MINUTE, + "hour" => Duration::HOUR, + "day" => Duration::DAY, + "week" => Duration::WEEK, + _ => return None, + }; + + let mut group_agg_result = Vec::new(); + let mut group_col = None; + + let mut count_idx = None; + + for (idx, agg) in agg.aggr_expr.iter().enumerate() { + let (agg_type, col) = match agg { + // COUNT(*) + Expr::AggregateFunction(agg) + if is_count_star_aggregate(agg) => + { + (UWheelAggregate::Count, None) + } + + // COUNT(*) + Expr::Alias(alias) if alias.name == COUNT_STAR_ALIAS => { + (UWheelAggregate::Count, None) + } + + Expr::AggregateFunction(agg) => { + if agg.args.len() > 1 { + return None; + } + let col = match &agg.args[0] { + Expr::Column(col) => col, + _ => return None, + }; + (func_def_to_aggregate_type(&agg.func)?, Some(col)) + } + + _ => return None, + }; + + let res = match agg_type { + UWheelAggregate::Count => { + count_idx = Some(idx); + self.wheels + .count + .group_by(wheel_range, group_by_interval) + .unwrap_or_default() + .iter() + .map(|(k, v)| (*k, *v as f64)) + .collect() + } + + UWheelAggregate::Avg => { + let wheel_key = format!( + "{}.{}.{}", + self.name, + col.unwrap().name, + expr_key + ); + self.wheels + .avg + .lock() + .unwrap() + .get(&wheel_key)? + .group_by(wheel_range, group_by_interval) + .unwrap_or_default() + } + + UWheelAggregate::Min => { + let wheel_key = format!( + "{}.{}.{}", + self.name, + col.unwrap().name, + expr_key + ); + self.wheels + .min + .lock() + .unwrap() + .get(&wheel_key)? + .group_by(wheel_range, group_by_interval) + .unwrap_or_default() + } + + UWheelAggregate::Max => { + let wheel_key = format!( + "{}.{}.{}", + self.name, + col.unwrap().name, + expr_key + ); + self.wheels + .max + .lock() + .unwrap() + .get(&wheel_key)? + .group_by(wheel_range, group_by_interval) + .unwrap_or_default() + } + + UWheelAggregate::Sum => { + let wheel_key = format!( + "{}.{}.{}", + self.name, + col.unwrap().name, + expr_key + ); + self.wheels + .sum + .lock() + .unwrap() + .get(&wheel_key)? + .group_by(wheel_range, group_by_interval) + .unwrap_or_default() + } + + _ => return None, + }; + + if group_col.is_none() { + group_col = Some( + res.iter() + .map(|(k, _)| (*k * 1_000) as i64) + .collect::>(), + ); } - _ => {} + + group_agg_result + .push(res.iter().map(|(_, v)| *v).collect::>()); } + + group_col.as_ref()?; + + let schema = Arc::new(plan.schema().clone().as_arrow().clone()); + + return uwheel_group_by_to_table_scan( + group_col.unwrap(), + group_agg_result, + count_idx, + schema, + ) + .ok(); } } - _ => { - unimplemented!("We only support scalar function date_trunc for group by expression now") - } + _ => return None, } + None } @@ -622,16 +734,26 @@ fn uwheel_agg_to_table_scan(result: f64, schema: SchemaRef) -> Result, + group_col: Vec, + agg_result: Vec>, + count_idx: Option, schema: SchemaRef, ) -> Result { - let group_by = - TimestampMicrosecondArray::from(result.iter().map(|(k, _)| *k).collect::>()); + let group_by = TimestampMicrosecondArray::from(group_col); - let agg = Int64Array::from(result.iter().map(|(_, v)| *v).collect::>()); + let mut columns = vec![Arc::new(group_by) as Arc]; + + for (idx, result) in agg_result.into_iter().enumerate() { + if count_idx.is_some() && idx == count_idx.unwrap() { + let data = Int64Array::from(result.iter().map(|v| *v as i64).collect::>()); + columns.push(Arc::new(data) as Arc); + } else { + let data = Float64Array::from(result); + columns.push(Arc::new(data) as Arc); + } + } - let record_batch = - RecordBatch::try_new(schema.clone(), vec![Arc::new(group_by), Arc::new(agg)])?; + let record_batch = RecordBatch::try_new(schema.clone(), columns)?; let df_schema = Arc::new(DFSchema::try_from(schema.clone())?); let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?; @@ -1898,4 +2020,136 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn group_by_multiple_aggregation_rewrite() -> Result<()> { + let optimizer = test_optimizer().await?; + + optimizer + .build_index(IndexBuilder::with_col_and_aggregate( + "agg_col", + UWheelAggregate::Avg, + )) + .await?; + + optimizer + .build_index(IndexBuilder::with_col_and_aggregate( + "agg_col", + UWheelAggregate::Sum, + )) + .await?; + + let temporal_filter = col("timestamp") + .gt_eq(lit("2024-05-10T00:00:00Z")) + .and(col("timestamp").lt(lit("2024-05-10T00:00:10Z"))); + + let plan = + LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)? + .filter(temporal_filter)? + .aggregate( + vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp) + vec![sum(col("agg_col")), avg(col("agg_col")), count(wildcard())], + )? + .project(vec![ + date_trunc(lit("day"), col("timestamp")), + sum(col("agg_col")), + avg(col("agg_col")), + count(wildcard()), + ])? + .build()?; + + // Assert that the original plan is a Projection + assert!(matches!(plan, LogicalPlan::Projection(_))); + + let rewritten = optimizer.try_rewrite(&plan).unwrap(); + // assert it was rewritten to a TableScan + assert!(matches!(rewritten, LogicalPlan::TableScan(_))); + + Ok(()) + } + + #[tokio::test] + async fn group_by_multiple_aggregation_exec() -> Result<()> { + let optimizer = test_optimizer().await?; + + optimizer + .build_index(IndexBuilder::with_col_and_aggregate( + "agg_col", + UWheelAggregate::Avg, + )) + .await?; + + optimizer + .build_index(IndexBuilder::with_col_and_aggregate( + "agg_col", + UWheelAggregate::Sum, + )) + .await?; + + let temporal_filter = col("timestamp") + .gt_eq(lit("2024-05-10T00:00:00Z")) + .and(col("timestamp").lt(lit("2024-05-10T00:00:10Z"))); + + let plan = + LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)? + .filter(temporal_filter)? + .aggregate( + vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp) + vec![sum(col("agg_col")), avg(col("agg_col")), count(wildcard())], + )? + .project(vec![ + date_trunc(lit("day"), col("timestamp")), + sum(col("agg_col")), + avg(col("agg_col")), + count(wildcard()), + ])? + .build()?; + + let ctx = SessionContext::new(); + ctx.register_table("test", optimizer.provider().clone())?; + + // Set UWheelOptimizer as optimizer rule + let session_state = SessionStateBuilder::new() + .with_optimizer_rules(vec![optimizer.clone()]) + .build(); + let uwheel_ctx = SessionContext::new_with_state(session_state); + + // Run the query through the ctx that has our OptimizerRule + let df = uwheel_ctx.execute_logical_plan(plan).await?; + let results = df.collect().await?; + + assert_eq!(results.len(), 1); + + assert_eq!( + results[0] + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 55.0 + ); + + assert_eq!( + results[0] + .column(2) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 5.5 + ); + + assert_eq!( + results[0] + .column(3) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 10 + ); + + Ok(()) + } }