redash/authentication/saml_auth.py (189 lines of code) (raw):

import logging from flask import flash, redirect, url_for, Blueprint, request from redash import settings from redash.authentication import create_and_login_user, logout_and_redirect_to_index from redash.authentication.org_resolving import current_org from redash.handlers.base import org_scoped_rule from redash.utils import mustache_render from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT, entity from saml2.client import Saml2Client from saml2.config import Config as Saml2Config from saml2.saml import NAMEID_FORMAT_TRANSIENT from saml2.sigver import get_xmlsec_binary from saml2.mdstore import MetaDataExtern from urllib import parse import os from redash.authentication import get_next_path logger = logging.getLogger("saml_auth") blueprint = Blueprint("saml_auth", __name__) inline_metadata_template = """<?xml version="1.0" encoding="UTF-8"?><md:EntityDescriptor entityID="{{entity_id}}" xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"><md:IDPSSODescriptor WantAuthnRequestsSigned="false" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"><md:KeyDescriptor use="signing"><ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#"><ds:X509Data><ds:X509Certificate>{{x509_cert}}</ds:X509Certificate></ds:X509Data></ds:KeyInfo></md:KeyDescriptor><md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="{{sso_url}}"/><md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="{{sso_url}}"/></md:IDPSSODescriptor></md:EntityDescriptor>""" def get_saml_client(org, next_url): """ Return SAML configuration. The configuration is a hash for use by saml2.config.Config """ saml_type = org.get_setting("auth_saml_type") entity_id = org.get_setting("auth_saml_entity_id") sso_url = org.get_setting("auth_saml_sso_url") x509_cert = org.get_setting("auth_saml_x509_cert") metadata_url = org.get_setting("auth_saml_metadata_url") if settings.SAML_SCHEME_OVERRIDE: acs_url = url_for( "saml_auth.idp_initiated", org_slug=org.slug, next=next_url, _external=True, _scheme=settings.SAML_SCHEME_OVERRIDE, ) else: acs_url = url_for("saml_auth.idp_initiated", org_slug=org.slug, next=next_url, _external=True) saml_settings = { "metadata": {"remote": [{"url": metadata_url}]}, "service": { "sp": { "endpoints": { "assertion_consumer_service": [ (acs_url, BINDING_HTTP_REDIRECT), (acs_url, BINDING_HTTP_POST), ] }, # Don't verify that the incoming requests originate from us via # the built-in cache for authn request ids in pysaml2 "allow_unsolicited": True, # Don't sign authn requests, since signed requests only make # sense in a situation where you control both the SP and IdP "authn_requests_signed": False, "logout_requests_signed": True, "want_assertions_signed": True, "want_response_signed": False, } }, } if settings.SAML_ENCRYPTION_ENABLED: encryption_dict = { "xmlsec_binary": get_xmlsec_binary(), "encryption_keypairs": [ { "key_file": settings.SAML_ENCRYPTION_PEM_PATH, "cert_file": settings.SAML_ENCRYPTION_CERT_PATH, } ], } saml_settings.update(encryption_dict) if saml_type is not None and saml_type == "static": metadata_inline = mustache_render( inline_metadata_template, entity_id=entity_id, x509_cert=x509_cert, sso_url=sso_url, ) saml_settings["metadata"] = {"inline": [metadata_inline]} if entity_id is not None and entity_id != "": saml_settings["entityid"] = entity_id sp_config = Saml2Config() sp_config.load(saml_settings) sp_config.allow_unknown_attributes = True saml_client = Saml2Client(config=sp_config) return saml_client @blueprint.route(org_scoped_rule("/saml/callback"), methods=["POST"]) def idp_initiated(org_slug=None): if not current_org.get_setting("auth_saml_enabled"): logger.error("SAML Login is not enabled") return redirect(url_for("redash.index", org_slug=org_slug)) index_url = url_for("redash.index", org_slug=org_slug) unsafe_next_path = request.args.get("next", index_url) next_path = get_next_path(unsafe_next_path) saml_client = get_saml_client(current_org, next_url=next_path) saml_client_urls_upgrade(saml_client) try: authn_response = saml_client.parse_authn_request_response( request.form["SAMLResponse"], entity.BINDING_HTTP_POST ) except Exception: logger.error("Failed to parse SAML response", exc_info=True) flash("SAML login failed. Please try again later.") return redirect(url_for("redash.login", org_slug=org_slug)) authn_response.get_identity() user_info = authn_response.get_subject() email = user_info.text try: name = "%s %s" % (authn_response.ava['firstName'][0], authn_response.ava['lastName'][0]) except Exception: name = email.split('@')[0] attributes = {} if authn_response.ava: for k, v in authn_response.ava.items(): if len(v) == 1: attributes[k] = v[0] else: attributes[k] = v # name = "%s %s" % ( # authn_response.ava["FirstName"][0], # authn_response.ava["LastName"][0], # ) # This is what as known as "Just In Time (JIT) provisioning". # What that means is that, if a user in a SAML assertion # isn't in the user store, we create that user first, then log them in user = create_and_login_user(current_org, name, email, attributes=attributes) if user is None: return logout_and_redirect_to_index() if "RedashGroups" in authn_response.ava: group_names = authn_response.ava.get("RedashGroups") user.update_group_assignments(group_names) # url = url_for("redash.index", org_slug=org_slug) return redirect(next_path) @blueprint.route(org_scoped_rule("/saml/login")) def sp_initiated(org_slug=None): if not current_org.get_setting("auth_saml_enabled"): logger.error("SAML Login is not enabled") return redirect(url_for("redash.index", org_slug=org_slug)) index_url = url_for("redash.index", org_slug=org_slug) unsafe_next_path = request.args.get("next", index_url) next_path = get_next_path(unsafe_next_path) saml_client = get_saml_client(current_org, next_url=next_path) nameid_format = current_org.get_setting("auth_saml_nameid_format") if nameid_format is None or nameid_format == "": nameid_format = NAMEID_FORMAT_TRANSIENT saml_client_urls_upgrade(saml_client) _, info = saml_client.prepare_for_authenticate(nameid_format=nameid_format) redirect_url = None # Select the IdP URL to send the AuthN request to for key, value in info["headers"]: if key == "Location": redirect_url = value response = redirect(redirect_url, code=302) # NOTE: # I realize I _technically_ don't need to set Cache-Control or Pragma: # https://stackoverflow.com/a/5494469 # However, Section 3.2.3.2 of the SAML spec suggests they are set: # http://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf # We set those headers here as a "belt and suspenders" approach, # since enterprise environments don't always conform to RFCs response.headers["Cache-Control"] = "no-cache, no-store" response.headers["Pragma"] = "no-cache" return response def saml_client_urls_upgrade(saml_client): saml_redirect_url = os.getenv('REDASH_SAML_REDIRECT_URL') if saml_redirect_url: eids = saml_client.metadata.with_descriptor("idpsso") default_netloc = parse.urlparse(list(eids.keys())[0]) new_netloc = parse.urlparse(saml_redirect_url) keys_replacer(saml_client.metadata.metadata, default_netloc, new_netloc) location_replacer(saml_client.metadata.metadata, default_netloc, new_netloc) keys_replacer(eids, default_netloc, new_netloc) location_replacer(eids, default_netloc, new_netloc) def keys_replacer(storage, default_netloc, new_netloc): new_keys = {} delete_keys = [] for k, v in m_enum(storage): if type(k) is str and default_netloc.netloc in k: new_k = location_replace(k, default_netloc, new_netloc) new_keys[new_k] = v delete_keys.append(k) if type(v) is dict or type(v) is list or isinstance(v, MetaDataExtern): keys_replacer(v, default_netloc, new_netloc) storage[k] = v for k in delete_keys: del storage[k] for k, v in new_keys.items(): storage[k] = v def location_replacer(storage, default_netloc, new_netloc): if storage is None: return storage for k, v in m_enum(storage): if k == 'location': storage[k] = location_replace(v, default_netloc, new_netloc) continue if type(v) is dict or type(v) is list: location_replacer(v, default_netloc, new_netloc) storage[k] = v return storage def location_replace(location_url, default_netloc, new_netloc): location_parsed = parse.urlparse(location_url) if location_parsed.netloc == default_netloc.netloc: return parse.urlunparse( location_parsed._replace(netloc=new_netloc.netloc). _replace(scheme=new_netloc.scheme)) return location_url def m_enum(d): if type(d) is dict or isinstance(d, MetaDataExtern): return d.items() if type(d) is list: return enumerate(d) return d