diff --git a/jsql/__init__.py b/jsql/__init__.py index b103431..0a28ad2 100644 --- a/jsql/__init__.py +++ b/jsql/__init__.py @@ -7,6 +7,11 @@ import six import itertools, collections +import sqlalchemy + +SQLALCHEMY_VERSION = sqlalchemy.__version__ +IS_VERSION_2 = SQLALCHEMY_VERSION.startswith("2.") + class UnsafeSqlException(Exception): pass @@ -78,12 +83,21 @@ def dangerously_inject_sql(value): jenv.filters["dangerously_inject_sql"] = dangerously_inject_sql jenv.globals["comma"] = DangerouslyInjectedSql(",") - def execute_sql(engine, query, params): + from sqlalchemy.engine import Engine from sqlalchemy.sql import text q = text(query) - is_session = 'session' in repr(engine.__class__).lower() - return engine.execute(q, params=params) if is_session else engine.execute(q, **params) + + if IS_VERSION_2: + if hasattr(engine, 'execute'): + return engine.execute(q, params) + + if isinstance(engine, Engine): + with engine.connect() as conn: + return conn.execute(q, params) + else: + is_session = 'session' in repr(engine.__class__).lower() + return engine.execute(q, params=params) if is_session else engine.execute(q, **params) BINDPARAM_PREFIX = 'bp' def gen_bindparam(params):