Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 173 additions & 3 deletions datafusion-uwheel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use uwheel::{
sum::{F64SumAggregator, U32SumAggregator},
},
wheels::read::ReaderWheel,
Aggregator, Conf, Entry, HawConf, RwWheel, WheelRange,
Aggregator, Conf, Duration, Entry, HawConf, RwWheel, WheelRange,
};

/// Custom aggregator implementations that are used by this crate.
Expand Down Expand Up @@ -314,6 +314,64 @@ impl UWheelOptimizer {
_ => None,
}
}

LogicalPlan::Aggregate(agg) => {
let group_expr = agg.group_expr.first()?;

// Only continue if the aggregation has a filter
let LogicalPlan::Filter(filter) = agg.input.as_ref() else {
return None;
};

let (wheel_range, _) = extract_filter_expr(&filter.predicate, &self.time_column)?;

match group_expr {
Expr::ScalarFunction(func) if func.name() == "date_trunc" => {
let interval = func.args.first()?;
if let Expr::Literal(ScalarValue::Utf8(duration)) = interval {
match duration.as_ref()?.as_str() {
"second" => {
unimplemented!("date_trunc('second') group by is not supported")
}
"minute" => {
unimplemented!("date_trunc('minute') group by is not supported")
}
"hour" => {
unimplemented!("date_trunc('hour') group by is not supported")
}
"day" => {
let res = self
.wheels
.count
.group_by(wheel_range, Duration::DAY)
.unwrap_or_default()
.iter()
.map(|(k, v)| ((*k * 1_000) as i64, *v as i64)) // transform milliseconds to microseconds by multiplying by 1_000
.collect();

let schema = Arc::new(plan.schema().clone().as_arrow().clone());

return uwheel_group_by_to_table_scan(res, schema).ok();
}
"week" => {
unimplemented!("date_trunc('week') group by is not supported")
}
"month" => {
unimplemented!("date_trunc('month') group by is not supported")
}
"year" => {
unimplemented!("date_trunc('year') group by is not supported")
}
_ => {}
}
}
}
_ => {
unimplemented!("We only support scalar function date_trunc for group by expression now")
}
}
None
}
// Check whether it follows the pattern: SELECT * FROM X WHERE TIME >= X AND TIME <= Y
LogicalPlan::Filter(filter) => self.try_rewrite_filter(filter, plan),
_ => None,
Expand Down Expand Up @@ -440,6 +498,25 @@ fn uwheel_agg_to_table_scan(result: f64, schema: SchemaRef) -> Result<LogicalPla
mem_table_as_table_scan(mem_table, df_schema)
}

// Converts a uwheel group by result to a TableScan with a MemTable as source
// currently only supports timestamp group by
fn uwheel_group_by_to_table_scan(
result: Vec<(i64, i64)>,
schema: SchemaRef,
) -> Result<LogicalPlan> {
let group_by =
TimestampMicrosecondArray::from(result.iter().map(|(k, _)| *k).collect::<Vec<_>>());

let agg = Int64Array::from(result.iter().map(|(_, v)| *v).collect::<Vec<_>>());

let record_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(group_by), Arc::new(agg)])?;

let df_schema = Arc::new(DFSchema::try_from(schema.clone())?);
let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?;
mem_table_as_table_scan(mem_table, df_schema)
}

// helper for possibly removing the table name from the expression key
fn maybe_replace_table_name(expr: &Expr, table_name: &str) -> String {
let expr_str = expr.to_string();
Expand Down Expand Up @@ -568,7 +645,7 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
func,
args,
..
} if func.name() == "COUNT" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
} if (func.name() == "COUNT" || func.name() == "count") && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
}

// Helper methods to build the UWheelOptimizer
Expand Down Expand Up @@ -942,9 +1019,11 @@ mod tests {
use chrono::TimeZone;
use datafusion::arrow::datatypes::{Field, Schema, TimeUnit};
use datafusion::execution::SessionStateBuilder;
use datafusion::functions_aggregate::count::count;
use datafusion::functions_aggregate::expr_fn::avg;
use datafusion::functions_aggregate::min_max::{max, min};
use datafusion::logical_expr::test::function_stub::{count, sum};
use datafusion::functions_aggregate::sum::sum;
use datafusion::prelude::date_trunc;

use super::*;
use builder::Builder;
Expand Down Expand Up @@ -1405,4 +1484,95 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn group_by_count_aggregation_rewrite() -> Result<()> {
let optimizer = test_optimizer().await?;

let temporal_filter = col("timestamp")
.gt_eq(lit("2024-05-10T00:00:00Z"))
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));

let plan =
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
.filter(temporal_filter)?
.aggregate(
vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp)
vec![count(wildcard())],
)?
.project(vec![
date_trunc(lit("day"), col("timestamp")),
count(wildcard()),
])?
.build()?;

// Assert that the original plan is a Projection
assert!(matches!(plan, LogicalPlan::Projection(_)));

let rewritten = optimizer.try_rewrite(&plan).unwrap();
// assert it was rewritten to a TableScan
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));

Ok(())
}

#[tokio::test]
async fn group_by_count_aggregation_exec() -> Result<()> {
let optimizer = test_optimizer().await?;

let temporal_filter = col("timestamp")
.gt_eq(lit("2024-05-10T00:00:00Z"))
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));

let plan =
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
.filter(temporal_filter)?
.aggregate(
vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp)
vec![count(wildcard())],
)?
.project(vec![
date_trunc(lit("day"), col("timestamp")),
count(wildcard()),
])?
.build()?;

let ctx = SessionContext::new();
ctx.register_table("test", optimizer.provider().clone())?;

// Set UWheelOptimizer as optimizer rule
let session_state = SessionStateBuilder::new()
.with_optimizer_rules(vec![optimizer.clone()])
.build();
let uwheel_ctx = SessionContext::new_with_state(session_state);

// Run the query through the ctx that has our OptimizerRule
let df = uwheel_ctx.execute_logical_plan(plan).await?;
let results = df.collect().await?;

assert_eq!(results.len(), 1);

assert_eq!(
results[0]
.column(0)
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.unwrap()
.value(0)
/ 1000,
1_715_299_200_000
);

assert_eq!(
results[0]
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.value(0),
10
);

Ok(())
}
}