diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index ea5c953..3eb5e5a 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -19,10 +19,10 @@ use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::types::format::FormatOptions; -use crate::client; use crate::hooks::set_show::SetShowHook; use crate::hooks::transactions::TransactionStatementHook; use crate::hooks::QueryHook; +use crate::{client, planner}; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; use datafusion_pg_catalog::sql::PostgresCompatibilityParser; @@ -215,8 +215,7 @@ impl ExtendedQueryHandler for DfSessionService { if !self.query_hooks.is_empty() { if let (_, Some((statement, plan))) = &portal.statement.statement { // TODO: in the case where query hooks all return None, we do the param handling again later. - let param_types = plan - .get_parameter_types() + let param_types = planner::get_inferred_parameter_types(plan) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let param_values: ParamValues = @@ -240,8 +239,7 @@ impl ExtendedQueryHandler for DfSessionService { } if let (_, Some((statement, plan))) = &portal.statement.statement { - let param_types = plan - .get_parameter_types() + let param_types = planner::get_inferred_parameter_types(plan) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let param_values = @@ -381,8 +379,7 @@ impl QueryParser for Parser { fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult> { if let (_, Some((_, plan))) = stmt { - let params = plan - .get_parameter_types() + let params = planner::get_inferred_parameter_types(plan) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let mut param_types = Vec::with_capacity(params.len()); diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index a83dfcb..1d1bb66 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -2,6 +2,7 @@ pub mod auth; pub(crate) mod client; mod handlers; pub mod hooks; +mod planner; #[cfg(any(test, debug_assertions))] pub mod testing; diff --git a/datafusion-postgres/src/planner.rs b/datafusion-postgres/src/planner.rs new file mode 100644 index 0000000..db33ef7 --- /dev/null +++ b/datafusion-postgres/src/planner.rs @@ -0,0 +1,67 @@ +use std::collections::{HashMap, HashSet}; + +use datafusion::arrow::datatypes::DataType; +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion::error::Result; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::Expr; + +fn extract_placeholder_cast_types(plan: &LogicalPlan) -> Result>> { + let mut placeholder_types = HashMap::new(); + let mut casted_placeholders = HashSet::new(); + + plan.apply(|node| { + for expr in node.expressions() { + let _ = expr.apply(|e| { + if let Expr::Cast(cast) = e { + if let Expr::Placeholder(ph) = &*cast.expr { + placeholder_types.insert(ph.id.clone(), Some(cast.data_type.clone())); + casted_placeholders.insert(ph.id.clone()); + } + } + + if let Expr::Placeholder(ph) = e { + if !casted_placeholders.contains(&ph.id) + && !placeholder_types.contains_key(&ph.id) + { + placeholder_types.insert(ph.id.clone(), None); + } + } + + Ok(TreeNodeRecursion::Continue) + }); + } + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(placeholder_types) +} + +pub fn get_inferred_parameter_types( + plan: &LogicalPlan, +) -> Result>> { + let param_types = plan.get_parameter_types()?; + + let has_none = param_types.values().any(|v| v.is_none()); + + if !has_none { + Ok(param_types) + } else { + let cast_types = extract_placeholder_cast_types(plan)?; + + let mut merged = param_types; + + for (id, opt_type) in cast_types { + merged + .entry(id) + .and_modify(|existing| { + if existing.is_none() { + *existing = opt_type.clone(); + } + }) + .or_insert(opt_type); + } + + Ok(merged) + } +}