Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion-pg-catalog/src/sql/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use super::rules::RemoveSubqueryFromProjection;
use super::rules::RemoveUnsupportedTypes;
use super::rules::ResolveUnqualifiedIdentifer;
use super::rules::RewriteArrayAnyAllOperation;
use super::rules::RewriteRegclassCastToSubquery;
use super::rules::SqlStatementRewriteRule;

const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[
Expand Down Expand Up @@ -228,6 +229,7 @@ impl PostgresCompatibilityParser {
Arc::new(RewriteArrayAnyAllOperation),
Arc::new(PrependUnqualifiedPgTableName),
Arc::new(RemoveQualifier),
Arc::new(RewriteRegclassCastToSubquery::new()),
Arc::new(RemoveUnsupportedTypes::new()),
Arc::new(FixArrayLiteral),
Arc::new(CurrentUserVariableToSessionUserFunctionCall),
Expand Down
169 changes: 169 additions & 0 deletions datafusion-pg-catalog/src/sql/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use datafusion::sql::sqlparser::ast::Value;
use datafusion::sql::sqlparser::ast::ValueWithSpan;
use datafusion::sql::sqlparser::ast::VisitMut;
use datafusion::sql::sqlparser::ast::VisitorMut;
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
use datafusion::sql::sqlparser::parser::Parser;

pub trait SqlStatementRewriteRule: Send + Sync + Debug {
fn rewrite(&self, s: Statement) -> Statement;
Expand Down Expand Up @@ -382,6 +384,143 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
}
}

/// Rewrite regclass::oid cast to subquery
///
/// This rewrites patterns like `$1::regclass::oid` to
/// `(SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)`
#[derive(Debug)]
pub struct RewriteRegclassCastToSubquery(Box<Query>);

impl Default for RewriteRegclassCastToSubquery {
fn default() -> Self {
Self::new()
}
}

impl RewriteRegclassCastToSubquery {
pub fn new() -> Self {
let sql = "SELECT oid FROM pg_catalog.pg_class WHERE relname = $1";
let dialect = PostgreSqlDialect {};
let query = Parser::parse_sql(&dialect, sql)
.map(|mut stmts| {
let stmt = stmts.remove(0);
if let Statement::Query(query) = stmt {
query
} else {
unreachable!()
}
})
.expect("Failed to parse prepared query");
Self(query)
}
}

struct RewriteRegclassCastToSubqueryVisitor(Box<Query>);

impl RewriteRegclassCastToSubqueryVisitor {
pub fn new(query: Box<Query>) -> Self {
Self(query)
}

fn create_subquery(&self, expr: &Expr) -> Expr {
let mut query = self.0.clone();
if let SetExpr::Select(select) = query.body.as_mut() {
if let Some(Expr::BinaryOp { right, .. }) = &mut select.selection {
**right = expr.clone();
}
}
Expr::Subquery(query)
}

fn is_regclass_to_oid_cast(&self, expr: &Expr) -> bool {
if let Expr::Cast {
kind,
data_type,
expr: inner_expr,
format: _,
} = expr
{
if *kind == CastKind::DoubleColon {
let dt_lower = data_type.to_string().to_lowercase();
if dt_lower == "oid" || dt_lower == "pg_catalog.oid" {
return self.is_regclass_cast(inner_expr);
}
}
}
false
}

fn is_regclass_cast(&self, expr: &Expr) -> bool {
if let Expr::Cast {
kind,
data_type,
expr: _,
format: _,
} = expr
{
if *kind == CastKind::DoubleColon {
let dt_lower = data_type.to_string().to_lowercase();
return dt_lower == "regclass" || dt_lower == "pg_catalog.regclass";
}
}
false
}

fn extract_inner_expr(&self, expr: &Expr) -> Option<Expr> {
if let Expr::Cast {
kind,
data_type,
expr: inner_expr,
format: _,
} = expr
{
if *kind == CastKind::DoubleColon {
let dt_lower = data_type.to_string().to_lowercase();
if dt_lower == "oid" || dt_lower == "pg_catalog.oid" {
if let Expr::Cast {
kind: inner_kind,
data_type: inner_data_type,
expr: inner_inner_expr,
format: _,
} = inner_expr.as_ref()
{
if *inner_kind == CastKind::DoubleColon {
let inner_dt_lower = inner_data_type.to_string().to_lowercase();
if inner_dt_lower == "regclass"
|| inner_dt_lower == "pg_catalog.regclass"
{
return Some((**inner_inner_expr).clone());
}
}
}
}
}
}
None
}
}

impl VisitorMut for RewriteRegclassCastToSubqueryVisitor {
type Break = ();

fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
if self.is_regclass_to_oid_cast(expr) {
if let Some(inner_expr) = self.extract_inner_expr(expr) {
*expr = self.create_subquery(&inner_expr);
}
}
ControlFlow::Continue(())
}
}

impl SqlStatementRewriteRule for RewriteRegclassCastToSubquery {
fn rewrite(&self, mut s: Statement) -> Statement {
let mut visitor = RewriteRegclassCastToSubqueryVisitor::new(self.0.clone());
let _ = s.visit(&mut visitor);
s
}
}

/// Rewrite Postgres's ANY operator to array_contains
#[derive(Debug)]
pub struct RewriteArrayAnyAllOperation;
Expand Down Expand Up @@ -997,6 +1136,36 @@ mod tests {
);
}

#[test]
fn test_rewrite_regclass_cast_to_subquery() {
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
vec![Arc::new(RewriteRegclassCastToSubquery::new())];

assert_rewrite!(
&rules,
"SELECT $1::regclass::oid",
"SELECT (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)"
);

assert_rewrite!(
&rules,
"SELECT $1::pg_catalog.regclass::oid",
"SELECT (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)"
);

assert_rewrite!(
&rules,
"SELECT $1::pg_catalog.regclass::pg_catalog.oid",
"SELECT (SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)"
);

assert_rewrite!(
&rules,
"SELECT * FROM pg_catalog.pg_class WHERE oid = 't'::pg_catalog.regclass::pg_catalog.oid",
"SELECT * FROM pg_catalog.pg_class WHERE oid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = 't')"
);
}

#[test]
fn test_any_to_array_contains() {
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
Expand Down
24 changes: 24 additions & 0 deletions datafusion-postgres/tests/pgadbc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use pgwire::api::query::SimpleQueryHandler;

use datafusion_postgres::testing::*;

const PGADBC_QUERIES: &[&str] = &[
"SELECT attname, atttypid FROM pg_catalog.pg_class AS cls INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid WHERE attr.attnum >= 0 AND cls.oid = 1602::regclass::oid ORDER BY attr.attnum",


];

#[tokio::test]
pub async fn test_pgadbc_metadata_sql() {
env_logger::init();
let service = setup_handlers();
let mut client = MockClient::new();

for query in PGADBC_QUERIES {
SimpleQueryHandler::do_query(&service, &mut client, query)
.await
.unwrap_or_else(|e| {
panic!("failed to run sql:\n--------------\n {query}\n--------------\n{e}")
});
}
}
Loading