redash/query_runner/__init__.py (317 lines of code) (raw):

import logging from contextlib import ExitStack from dateutil import parser from functools import wraps import socket import ipaddress from urllib.parse import urlparse from six import text_type from sshtunnel import open_tunnel from redash import settings, utils from redash.utils import json_loads, query_is_select_no_limit, add_limit_to_query from rq.timeouts import JobTimeoutException from redash.utils.requests_session import requests_or_advocate, requests_session, UnacceptableAddressException logger = logging.getLogger(__name__) __all__ = [ "BaseQueryRunner", "BaseHTTPQueryRunner", "InterruptException", "JobTimeoutException", "BaseSQLQueryRunner", "TYPE_DATETIME", "TYPE_BOOLEAN", "TYPE_INTEGER", "TYPE_STRING", "TYPE_DATE", "TYPE_FLOAT", "SUPPORTED_COLUMN_TYPES", "register", "get_query_runner", "import_query_runners", "guess_type", ] # Valid types of columns returned in results: TYPE_INTEGER = "integer" TYPE_FLOAT = "float" TYPE_BOOLEAN = "boolean" TYPE_STRING = "string" TYPE_DATETIME = "datetime" TYPE_DATE = "date" SUPPORTED_COLUMN_TYPES = set( [TYPE_INTEGER, TYPE_FLOAT, TYPE_BOOLEAN, TYPE_STRING, TYPE_DATETIME, TYPE_DATE] ) class InterruptException(Exception): pass class NotSupported(Exception): pass class BaseQueryRunner(object): deprecated = False should_annotate_query = True noop_query = None def __init__(self, configuration): self.syntax = "sql" self.configuration = configuration @classmethod def name(cls): return cls.__name__ @classmethod def type(cls): return cls.__name__.lower() @classmethod def enabled(cls): return True @property def host(self): """Returns this query runner's configured host. This is used primarily for temporarily swapping endpoints when using SSH tunnels to connect to a data source. `BaseQueryRunner`'s naïve implementation supports query runner implementations that store endpoints using `host` and `port` configuration values. If your query runner uses a different schema (e.g. a web address), you should override this function. """ if "host" in self.configuration: return self.configuration["host"] else: raise NotImplementedError() @host.setter def host(self, host): """Sets this query runner's configured host. This is used primarily for temporarily swapping endpoints when using SSH tunnels to connect to a data source. `BaseQueryRunner`'s naïve implementation supports query runner implementations that store endpoints using `host` and `port` configuration values. If your query runner uses a different schema (e.g. a web address), you should override this function. """ if "host" in self.configuration: self.configuration["host"] = host else: raise NotImplementedError() @property def port(self): """Returns this query runner's configured port. This is used primarily for temporarily swapping endpoints when using SSH tunnels to connect to a data source. `BaseQueryRunner`'s naïve implementation supports query runner implementations that store endpoints using `host` and `port` configuration values. If your query runner uses a different schema (e.g. a web address), you should override this function. """ if "port" in self.configuration: return self.configuration["port"] else: raise NotImplementedError() @port.setter def port(self, port): """Sets this query runner's configured port. This is used primarily for temporarily swapping endpoints when using SSH tunnels to connect to a data source. `BaseQueryRunner`'s naïve implementation supports query runner implementations that store endpoints using `host` and `port` configuration values. If your query runner uses a different schema (e.g. a web address), you should override this function. """ if "port" in self.configuration: self.configuration["port"] = port else: raise NotImplementedError() @classmethod def configuration_schema(cls): return {} def annotate_query(self, query, metadata): if not self.should_annotate_query: return query annotation = ", ".join(["{}: {}".format(k, v) for k, v in metadata.items()]) annotated_query = "/* {} */ {}".format(annotation, query) return annotated_query def test_connection(self): if self.noop_query is None: raise NotImplementedError() data, error = self.run_query(self.noop_query, None) if error is not None: raise Exception(error) def run_query(self, query, user): raise NotImplementedError() def fetch_columns(self, columns): column_names = [] duplicates_counter = 1 new_columns = [] for col in columns: column_name = col[0] if column_name in column_names: column_name = "{}{}".format(column_name, duplicates_counter) duplicates_counter += 1 column_names.append(column_name) new_columns.append( {"name": column_name, "friendly_name": column_name, "type": col[1]} ) return new_columns def get_schema(self, get_stats=False): raise NotSupported() def _run_query_internal(self, query): results, error = self.run_query(query, None) if error is not None: raise Exception("Failed running query [%s]." % query) return json_loads(results)["rows"] @classmethod def to_dict(cls): return { "name": cls.name(), "type": cls.type(), "configuration_schema": cls.configuration_schema(), **({"deprecated": True} if cls.deprecated else {}), } @property def supports_auto_limit(self): return False def apply_auto_limit(self, query_text, should_apply_auto_limit): return query_text def gen_query_hash(self, query_text, set_auto_limit=False): query_text = self.apply_auto_limit(query_text, set_auto_limit) return utils.gen_query_hash(query_text) class BaseSQLQueryRunner(BaseQueryRunner): def get_schema(self, get_stats=False): schema_dict = {} self._get_tables(schema_dict) if settings.SCHEMA_RUN_TABLE_SIZE_CALCULATIONS and get_stats: self._get_tables_stats(schema_dict) return list(schema_dict.values()) def _get_tables(self, schema_dict): return [] def _get_tables_stats(self, tables_dict): for t in tables_dict.keys(): if type(tables_dict[t]) == dict: res = self._run_query_internal("select count(*) as cnt from %s" % t) tables_dict[t]["size"] = res[0]["cnt"] @property def supports_auto_limit(self): return True def apply_auto_limit(self, query_text, should_apply_auto_limit): if should_apply_auto_limit: from redash.query_runner.databricks import split_sql_statements, combine_sql_statements queries = split_sql_statements(query_text) # we only check for last one in the list because it is the one that we show result last_query = queries[-1] if query_is_select_no_limit(last_query): queries[-1] = add_limit_to_query(last_query) return combine_sql_statements(queries) else: return query_text class BaseHTTPQueryRunner(BaseQueryRunner): should_annotate_query = False response_error = "Endpoint returned unexpected status code" requires_authentication = False requires_url = True url_title = "URL base path" username_title = "HTTP Basic Auth Username" password_title = "HTTP Basic Auth Password" @classmethod def configuration_schema(cls): schema = { "type": "object", "properties": { "url": {"type": "string", "title": cls.url_title}, "username": {"type": "string", "title": cls.username_title}, "password": {"type": "string", "title": cls.password_title}, }, "secret": ["password"], "order": ["url", "username", "password"], } if cls.requires_url or cls.requires_authentication: schema["required"] = [] if cls.requires_url: schema["required"] += ["url"] if cls.requires_authentication: schema["required"] += ["username", "password"] return schema def get_auth(self): username = self.configuration.get("username") password = self.configuration.get("password") if username and password: return (username, password) if self.requires_authentication: raise ValueError("Username and Password required") else: return None def get_response(self, url, auth=None, http_method="get", **kwargs): # Get authentication values if not given if auth is None: auth = self.get_auth() # Then call requests to get the response from the given endpoint # URL optionally, with the additional requests parameters. error = None response = None try: response = requests_session.request(http_method, url, auth=auth, **kwargs) # Raise a requests HTTP exception with the appropriate reason # for 4xx and 5xx response status codes which is later caught # and passed back. response.raise_for_status() # Any other responses (e.g. 2xx and 3xx): if response.status_code != 200: error = "{} ({}).".format(self.response_error, response.status_code) except requests_or_advocate.HTTPError as exc: logger.exception(exc) error = "Failed to execute query. " "Return Code: {} Reason: {}".format( response.status_code, response.text ) except UnacceptableAddressException as exc: logger.exception(exc) error = "Can't query private addresses." except requests_or_advocate.RequestException as exc: # Catch all other requests exceptions and return the error. logger.exception(exc) error = str(exc) # Return response and error. return response, error query_runners = {} def register(query_runner_class): global query_runners if query_runner_class.enabled(): logger.debug( "Registering %s (%s) query runner.", query_runner_class.name(), query_runner_class.type(), ) query_runners[query_runner_class.type()] = query_runner_class else: logger.debug( "%s query runner enabled but not supported, not registering. Either disable or install missing " "dependencies.", query_runner_class.name(), ) def get_query_runner(query_runner_type, configuration): query_runner_class = query_runners.get(query_runner_type, None) if query_runner_class is None: return None return query_runner_class(configuration) def get_configuration_schema_for_query_runner_type(query_runner_type): query_runner_class = query_runners.get(query_runner_type, None) if query_runner_class is None: return None return query_runner_class.configuration_schema() def import_query_runners(query_runner_imports): for runner_import in query_runner_imports: __import__(runner_import) def guess_type(value): if isinstance(value, bool): return TYPE_BOOLEAN elif isinstance(value, int): return TYPE_INTEGER elif isinstance(value, float): return TYPE_FLOAT return guess_type_from_string(value) def guess_type_from_string(string_value): if string_value == "" or string_value is None: return TYPE_STRING try: int(string_value) return TYPE_INTEGER except (ValueError, OverflowError): pass try: float(string_value) return TYPE_FLOAT except (ValueError, OverflowError): pass if str(string_value).lower() in ("true", "false"): return TYPE_BOOLEAN try: parser.parse(string_value) return TYPE_DATETIME except (ValueError, OverflowError): pass return TYPE_STRING def with_ssh_tunnel(query_runner, details): def tunnel(f): @wraps(f) def wrapper(*args, **kwargs): try: remote_host, remote_port = query_runner.host, query_runner.port except NotImplementedError: raise NotImplementedError( "SSH tunneling is not implemented for this query runner yet." ) stack = ExitStack() try: bastion_address = (details["ssh_host"], details.get("ssh_port", 22)) remote_address = (remote_host, remote_port) auth = { "ssh_username": details["ssh_username"], **settings.dynamic_settings.ssh_tunnel_auth(), } server = stack.enter_context( open_tunnel( bastion_address, remote_bind_address=remote_address, **auth ) ) except Exception as error: raise type(error)("SSH tunnel: {}".format(str(error))) with stack: try: query_runner.host, query_runner.port = server.local_bind_address result = f(*args, **kwargs) finally: query_runner.host, query_runner.port = remote_host, remote_port return result return wrapper query_runner.run_query = tunnel(query_runner.run_query) return query_runner