redash/query_runner/databricks.py (208 lines of code) (raw):

import datetime import logging import os import sqlparse from redash.query_runner import ( NotSupported, register, BaseSQLQueryRunner, TYPE_STRING, TYPE_BOOLEAN, TYPE_DATE, TYPE_DATETIME, TYPE_INTEGER, TYPE_FLOAT, ) from redash.settings import cast_int_or_default from redash.utils import json_dumps, json_loads from redash import __version__, settings, statsd_client try: import pyodbc enabled = True except ImportError: enabled = False TYPES_MAP = { str: TYPE_STRING, bool: TYPE_BOOLEAN, datetime.date: TYPE_DATE, datetime.datetime: TYPE_DATETIME, int: TYPE_INTEGER, float: TYPE_FLOAT, } ROW_LIMIT = cast_int_or_default(os.environ.get("DATABRICKS_ROW_LIMIT"), 20000) logger = logging.getLogger(__name__) def _build_odbc_connection_string(**kwargs): return ";".join([f"{k}={v}" for k, v in kwargs.items()]) def split_sql_statements(query): def strip_trailing_comments(stmt): idx = len(stmt.tokens) - 1 while idx >= 0: tok = stmt.tokens[idx] if tok.is_whitespace or sqlparse.utils.imt( tok, i=sqlparse.sql.Comment, t=sqlparse.tokens.Comment ): stmt.tokens[idx] = sqlparse.sql.Token(sqlparse.tokens.Whitespace, " ") else: break idx -= 1 return stmt def strip_trailing_semicolon(stmt): idx = len(stmt.tokens) - 1 while idx >= 0: tok = stmt.tokens[idx] # we expect that trailing comments already are removed if not tok.is_whitespace: if ( sqlparse.utils.imt(tok, t=sqlparse.tokens.Punctuation) and tok.value == ";" ): stmt.tokens[idx] = sqlparse.sql.Token( sqlparse.tokens.Whitespace, " " ) break idx -= 1 return stmt def is_empty_statement(stmt): strip_comments = sqlparse.filters.StripCommentsFilter() # copy statement object. `copy.deepcopy` fails to do this, so just re-parse it st = sqlparse.engine.FilterStack() stmt = next(st.run(sqlparse.text_type(stmt))) sql = sqlparse.text_type(strip_comments.process(stmt)) return sql.strip() == "" stack = sqlparse.engine.FilterStack() result = [stmt for stmt in stack.run(query)] result = [strip_trailing_comments(stmt) for stmt in result] result = [strip_trailing_semicolon(stmt) for stmt in result] result = [ sqlparse.text_type(stmt).strip() for stmt in result if not is_empty_statement(stmt) ] if len(result) > 0: return result return [""] # if all statements were empty - return a single empty statement def combine_sql_statements(queries): return ";\n".join(queries) class Databricks(BaseSQLQueryRunner): noop_query = "SELECT 1" should_annotate_query = False @classmethod def type(cls): return "databricks" @classmethod def enabled(cls): return enabled @classmethod def configuration_schema(cls): return { "type": "object", "properties": { "host": {"type": "string"}, "http_path": {"type": "string", "title": "HTTP Path"}, # We're using `http_password` here for legacy reasons "http_password": {"type": "string", "title": "Access Token"}, }, "order": ["host", "http_path", "http_password"], "secret": ["http_password"], "required": ["host", "http_path", "http_password"], } def _get_cursor(self): user_agent = "Redash/{} (Databricks)".format(__version__.split("-")[0]) connection_string = _build_odbc_connection_string( Driver="Simba", UID="token", PORT="443", SSL="1", THRIFTTRANSPORT="2", SPARKSERVERTYPE="3", AUTHMECH=3, # Use the query as is without rewriting: UseNativeQuery="1", # Automatically reconnect to the cluster if an error occurs AutoReconnect="1", # Minimum interval between consecutive polls for query execution status (1ms) AsyncExecPollInterval="1", UserAgentEntry=user_agent, HOST=self.configuration["host"], PWD=self.configuration["http_password"], HTTPPath=self.configuration["http_path"], ) connection = pyodbc.connect(connection_string, autocommit=True) return connection.cursor() def run_query(self, query, user): try: cursor = self._get_cursor() statements = split_sql_statements(query) for stmt in statements: cursor.execute(stmt) if cursor.description is not None: result_set = cursor.fetchmany(ROW_LIMIT) columns = self.fetch_columns( [ (i[0], TYPES_MAP.get(i[1], TYPE_STRING)) for i in cursor.description ] ) rows = [ dict(zip((column["name"] for column in columns), row)) for row in result_set ] data = {"columns": columns, "rows": rows} if ( len(result_set) >= ROW_LIMIT and cursor.fetchone() is not None ): logger.warning("Truncated result set.") statsd_client.incr("redash.query_runner.databricks.truncated") data["truncated"] = True json_data = json_dumps(data) error = None else: error = None json_data = json_dumps( { "columns": [{"name": "result", "type": TYPE_STRING}], "rows": [{"result": "No data was returned."}], } ) cursor.close() except pyodbc.Error as e: if len(e.args) > 1: error = str(e.args[1]) else: error = str(e) json_data = None return json_data, error def get_schema(self): raise NotSupported() def get_databases(self): query = "SHOW DATABASES" results, error = self.run_query(query, None) if error is not None: raise Exception("Failed getting schema.") results = json_loads(results) first_column_name = results["columns"][0]["name"] return [row[first_column_name] for row in results["rows"]] def get_database_tables(self, database_name): schema = {} cursor = self._get_cursor() cursor.tables(schema=database_name) for table in cursor: table_name = "{}.{}".format(table[1], table[2]) if table_name not in schema: schema[table_name] = {"name": table_name, "columns": []} return list(schema.values()) def get_database_tables_with_columns(self, database_name): schema = {} cursor = self._get_cursor() # load tables first, otherwise tables without columns are not showed cursor.tables(schema=database_name) for table in cursor: table_name = "{}.{}".format(table[1], table[2]) if table_name not in schema: schema[table_name] = {"name": table_name, "columns": []} cursor.columns(schema=database_name) for column in cursor: table_name = "{}.{}".format(column[1], column[2]) if table_name not in schema: schema[table_name] = {"name": table_name, "columns": []} schema[table_name]["columns"].append({"name": column[3], "type": column[5]}) return list(schema.values()) def get_table_columns(self, database_name, table_name): cursor = self._get_cursor() cursor.columns(schema=database_name, table=table_name) return [{"name": column[3], "type": column[5]} for column in cursor] register(Databricks)