From fab4d1e83531671d3ce838925d7ba8e2fcde0cf7 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 15 Feb 2026 19:11:33 +0800 Subject: [PATCH 1/4] feat: add rule to support ::regclass::oid cast --- datafusion-pg-catalog/src/sql/parser.rs | 2 + datafusion-pg-catalog/src/sql/rules.rs | 165 ++++++++++++++++++++++++ 2 files changed, 167 insertions(+) diff --git a/datafusion-pg-catalog/src/sql/parser.rs b/datafusion-pg-catalog/src/sql/parser.rs index a110ca7..719df90 100644 --- a/datafusion-pg-catalog/src/sql/parser.rs +++ b/datafusion-pg-catalog/src/sql/parser.rs @@ -18,6 +18,7 @@ use super::rules::RemoveSubqueryFromProjection; use super::rules::RemoveUnsupportedTypes; use super::rules::ResolveUnqualifiedIdentifer; use super::rules::RewriteArrayAnyAllOperation; +use super::rules::RewriteRegclassCastToSubquery; use super::rules::SqlStatementRewriteRule; const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[ @@ -228,6 +229,7 @@ impl PostgresCompatibilityParser { Arc::new(RewriteArrayAnyAllOperation), Arc::new(PrependUnqualifiedPgTableName), Arc::new(RemoveQualifier), + Arc::new(RewriteRegclassCastToSubquery::new()), Arc::new(RemoveUnsupportedTypes::new()), Arc::new(FixArrayLiteral), Arc::new(CurrentUserVariableToSessionUserFunctionCall), diff --git a/datafusion-pg-catalog/src/sql/rules.rs b/datafusion-pg-catalog/src/sql/rules.rs index 91c6553..8538a37 100644 --- a/datafusion-pg-catalog/src/sql/rules.rs +++ b/datafusion-pg-catalog/src/sql/rules.rs @@ -31,6 +31,8 @@ use datafusion::sql::sqlparser::ast::Value; use datafusion::sql::sqlparser::ast::ValueWithSpan; use datafusion::sql::sqlparser::ast::VisitMut; use datafusion::sql::sqlparser::ast::VisitorMut; +use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion::sql::sqlparser::parser::Parser; pub trait SqlStatementRewriteRule: Send + Sync + Debug { fn rewrite(&self, s: Statement) -> Statement; @@ -382,6 +384,139 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes { } } +/// Rewrite regclass::oid cast to subquery +/// +/// This rewrites patterns like `$1::regclass::oid` to +/// `(SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)` +#[derive(Debug)] +pub struct RewriteRegclassCastToSubquery(Box); + +impl RewriteRegclassCastToSubquery { + pub fn new() -> Self { + let sql = "SELECT oid FROM pg_catalog.pg_class WHERE relname = $1"; + let dialect = PostgreSqlDialect {}; + let query = Parser::parse_sql(&dialect, sql) + .map(|mut stmts| { + let stmt = stmts.remove(0); + if let Statement::Query(query) = stmt { + query + } else { + unreachable!() + } + }) + .expect("Failed to parse prepared query"); + Self(query) + } +} + +struct RewriteRegclassCastToSubqueryVisitor(Box); + +impl RewriteRegclassCastToSubqueryVisitor { + pub fn new(query: Box) -> Self { + Self(query) + } + + fn create_subquery(&self, expr: &Expr) -> Expr { + let mut query = self.0.clone(); + if let SetExpr::Select(select) = query.body.as_mut() { + if let Some(selection) = &mut select.selection { + if let Expr::BinaryOp { right, .. } = selection { + *right = Box::new(expr.clone()); + } + } + } + Expr::Subquery(query) + } + + fn is_regclass_to_oid_cast(&self, expr: &Expr) -> bool { + if let Expr::Cast { + kind, + data_type, + expr: inner_expr, + format: _, + } = expr + { + if *kind == CastKind::DoubleColon { + let dt_lower = data_type.to_string().to_lowercase(); + if dt_lower == "oid" || dt_lower == "pg_catalog.oid" { + return self.is_regclass_cast(inner_expr); + } + } + } + false + } + + fn is_regclass_cast(&self, expr: &Expr) -> bool { + if let Expr::Cast { + kind, + data_type, + expr: _, + format: _, + } = expr + { + if *kind == CastKind::DoubleColon { + let dt_lower = data_type.to_string().to_lowercase(); + return dt_lower == "regclass" || dt_lower == "pg_catalog.regclass"; + } + } + false + } + + fn extract_inner_expr(&self, expr: &Expr) -> Option { + if let Expr::Cast { + kind, + data_type, + expr: inner_expr, + format: _, + } = expr + { + if *kind == CastKind::DoubleColon { + let dt_lower = data_type.to_string().to_lowercase(); + if dt_lower == "oid" || dt_lower == "pg_catalog.oid" { + if let Expr::Cast { + kind: inner_kind, + data_type: inner_data_type, + expr: inner_inner_expr, + format: _, + } = inner_expr.as_ref() + { + if *inner_kind == CastKind::DoubleColon { + let inner_dt_lower = inner_data_type.to_string().to_lowercase(); + if inner_dt_lower == "regclass" + || inner_dt_lower == "pg_catalog.regclass" + { + return Some((**inner_inner_expr).clone()); + } + } + } + } + } + } + None + } +} + +impl VisitorMut for RewriteRegclassCastToSubqueryVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + if self.is_regclass_to_oid_cast(expr) { + if let Some(inner_expr) = self.extract_inner_expr(expr) { + *expr = self.create_subquery(&inner_expr); + } + } + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RewriteRegclassCastToSubquery { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = RewriteRegclassCastToSubqueryVisitor::new(self.0.clone()); + let _ = s.visit(&mut visitor); + s + } +} + /// Rewrite Postgres's ANY operator to array_contains #[derive(Debug)] pub struct RewriteArrayAnyAllOperation; @@ -997,6 +1132,36 @@ mod tests { ); } + #[test] + fn test_rewrite_regclass_cast_to_subquery() { + let rules: Vec> = + vec![Arc::new(RewriteRegclassCastToSubquery::new())]; + + assert_rewrite!( + &rules, + "SELECT $1::regclass::oid", + "SELECT (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)" + ); + + assert_rewrite!( + &rules, + "SELECT $1::pg_catalog.regclass::oid", + "SELECT (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)" + ); + + assert_rewrite!( + &rules, + "SELECT $1::pg_catalog.regclass::pg_catalog.oid", + "SELECT (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)" + ); + + assert_rewrite!( + &rules, + "SELECT * FROM pg_catalog.pg_class WHERE oid = 't'::pg_catalog.regclass::pg_catalog.oid", + "SELECT * FROM pg_catalog.pg_class WHERE oid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = 't')" + ); + } + #[test] fn test_any_to_array_contains() { let rules: Vec> = From 4836dd42748744a66999af293ddfb4b494969034 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 15 Feb 2026 19:36:04 +0800 Subject: [PATCH 2/4] test: add pgadbc sql tests --- datafusion-postgres/tests/pgadbc.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 datafusion-postgres/tests/pgadbc.rs diff --git a/datafusion-postgres/tests/pgadbc.rs b/datafusion-postgres/tests/pgadbc.rs new file mode 100644 index 0000000..24db14a --- /dev/null +++ b/datafusion-postgres/tests/pgadbc.rs @@ -0,0 +1,24 @@ +use pgwire::api::query::SimpleQueryHandler; + +use datafusion_postgres::testing::*; + +const PGADBC_QUERIES: &[&str] = &[ + "SELECT attname, atttypid FROM pg_catalog.pg_class AS cls INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid WHERE attr.attnum >= 0 AND cls.oid = 1602::regclass::oid ORDER BY attr.attnum", + + +]; + +#[tokio::test] +pub async fn test_pgadbc_metadata_sql() { + env_logger::init(); + let service = setup_handlers(); + let mut client = MockClient::new(); + + for query in PGADBC_QUERIES { + SimpleQueryHandler::do_query(&service, &mut client, query) + .await + .unwrap_or_else(|e| { + panic!("failed to run sql:\n--------------\n {query}\n--------------\n{e}") + }); + } +} From b5bc096fdf8172292c82ddea2780dbf51d104c59 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 15 Feb 2026 19:52:43 +0800 Subject: [PATCH 3/4] fix: lint --- datafusion-pg-catalog/src/sql/rules.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/datafusion-pg-catalog/src/sql/rules.rs b/datafusion-pg-catalog/src/sql/rules.rs index 8538a37..a453860 100644 --- a/datafusion-pg-catalog/src/sql/rules.rs +++ b/datafusion-pg-catalog/src/sql/rules.rs @@ -391,6 +391,12 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes { #[derive(Debug)] pub struct RewriteRegclassCastToSubquery(Box); +impl Default for RewriteRegclassCastToSubquery { + fn default() -> Self { + Self::new() + } +} + impl RewriteRegclassCastToSubquery { pub fn new() -> Self { let sql = "SELECT oid FROM pg_catalog.pg_class WHERE relname = $1"; @@ -419,10 +425,8 @@ impl RewriteRegclassCastToSubqueryVisitor { fn create_subquery(&self, expr: &Expr) -> Expr { let mut query = self.0.clone(); if let SetExpr::Select(select) = query.body.as_mut() { - if let Some(selection) = &mut select.selection { - if let Expr::BinaryOp { right, .. } = selection { - *right = Box::new(expr.clone()); - } + if let Some(Expr::BinaryOp { right, .. }) = &mut select.selection { + **right = expr.clone(); } } Expr::Subquery(query) From 41b8be4cf579640bacc65593a7fa151e579e9423 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Fri, 20 Feb 2026 17:49:40 +0800 Subject: [PATCH 4/4] feat: update sql and support schema --- datafusion-pg-catalog/src/sql/rules.rs | 40 ++++++++++++++++++++------ datafusion-postgres/tests/pgadbc.rs | 2 +- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/datafusion-pg-catalog/src/sql/rules.rs b/datafusion-pg-catalog/src/sql/rules.rs index a453860..8da0c13 100644 --- a/datafusion-pg-catalog/src/sql/rules.rs +++ b/datafusion-pg-catalog/src/sql/rules.rs @@ -399,7 +399,15 @@ impl Default for RewriteRegclassCastToSubquery { impl RewriteRegclassCastToSubquery { pub fn new() -> Self { - let sql = "SELECT oid FROM pg_catalog.pg_class WHERE relname = $1"; + let sql = "SELECT c.oid +FROM pg_catalog.pg_class c +JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace +CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) p +WHERE n.nspname = COALESCE( + CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, + current_schema() +) +AND c.relname = p.parts[-1]"; let dialect = PostgreSqlDialect {}; let query = Parser::parse_sql(&dialect, sql) .map(|mut stmts| { @@ -423,12 +431,26 @@ impl RewriteRegclassCastToSubqueryVisitor { } fn create_subquery(&self, expr: &Expr) -> Expr { - let mut query = self.0.clone(); - if let SetExpr::Select(select) = query.body.as_mut() { - if let Some(Expr::BinaryOp { right, .. }) = &mut select.selection { - **right = expr.clone(); + struct PlaceholderReplacer(Expr); + + impl VisitorMut for PlaceholderReplacer { + type Break = (); + + fn pre_visit_expr(&mut self, e: &mut Expr) -> ControlFlow { + if let Expr::Value(ValueWithSpan { + value: Value::Placeholder(_placeholder), + .. + }) = e + { + *e = self.0.clone(); + } + ControlFlow::Continue(()) } } + + let mut query = self.0.clone(); + let mut replacer = PlaceholderReplacer(expr.clone()); + let _ = query.visit(&mut replacer); Expr::Subquery(query) } @@ -1144,25 +1166,25 @@ mod tests { assert_rewrite!( &rules, "SELECT $1::regclass::oid", - "SELECT (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)" + "SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])" ); assert_rewrite!( &rules, "SELECT $1::pg_catalog.regclass::oid", - "SELECT (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)" + "SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])" ); assert_rewrite!( &rules, "SELECT $1::pg_catalog.regclass::pg_catalog.oid", - "SELECT (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)" + "SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])" ); assert_rewrite!( &rules, "SELECT * FROM pg_catalog.pg_class WHERE oid = 't'::pg_catalog.regclass::pg_catalog.oid", - "SELECT * FROM pg_catalog.pg_class WHERE oid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = 't')" + "SELECT * FROM pg_catalog.pg_class WHERE oid = (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident('t'::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])" ); } diff --git a/datafusion-postgres/tests/pgadbc.rs b/datafusion-postgres/tests/pgadbc.rs index 24db14a..cd7a515 100644 --- a/datafusion-postgres/tests/pgadbc.rs +++ b/datafusion-postgres/tests/pgadbc.rs @@ -3,7 +3,7 @@ use pgwire::api::query::SimpleQueryHandler; use datafusion_postgres::testing::*; const PGADBC_QUERIES: &[&str] = &[ - "SELECT attname, atttypid FROM pg_catalog.pg_class AS cls INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid WHERE attr.attnum >= 0 AND cls.oid = 1602::regclass::oid ORDER BY attr.attnum", + "SELECT attname, atttypid FROM pg_catalog.pg_class AS cls INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid WHERE attr.attnum >= 0 AND cls.oid = 'clubs'::regclass::oid ORDER BY attr.attnum", ];