Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 297 additions & 43 deletions datafusion-uwheel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,53 +337,165 @@ impl UWheelOptimizer {
return None;
};

let (wheel_range, _) = extract_filter_expr(&filter.predicate, &self.time_column)?;
let (wheel_range, expr_key) =
match extract_filter_expr(&filter.predicate, &self.time_column)? {
(range, Some(expr)) => (range, maybe_replace_table_name(&expr, &self.name)),
(range, None) => (range, STAR_AGGREGATION_ALIAS.to_string()),
};

match group_expr {
Expr::ScalarFunction(func) if func.name() == "date_trunc" => {
let interval = func.args.first()?;
if let Expr::Literal(ScalarValue::Utf8(duration)) = interval {
match duration.as_ref()?.as_str() {
"second" => {
unimplemented!("date_trunc('second') group by is not supported")
}
"minute" => {
unimplemented!("date_trunc('minute') group by is not supported")
}
"hour" => {
unimplemented!("date_trunc('hour') group by is not supported")
}
"day" => {
let res = self
.wheels
.count
.group_by(wheel_range, Duration::DAY)
.unwrap_or_default()
.iter()
.map(|(k, v)| ((*k * 1_000) as i64, *v as i64)) // transform milliseconds to microseconds by multiplying by 1_000
.collect();

let schema = Arc::new(plan.schema().clone().as_arrow().clone());

return uwheel_group_by_to_table_scan(res, schema).ok();
}
"week" => {
unimplemented!("date_trunc('week') group by is not supported")
}
"month" => {
unimplemented!("date_trunc('month') group by is not supported")
}
"year" => {
unimplemented!("date_trunc('year') group by is not supported")
let group_by_interval = match duration.as_ref()?.as_str() {
"second" => Duration::SECOND,
"minute" => Duration::MINUTE,
"hour" => Duration::HOUR,
"day" => Duration::DAY,
"week" => Duration::WEEK,
_ => return None,
};

let mut group_agg_result = Vec::new();
let mut group_col = None;

let mut count_idx = None;

for (idx, agg) in agg.aggr_expr.iter().enumerate() {
let (agg_type, col) = match agg {
// COUNT(*)
Expr::AggregateFunction(agg)
if is_count_star_aggregate(agg) =>
{
(UWheelAggregate::Count, None)
}

// COUNT(*)
Expr::Alias(alias) if alias.name == COUNT_STAR_ALIAS => {
(UWheelAggregate::Count, None)
}

Expr::AggregateFunction(agg) => {
if agg.args.len() > 1 {
return None;
}
let col = match &agg.args[0] {
Expr::Column(col) => col,
_ => return None,
};
(func_def_to_aggregate_type(&agg.func)?, Some(col))
}

_ => return None,
};

let res = match agg_type {
UWheelAggregate::Count => {
count_idx = Some(idx);
self.wheels
.count
.group_by(wheel_range, group_by_interval)
.unwrap_or_default()
.iter()
.map(|(k, v)| (*k, *v as f64))
.collect()
}

UWheelAggregate::Avg => {
let wheel_key = format!(
"{}.{}.{}",
self.name,
col.unwrap().name,
expr_key
);
self.wheels
.avg
.lock()
.unwrap()
.get(&wheel_key)?
.group_by(wheel_range, group_by_interval)
.unwrap_or_default()
}

UWheelAggregate::Min => {
let wheel_key = format!(
"{}.{}.{}",
self.name,
col.unwrap().name,
expr_key
);
self.wheels
.min
.lock()
.unwrap()
.get(&wheel_key)?
.group_by(wheel_range, group_by_interval)
.unwrap_or_default()
}

UWheelAggregate::Max => {
let wheel_key = format!(
"{}.{}.{}",
self.name,
col.unwrap().name,
expr_key
);
self.wheels
.max
.lock()
.unwrap()
.get(&wheel_key)?
.group_by(wheel_range, group_by_interval)
.unwrap_or_default()
}

UWheelAggregate::Sum => {
let wheel_key = format!(
"{}.{}.{}",
self.name,
col.unwrap().name,
expr_key
);
self.wheels
.sum
.lock()
.unwrap()
.get(&wheel_key)?
.group_by(wheel_range, group_by_interval)
.unwrap_or_default()
}

_ => return None,
};

if group_col.is_none() {
group_col = Some(
res.iter()
.map(|(k, _)| (*k * 1_000) as i64)
.collect::<Vec<_>>(),
);
}
_ => {}

group_agg_result
.push(res.iter().map(|(_, v)| *v).collect::<Vec<_>>());
}

group_col.as_ref()?;

let schema = Arc::new(plan.schema().clone().as_arrow().clone());

return uwheel_group_by_to_table_scan(
group_col.unwrap(),
group_agg_result,
count_idx,
schema,
)
.ok();
}
}
_ => {
unimplemented!("We only support scalar function date_trunc for group by expression now")
}
_ => return None,
}

None
}

Expand Down Expand Up @@ -622,16 +734,26 @@ fn uwheel_agg_to_table_scan(result: f64, schema: SchemaRef) -> Result<LogicalPla
// Converts a uwheel group by result to a TableScan with a MemTable as source
// currently only supports timestamp group by
fn uwheel_group_by_to_table_scan(
result: Vec<(i64, i64)>,
group_col: Vec<i64>,
agg_result: Vec<Vec<f64>>,
count_idx: Option<usize>,
schema: SchemaRef,
) -> Result<LogicalPlan> {
let group_by =
TimestampMicrosecondArray::from(result.iter().map(|(k, _)| *k).collect::<Vec<_>>());
let group_by = TimestampMicrosecondArray::from(group_col);

let agg = Int64Array::from(result.iter().map(|(_, v)| *v).collect::<Vec<_>>());
let mut columns = vec![Arc::new(group_by) as Arc<dyn Array>];

for (idx, result) in agg_result.into_iter().enumerate() {
if count_idx.is_some() && idx == count_idx.unwrap() {
let data = Int64Array::from(result.iter().map(|v| *v as i64).collect::<Vec<_>>());
columns.push(Arc::new(data) as Arc<dyn Array>);
} else {
let data = Float64Array::from(result);
columns.push(Arc::new(data) as Arc<dyn Array>);
}
}

let record_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(group_by), Arc::new(agg)])?;
let record_batch = RecordBatch::try_new(schema.clone(), columns)?;

let df_schema = Arc::new(DFSchema::try_from(schema.clone())?);
let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?;
Expand Down Expand Up @@ -1898,4 +2020,136 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn group_by_multiple_aggregation_rewrite() -> Result<()> {
let optimizer = test_optimizer().await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Avg,
))
.await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Sum,
))
.await?;

let temporal_filter = col("timestamp")
.gt_eq(lit("2024-05-10T00:00:00Z"))
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));

let plan =
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
.filter(temporal_filter)?
.aggregate(
vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp)
vec![sum(col("agg_col")), avg(col("agg_col")), count(wildcard())],
)?
.project(vec![
date_trunc(lit("day"), col("timestamp")),
sum(col("agg_col")),
avg(col("agg_col")),
count(wildcard()),
])?
.build()?;

// Assert that the original plan is a Projection
assert!(matches!(plan, LogicalPlan::Projection(_)));

let rewritten = optimizer.try_rewrite(&plan).unwrap();
// assert it was rewritten to a TableScan
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));

Ok(())
}

#[tokio::test]
async fn group_by_multiple_aggregation_exec() -> Result<()> {
let optimizer = test_optimizer().await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Avg,
))
.await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Sum,
))
.await?;

let temporal_filter = col("timestamp")
.gt_eq(lit("2024-05-10T00:00:00Z"))
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));

let plan =
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
.filter(temporal_filter)?
.aggregate(
vec![date_trunc(lit("day"), col("timestamp"))], // GROUP BY date_trunc('day', timestamp)
vec![sum(col("agg_col")), avg(col("agg_col")), count(wildcard())],
)?
.project(vec![
date_trunc(lit("day"), col("timestamp")),
sum(col("agg_col")),
avg(col("agg_col")),
count(wildcard()),
])?
.build()?;

let ctx = SessionContext::new();
ctx.register_table("test", optimizer.provider().clone())?;

// Set UWheelOptimizer as optimizer rule
let session_state = SessionStateBuilder::new()
.with_optimizer_rules(vec![optimizer.clone()])
.build();
let uwheel_ctx = SessionContext::new_with_state(session_state);

// Run the query through the ctx that has our OptimizerRule
let df = uwheel_ctx.execute_logical_plan(plan).await?;
let results = df.collect().await?;

assert_eq!(results.len(), 1);

assert_eq!(
results[0]
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0),
55.0
);

assert_eq!(
results[0]
.column(2)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0),
5.5
);

assert_eq!(
results[0]
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.value(0),
10
);

Ok(())
}
}