redash/query_runner/pg.py (405 lines of code) (raw):
import os
import logging
import select
from contextlib import contextmanager
from base64 import b64decode
from tempfile import NamedTemporaryFile
from uuid import uuid4
import psycopg2
from psycopg2.extras import Range
from redash.query_runner import *
from redash.utils import JSONEncoder, json_dumps, json_loads
logger = logging.getLogger(__name__)
try:
import boto3
IAM_ENABLED = True
except ImportError:
IAM_ENABLED = False
types_map = {
20: TYPE_INTEGER,
21: TYPE_INTEGER,
23: TYPE_INTEGER,
700: TYPE_FLOAT,
1700: TYPE_FLOAT,
701: TYPE_FLOAT,
16: TYPE_BOOLEAN,
1082: TYPE_DATE,
1114: TYPE_DATETIME,
1184: TYPE_DATETIME,
1014: TYPE_STRING,
1015: TYPE_STRING,
1008: TYPE_STRING,
1009: TYPE_STRING,
2951: TYPE_STRING,
}
class PostgreSQLJSONEncoder(JSONEncoder):
def default(self, o):
if isinstance(o, Range):
# From: https://github.com/psycopg/psycopg2/pull/779
if o._bounds is None:
return ""
items = [o._bounds[0], str(o._lower), ", ", str(o._upper), o._bounds[1]]
return "".join(items)
return super(PostgreSQLJSONEncoder, self).default(o)
def _wait(conn, timeout=None):
while 1:
try:
state = conn.poll()
if state == psycopg2.extensions.POLL_OK:
break
elif state == psycopg2.extensions.POLL_WRITE:
select.select([], [conn.fileno()], [], timeout)
elif state == psycopg2.extensions.POLL_READ:
select.select([conn.fileno()], [], [], timeout)
else:
raise psycopg2.OperationalError("poll() returned %s" % state)
except select.error:
raise psycopg2.OperationalError("select.error received")
def full_table_name(schema, name):
if "." in name:
name = '"{}"'.format(name)
return "{}.{}".format(schema, name)
def build_schema(query_result, schema):
# By default we omit the public schema name from the table name. But there are
# edge cases, where this might cause conflicts. For example:
# * We have a schema named "main" with table "users".
# * We have a table named "main.users" in the public schema.
# (while this feels unlikely, this actually happened)
# In this case if we omit the schema name for the public table, we will have
# a conflict.
table_names = set(
map(
lambda r: full_table_name(r["table_schema"], r["table_name"]),
query_result["rows"],
)
)
for row in query_result["rows"]:
if row["table_schema"] != "public":
table_name = full_table_name(row["table_schema"], row["table_name"])
else:
if row["table_name"] in table_names:
table_name = full_table_name(row["table_schema"], row["table_name"])
else:
table_name = row["table_name"]
if table_name not in schema:
schema[table_name] = {"name": table_name, "columns": []}
column = row["column_name"]
if row.get("data_type") is not None:
column = {"name": row["column_name"], "type": row["data_type"]}
schema[table_name]["columns"].append(column)
def _create_cert_file(configuration, key, ssl_config):
file_key = key + "File"
if file_key in configuration:
with NamedTemporaryFile(mode="w", delete=False) as cert_file:
cert_bytes = b64decode(configuration[file_key])
cert_file.write(cert_bytes.decode("utf-8"))
ssl_config[key] = cert_file.name
def _cleanup_ssl_certs(ssl_config):
for k, v in ssl_config.items():
if k != "sslmode":
os.remove(v)
def _get_ssl_config(configuration):
ssl_config = {"sslmode": configuration.get("sslmode", "prefer")}
_create_cert_file(configuration, "sslrootcert", ssl_config)
_create_cert_file(configuration, "sslcert", ssl_config)
_create_cert_file(configuration, "sslkey", ssl_config)
return ssl_config
class PostgreSQL(BaseSQLQueryRunner):
noop_query = "SELECT 1"
@classmethod
def configuration_schema(cls):
return {
"type": "object",
"properties": {
"user": {"type": "string"},
"password": {"type": "string"},
"host": {"type": "string", "default": "127.0.0.1"},
"port": {"type": "number", "default": 5432},
"dbname": {"type": "string", "title": "Database Name"},
"sslmode": {
"type": "string",
"title": "SSL Mode",
"default": "prefer",
"extendedEnum": [
{"value": "disable", "name": "Disable"},
{"value": "allow", "name": "Allow"},
{"value": "prefer", "name": "Prefer"},
{"value": "require", "name": "Require"},
{"value": "verify-ca", "name": "Verify CA"},
{"value": "verify-full", "name": "Verify Full"},
],
},
"sslrootcertFile": {"type": "string", "title": "SSL Root Certificate"},
"sslcertFile": {"type": "string", "title": "SSL Client Certificate"},
"sslkeyFile": {"type": "string", "title": "SSL Client Key"},
},
"order": ["host", "port", "user", "password"],
"required": ["dbname"],
"secret": ["password", "sslrootcertFile", "sslcertFile", "sslkeyFile"],
"extra_options": [
"sslmode",
"sslrootcertFile",
"sslcertFile",
"sslkeyFile",
],
}
@classmethod
def type(cls):
return "pg"
def _get_definitions(self, schema, query):
results, error = self.run_query(query, None)
if error is not None:
raise Exception("Failed getting schema.")
results = json_loads(results)
build_schema(results, schema)
def _get_tables(self, schema):
"""
relkind constants per https://www.postgresql.org/docs/10/static/catalog-pg-class.html
r = regular table
v = view
m = materialized view
f = foreign table
p = partitioned table (new in 10)
---
i = index
S = sequence
t = TOAST table
c = composite type
"""
query = """
SELECT s.nspname as table_schema,
c.relname as table_name,
a.attname as column_name,
null as data_type
FROM pg_class c
JOIN pg_namespace s
ON c.relnamespace = s.oid
AND s.nspname NOT IN ('pg_catalog', 'information_schema')
JOIN pg_attribute a
ON a.attrelid = c.oid
AND a.attnum > 0
AND NOT a.attisdropped
WHERE c.relkind IN ('m', 'f', 'p')
UNION
SELECT table_schema,
table_name,
column_name,
data_type
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
"""
self._get_definitions(schema, query)
return list(schema.values())
def _get_connection(self):
self.ssl_config = _get_ssl_config(self.configuration)
connection = psycopg2.connect(
user=self.configuration.get("user"),
password=self.configuration.get("password"),
host=self.configuration.get("host"),
port=self.configuration.get("port"),
dbname=self.configuration.get("dbname"),
async_=True,
**self.ssl_config,
)
return connection
def run_query(self, query, user):
connection = self._get_connection()
_wait(connection, timeout=10)
cursor = connection.cursor()
try:
cursor.execute(query)
_wait(connection)
if cursor.description is not None:
columns = self.fetch_columns(
[(i[0], types_map.get(i[1], None)) for i in cursor.description]
)
rows = [
dict(zip((column["name"] for column in columns), row))
for row in cursor
]
data = {"columns": columns, "rows": rows}
error = None
json_data = json_dumps(data, ignore_nan=True, cls=PostgreSQLJSONEncoder)
else:
error = "Query completed but it returned no data."
json_data = None
except (select.error, OSError) as e:
error = "Query interrupted. Please retry."
json_data = None
except psycopg2.DatabaseError as e:
error = str(e)
json_data = None
except (KeyboardInterrupt, InterruptException, JobTimeoutException):
connection.cancel()
raise
finally:
connection.close()
_cleanup_ssl_certs(self.ssl_config)
return json_data, error
class Redshift(PostgreSQL):
@classmethod
def type(cls):
return "redshift"
@classmethod
def name(cls):
return "Redshift"
def _get_connection(self):
self.ssl_config = {}
sslrootcert_path = os.path.join(
os.path.dirname(__file__), "./files/redshift-ca-bundle.crt"
)
connection = psycopg2.connect(
user=self.configuration.get("user"),
password=self.configuration.get("password"),
host=self.configuration.get("host"),
port=self.configuration.get("port"),
dbname=self.configuration.get("dbname"),
sslmode=self.configuration.get("sslmode", "prefer"),
sslrootcert=sslrootcert_path,
async_=True,
)
return connection
@classmethod
def configuration_schema(cls):
return {
"type": "object",
"properties": {
"user": {"type": "string"},
"password": {"type": "string"},
"host": {"type": "string"},
"port": {"type": "number"},
"dbname": {"type": "string", "title": "Database Name"},
"sslmode": {"type": "string", "title": "SSL Mode", "default": "prefer"},
"adhoc_query_group": {
"type": "string",
"title": "Query Group for Adhoc Queries",
"default": "default",
},
"scheduled_query_group": {
"type": "string",
"title": "Query Group for Scheduled Queries",
"default": "default",
},
},
"order": [
"host",
"port",
"user",
"password",
"dbname",
"sslmode",
"adhoc_query_group",
"scheduled_query_group",
],
"required": ["dbname", "user", "password", "host", "port"],
"secret": ["password"],
}
def annotate_query(self, query, metadata):
annotated = super(Redshift, self).annotate_query(query, metadata)
if metadata.get("Scheduled", False):
query_group = self.configuration.get("scheduled_query_group")
else:
query_group = self.configuration.get("adhoc_query_group")
if query_group:
set_query_group = "set query_group to {};".format(query_group)
annotated = "{}\n{}".format(set_query_group, annotated)
return annotated
def _get_tables(self, schema):
# Use svv_columns to include internal & external (Spectrum) tables and views data for Redshift
# https://docs.aws.amazon.com/redshift/latest/dg/r_SVV_COLUMNS.html
# Use HAS_SCHEMA_PRIVILEGE(), SVV_EXTERNAL_SCHEMAS and HAS_TABLE_PRIVILEGE() to filter
# out tables the current user cannot access.
# https://docs.aws.amazon.com/redshift/latest/dg/r_HAS_SCHEMA_PRIVILEGE.html
# https://docs.aws.amazon.com/redshift/latest/dg/r_SVV_EXTERNAL_SCHEMAS.html
# https://docs.aws.amazon.com/redshift/latest/dg/r_HAS_TABLE_PRIVILEGE.html
query = """
WITH tables AS (
SELECT DISTINCT table_name,
table_schema,
column_name,
ordinal_position AS pos
FROM svv_columns
WHERE table_schema NOT IN ('pg_internal','pg_catalog','information_schema')
AND table_schema NOT LIKE 'pg_temp_%'
)
SELECT table_name, table_schema, column_name
FROM tables
WHERE
HAS_SCHEMA_PRIVILEGE(table_schema, 'USAGE') AND
(
table_schema IN (SELECT schemaname FROM SVV_EXTERNAL_SCHEMAS) OR
HAS_TABLE_PRIVILEGE('"' || table_schema || '"."' || table_name || '"', 'SELECT')
)
ORDER BY table_name, pos
"""
self._get_definitions(schema, query)
return list(schema.values())
class RedshiftIAM(Redshift):
@classmethod
def type(cls):
return "redshift_iam"
@classmethod
def name(cls):
return "Redshift (with IAM User/Role)"
@classmethod
def enabled(cls):
return IAM_ENABLED
def _login_method_selection(self):
if self.configuration.get("rolename"):
if not self.configuration.get(
"aws_access_key_id"
) or not self.configuration.get("aws_secret_access_key"):
return "ASSUME_ROLE_NO_KEYS"
else:
return "ASSUME_ROLE_KEYS"
elif self.configuration.get("aws_access_key_id") and self.configuration.get(
"aws_secret_access_key"
):
return "KEYS"
elif not self.configuration.get("password"):
return "ROLE"
@classmethod
def configuration_schema(cls):
return {
"type": "object",
"properties": {
"rolename": {"type": "string", "title": "IAM Role Name"},
"aws_region": {"type": "string", "title": "AWS Region"},
"aws_access_key_id": {"type": "string", "title": "AWS Access Key ID"},
"aws_secret_access_key": {
"type": "string",
"title": "AWS Secret Access Key",
},
"clusterid": {"type": "string", "title": "Redshift Cluster ID"},
"user": {"type": "string"},
"host": {"type": "string"},
"port": {"type": "number"},
"dbname": {"type": "string", "title": "Database Name"},
"sslmode": {"type": "string", "title": "SSL Mode", "default": "prefer"},
"adhoc_query_group": {
"type": "string",
"title": "Query Group for Adhoc Queries",
"default": "default",
},
"scheduled_query_group": {
"type": "string",
"title": "Query Group for Scheduled Queries",
"default": "default",
},
},
"order": [
"rolename",
"aws_region",
"aws_access_key_id",
"aws_secret_access_key",
"clusterid",
"host",
"port",
"user",
"dbname",
"sslmode",
"adhoc_query_group",
"scheduled_query_group",
],
"required": ["dbname", "user", "host", "port", "aws_region"],
"secret": ["aws_secret_access_key"],
}
def _get_connection(self):
sslrootcert_path = os.path.join(
os.path.dirname(__file__), "./files/redshift-ca-bundle.crt"
)
login_method = self._login_method_selection()
if login_method == "KEYS":
client = boto3.client(
"redshift",
region_name=self.configuration.get("aws_region"),
aws_access_key_id=self.configuration.get("aws_access_key_id"),
aws_secret_access_key=self.configuration.get("aws_secret_access_key"),
)
elif login_method == "ROLE":
client = boto3.client(
"redshift", region_name=self.configuration.get("aws_region")
)
else:
if login_method == "ASSUME_ROLE_KEYS":
assume_client = client = boto3.client(
"sts",
region_name=self.configuration.get("aws_region"),
aws_access_key_id=self.configuration.get("aws_access_key_id"),
aws_secret_access_key=self.configuration.get(
"aws_secret_access_key"
),
)
else:
assume_client = client = boto3.client(
"sts", region_name=self.configuration.get("aws_region")
)
role_session = f"redash_{uuid4().hex}"
session_keys = assume_client.assume_role(
RoleArn=self.configuration.get("rolename"), RoleSessionName=role_session
)["Credentials"]
client = boto3.client(
"redshift",
region_name=self.configuration.get("aws_region"),
aws_access_key_id=session_keys["AccessKeyId"],
aws_secret_access_key=session_keys["SecretAccessKey"],
aws_session_token=session_keys["SessionToken"],
)
credentials = client.get_cluster_credentials(
DbUser=self.configuration.get("user"),
DbName=self.configuration.get("dbname"),
ClusterIdentifier=self.configuration.get("clusterid"),
)
db_user = credentials["DbUser"]
db_password = credentials["DbPassword"]
connection = psycopg2.connect(
user=db_user,
password=db_password,
host=self.configuration.get("host"),
port=self.configuration.get("port"),
dbname=self.configuration.get("dbname"),
sslmode=self.configuration.get("sslmode", "prefer"),
sslrootcert=sslrootcert_path,
async_=True,
)
return connection
class CockroachDB(PostgreSQL):
@classmethod
def type(cls):
return "cockroach"
register(PostgreSQL)
register(Redshift)
register(RedshiftIAM)
register(CockroachDB)