2121
2222import logging
2323from datetime import datetime , date
24+ from types import ModuleType
2425
2526from sqlalchemy import types as sqltypes
2627from sqlalchemy .engine import default , reflection
@@ -202,6 +203,12 @@ def initialize(self, connection):
202203 self .default_schema_name = \
203204 self ._get_default_schema_name (connection )
204205
206+ def set_isolation_level (self , dbapi_connection , level ):
207+ """
208+ For CrateDB, this is implemented as a noop.
209+ """
210+ pass
211+
205212 def do_rollback (self , connection ):
206213 # if any exception is raised by the dbapi, sqlalchemy by default
207214 # attempts to do a rollback crate doesn't support rollbacks.
@@ -220,7 +227,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
220227 use_ssl = asbool (kwargs .pop ("ssl" , False ))
221228 if use_ssl :
222229 servers = ["https://" + server for server in servers ]
223- return self .dbapi .connect (servers = servers , ** kwargs )
230+
231+ is_module = isinstance (self .dbapi , ModuleType )
232+ if is_module :
233+ driver_name = self .dbapi .__name__
234+ else :
235+ driver_name = self .dbapi .__class__ .__name__
236+ if driver_name == "crate.client" :
237+ if "database" in kwargs :
238+ del kwargs ["database" ]
239+ return self .dbapi .connect (servers = servers , ** kwargs )
240+ elif driver_name in ["psycopg" , "PsycopgAdaptDBAPI" , "AsyncAdapt_asyncpg_dbapi" ]:
241+ return self .dbapi .connect (host = host , port = port , ** kwargs )
242+ else :
243+ raise ValueError (f"Unknown driver variant: { driver_name } " )
244+
224245 return self .dbapi .connect (** kwargs )
225246
226247 def _get_default_schema_name (self , connection ):
@@ -266,11 +287,11 @@ def get_schema_names(self, connection, **kw):
266287 def get_table_names (self , connection , schema = None , ** kw ):
267288 if schema is None :
268289 schema = self ._get_effective_schema_name (connection )
269- cursor = connection .exec_driver_sql (
290+ cursor = connection .exec_driver_sql (self . _format_query (
270291 "SELECT table_name FROM information_schema.tables "
271292 "WHERE {0} = ? "
272293 "AND table_type = 'BASE TABLE' "
273- "ORDER BY table_name ASC, {0} ASC" .format (self .schema_column ),
294+ "ORDER BY table_name ASC, {0} ASC" ) .format (self .schema_column ),
274295 (schema or self .default_schema_name , )
275296 )
276297 return [row [0 ] for row in cursor .fetchall ()]
@@ -292,7 +313,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
292313 "AND column_name !~ ?" \
293314 .format (self .schema_column )
294315 cursor = connection .exec_driver_sql (
295- query ,
316+ self . _format_query ( query ) ,
296317 (table_name ,
297318 schema or self .default_schema_name ,
298319 r"(.*)\[\'(.*)\'\]" ) # regex to filter subscript
@@ -331,7 +352,7 @@ def result_fun(result):
331352 return set (rows [0 ] if rows else [])
332353
333354 pk_result = engine .exec_driver_sql (
334- query ,
355+ self . _format_query ( query ) ,
335356 (table_name , schema or self .default_schema_name )
336357 )
337358 pks = result_fun (pk_result )
@@ -372,6 +393,17 @@ def has_ilike_operator(self):
372393 server_version_info = self .server_version_info
373394 return server_version_info is not None and server_version_info >= (4 , 1 , 0 )
374395
396+ def _format_query (self , query ):
397+ """
398+ When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
399+ the paramstyle is not `qmark`, but `pyformat`.
400+
401+ TODO: Review: Is it legit and sane? Are there alternatives?
402+ """
403+ if self .paramstyle == "pyformat" :
404+ query = query .replace ("= ?" , "= %s" ).replace ("!~ ?" , "!~ %s" )
405+ return query
406+
375407
376408class DateTrunc (functions .GenericFunction ):
377409 name = "date_trunc"
0 commit comments