diff --git a/datafusion-uwheel/src/lib.rs b/datafusion-uwheel/src/lib.rs index 282b315..5d6d95c 100644 --- a/datafusion-uwheel/src/lib.rs +++ b/datafusion-uwheel/src/lib.rs @@ -47,7 +47,7 @@ use uwheel::{ sum::{F64SumAggregator, U32SumAggregator}, }, wheels::read::ReaderWheel, - Aggregator, Conf, Entry, HawConf, RwWheel, WheelRange, + Aggregator, Conf, Duration, Entry, HawConf, RwWheel, WheelRange, }; /// Custom aggregator implementations that are used by this crate. @@ -314,6 +314,64 @@ impl UWheelOptimizer { _ => None, } } + + LogicalPlan::Aggregate(agg) => { + let group_expr = agg.group_expr.first()?; + + // Only continue if the aggregation has a filter + let LogicalPlan::Filter(filter) = agg.input.as_ref() else { + return None; + }; + + let (wheel_range, _) = extract_filter_expr(&filter.predicate, &self.time_column)?; + + match group_expr { + Expr::ScalarFunction(func) if func.name() == "date_trunc" => { + let interval = func.args.first()?; + if let Expr::Literal(ScalarValue::Utf8(duration)) = interval { + match duration.as_ref()?.as_str() { + "second" => { + unimplemented!("date_trunc('second') group by is not supported") + } + "minute" => { + unimplemented!("date_trunc('minute') group by is not supported") + } + "hour" => { + unimplemented!("date_trunc('hour') group by is not supported") + } + "day" => { + let res = self + .wheels + .count + .group_by(wheel_range, Duration::DAY) + .unwrap_or_default() + .iter() + .map(|(k, v)| ((*k * 1_000) as i64, *v as i64)) // transform milliseconds to microseconds by multiplying by 1_000 + .collect(); + + let schema = Arc::new(plan.schema().clone().as_arrow().clone()); + + return uwheel_group_by_to_table_scan(res, schema).ok(); + } + "week" => { + unimplemented!("date_trunc('week') group by is not supported") + } + "month" => { + unimplemented!("date_trunc('month') group by is not supported") + } + "year" => { + unimplemented!("date_trunc('year') group by is not supported") + } + _ => {} + } + } + } + _ => { + unimplemented!("We only support scalar function date_trunc for group by expression now") + } + } + None + } // Check whether it follows the pattern: SELECT * FROM X WHERE TIME >= X AND TIME <= Y LogicalPlan::Filter(filter) => self.try_rewrite_filter(filter, plan), _ => None, @@ -440,6 +498,25 @@ fn uwheel_agg_to_table_scan(result: f64, schema: SchemaRef) -> Result, + schema: SchemaRef, +) -> Result { + let group_by = + TimestampMicrosecondArray::from(result.iter().map(|(k, _)| *k).collect::>()); + + let agg = Int64Array::from(result.iter().map(|(_, v)| *v).collect::>()); + + let record_batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(group_by), Arc::new(agg)])?; + + let df_schema = Arc::new(DFSchema::try_from(schema.clone())?); + let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?; + mem_table_as_table_scan(mem_table, df_schema) +} + // helper for possibly removing the table name from the expression key fn maybe_replace_table_name(expr: &Expr, table_name: &str) -> String { let expr_str = expr.to_string(); @@ -568,7 +645,7 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { func, args, .. - } if func.name() == "COUNT" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) + } if (func.name() == "COUNT" || func.name() == "count") && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } // Helper methods to build the UWheelOptimizer @@ -942,9 +1019,11 @@ mod tests { use chrono::TimeZone; use datafusion::arrow::datatypes::{Field, Schema, TimeUnit}; use datafusion::execution::SessionStateBuilder; + use datafusion::functions_aggregate::count::count; use datafusion::functions_aggregate::expr_fn::avg; use datafusion::functions_aggregate::min_max::{max, min}; - use datafusion::logical_expr::test::function_stub::{count, sum}; + use datafusion::functions_aggregate::sum::sum; + use datafusion::prelude::date_trunc; use super::*; use builder::Builder; @@ -1405,4 +1484,95 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn group_by_count_aggregation_rewrite() -> Result<()> { + let optimizer = test_optimizer().await?; + + let temporal_filter = col("timestamp") + .gt_eq(lit("2024-05-10T00:00:00Z")) + .and(col("timestamp").lt(lit("2024-05-10T00:00:10Z"))); + + let plan = + LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)? + .filter(temporal_filter)? + .aggregate( + vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp) + vec![count(wildcard())], + )? + .project(vec![ + date_trunc(lit("day"), col("timestamp")), + count(wildcard()), + ])? + .build()?; + + // Assert that the original plan is a Projection + assert!(matches!(plan, LogicalPlan::Projection(_))); + + let rewritten = optimizer.try_rewrite(&plan).unwrap(); + // assert it was rewritten to a TableScan + assert!(matches!(rewritten, LogicalPlan::TableScan(_))); + + Ok(()) + } + + #[tokio::test] + async fn group_by_count_aggregation_exec() -> Result<()> { + let optimizer = test_optimizer().await?; + + let temporal_filter = col("timestamp") + .gt_eq(lit("2024-05-10T00:00:00Z")) + .and(col("timestamp").lt(lit("2024-05-10T00:00:10Z"))); + + let plan = + LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)? + .filter(temporal_filter)? + .aggregate( + vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp) + vec![count(wildcard())], + )? + .project(vec![ + date_trunc(lit("day"), col("timestamp")), + count(wildcard()), + ])? + .build()?; + + let ctx = SessionContext::new(); + ctx.register_table("test", optimizer.provider().clone())?; + + // Set UWheelOptimizer as optimizer rule + let session_state = SessionStateBuilder::new() + .with_optimizer_rules(vec![optimizer.clone()]) + .build(); + let uwheel_ctx = SessionContext::new_with_state(session_state); + + // Run the query through the ctx that has our OptimizerRule + let df = uwheel_ctx.execute_logical_plan(plan).await?; + let results = df.collect().await?; + + assert_eq!(results.len(), 1); + + assert_eq!( + results[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + / 1000, + 1_715_299_200_000 + ); + + assert_eq!( + results[0] + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 10 + ); + + Ok(()) + } }