Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::types::format::FormatOptions;

use crate::client;
use crate::hooks::set_show::SetShowHook;
use crate::hooks::transactions::TransactionStatementHook;
use crate::hooks::QueryHook;
use crate::{client, planner};
use arrow_pg::datatypes::df;
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
Expand Down Expand Up @@ -215,8 +215,7 @@ impl ExtendedQueryHandler for DfSessionService {
if !self.query_hooks.is_empty() {
if let (_, Some((statement, plan))) = &portal.statement.statement {
// TODO: in the case where query hooks all return None, we do the param handling again later.
let param_types = plan
.get_parameter_types()
let param_types = planner::get_inferred_parameter_types(plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let param_values: ParamValues =
Expand All @@ -240,8 +239,7 @@ impl ExtendedQueryHandler for DfSessionService {
}

if let (_, Some((statement, plan))) = &portal.statement.statement {
let param_types = plan
.get_parameter_types()
let param_types = planner::get_inferred_parameter_types(plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let param_values =
Expand Down Expand Up @@ -381,8 +379,7 @@ impl QueryParser for Parser {

fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
if let (_, Some((_, plan))) = stmt {
let params = plan
.get_parameter_types()
let params = planner::get_inferred_parameter_types(plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let mut param_types = Vec::with_capacity(params.len());
Expand Down
1 change: 1 addition & 0 deletions datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod auth;
pub(crate) mod client;
mod handlers;
pub mod hooks;
mod planner;
#[cfg(any(test, debug_assertions))]
pub mod testing;

Expand Down
67 changes: 67 additions & 0 deletions datafusion-postgres/src/planner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::collections::{HashMap, HashSet};

use datafusion::arrow::datatypes::DataType;
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion::error::Result;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::Expr;

fn extract_placeholder_cast_types(plan: &LogicalPlan) -> Result<HashMap<String, Option<DataType>>> {
let mut placeholder_types = HashMap::new();
let mut casted_placeholders = HashSet::new();

plan.apply(|node| {
for expr in node.expressions() {
let _ = expr.apply(|e| {
if let Expr::Cast(cast) = e {
if let Expr::Placeholder(ph) = &*cast.expr {
placeholder_types.insert(ph.id.clone(), Some(cast.data_type.clone()));
casted_placeholders.insert(ph.id.clone());
}
}

if let Expr::Placeholder(ph) = e {
if !casted_placeholders.contains(&ph.id)
&& !placeholder_types.contains_key(&ph.id)
{
placeholder_types.insert(ph.id.clone(), None);
}
}

Ok(TreeNodeRecursion::Continue)
});
}
Ok(TreeNodeRecursion::Continue)
})?;

Ok(placeholder_types)
}

pub fn get_inferred_parameter_types(
plan: &LogicalPlan,
) -> Result<HashMap<String, Option<DataType>>> {
let param_types = plan.get_parameter_types()?;

let has_none = param_types.values().any(|v| v.is_none());

if !has_none {
Ok(param_types)
} else {
let cast_types = extract_placeholder_cast_types(plan)?;

let mut merged = param_types;

for (id, opt_type) in cast_types {
merged
.entry(id)
.and_modify(|existing| {
if existing.is_none() {
*existing = opt_type.clone();
}
})
.or_insert(opt_type);
}

Ok(merged)
}
}
Loading