kong/kong/db/strategies/postgres/connector.lua (784 lines of code) (raw):

local logger = require "kong.cmd.utils.log" local utils = require "kong.tools.utils" local pgmoon = require "pgmoon" local arrays = require "pgmoon.arrays" local stringx = require "pl.stringx" local semaphore = require "ngx.semaphore" local kong_global = require "kong.global" local constants = require "kong.constants" local setmetatable = setmetatable local encode_array = arrays.encode_array local tonumber = tonumber local tostring = tostring local concat = table.concat local ipairs = ipairs local pairs = pairs local error = error local floor = math.floor local type = type local ngx = ngx local timer_every = ngx.timer.every local update_time = ngx.update_time local get_phase = ngx.get_phase local null = ngx.null local now = ngx.now local log = ngx.log local match = string.match local fmt = string.format local sub = string.sub local utils_toposort = utils.topological_sort local insert = table.insert local WARN = ngx.WARN local ERR = ngx.ERR local SQL_INFORMATION_SCHEMA_TABLES = [[ SELECT table_name FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA; ]] local PROTECTED_TABLES = { schema_migrations = true, schema_meta = true, locks = true, parameters = true, } local OPERATIONS = { read = true, write = true, } local ADMIN_API_PHASE = kong_global.phases.admin_api local CORE_ENTITIES = constants.CORE_ENTITIES local function now_updated() update_time() return now() end local function iterator(rows) local i = 0 return function() i = i + 1 return rows[i] end end local function get_table_names(self, excluded) local i = 0 local table_names = {} for row, err in self:iterate(SQL_INFORMATION_SCHEMA_TABLES) do if err then return nil, err end if not excluded or not excluded[row.table_name] then i = i + 1 table_names[i] = self:escape_identifier(row.table_name) end end return table_names end local get_names_of_tables_with_ttl do local CORE_SCORE = {} for _, v in ipairs(CORE_ENTITIES) do CORE_SCORE[v] = 1 end CORE_SCORE["workspaces"] = 2 local function sort_core_tables_first(a, b) local sa = CORE_SCORE[a] or 0 local sb = CORE_SCORE[b] or 0 if sa == sb then -- sort tables in reverse order so that they end up sorted alphabetically, -- because utils_topological sort does "dependencies first" and then current. return a > b end return sa < sb end local sort = table.sort get_names_of_tables_with_ttl = function(strategies) local s local ttl_schemas_by_name = {} local table_names = {} for _, strategy in pairs(strategies) do s = strategy.schema if s.ttl then table_names[#table_names + 1] = s.name ttl_schemas_by_name[s.name] = s end end sort(table_names, sort_core_tables_first) local get_table_name_neighbors = function(table_name) local neighbors = {} local neighbors_len = 0 local neighbor local schema = ttl_schemas_by_name[table_name] for _, field in schema:each_field() do if field.type == "foreign" and field.schema.ttl then neighbor = field.reference if ttl_schemas_by_name[neighbor] then -- the neighbor schema name is on table_names neighbors_len = neighbors_len + 1 neighbors[neighbors_len] = neighbor end -- else the neighbor points to an unknown/uninteresting schema. This happens in tests. end end return neighbors end local res, err = utils_toposort(table_names, get_table_name_neighbors) if res then insert(res, 1, "cluster_events") end return res, err end end local function reset_schema(self) local table_names, err = get_table_names(self) if not table_names then return nil, err end local drop_tables if #table_names == 0 then drop_tables = "" else drop_tables = concat { " DROP TABLE IF EXISTS ", concat(table_names, ", "), " CASCADE;\n" } end local schema = self:escape_identifier(self.config.schema) local ok, err = self:query(concat { "BEGIN;\n", " DO $$\n", " BEGIN\n", " DROP SCHEMA IF EXISTS ", schema, " CASCADE;\n", " CREATE SCHEMA IF NOT EXISTS ", schema, " AUTHORIZATION CURRENT_USER;\n", " GRANT ALL ON SCHEMA ", schema ," TO CURRENT_USER;\n", " EXCEPTION WHEN insufficient_privilege THEN\n", drop_tables, " END;\n", " $$;\n", " SET SCHEMA ", self:escape_literal(self.config.schema), ";\n", "COMMIT;", }) if not ok then return nil, err end return true end local setkeepalive local function reconnect(config) local phase = get_phase() if phase == "init" or phase == "init_worker" or ngx.IS_CLI then -- Force LuaSocket usage in the CLI in order to allow for self-signed -- certificates to be trusted (via opts.cafile) in the resty-cli -- interpreter (no way to set lua_ssl_trusted_certificate). config.socket_type = "luasocket" else config.socket_type = "nginx" end local connection = pgmoon.new(config) connection.convert_null = true connection.NULL = null if config.timeout then connection:settimeout(config.timeout) end local ok, err = connection:connect() if not ok then return nil, err end if connection.sock:getreusedtimes() == 0 then if config.schema == "" then local res = connection:query("SELECT CURRENT_SCHEMA AS schema") if res and res[1] and res[1].schema and res[1].schema ~= null then config.schema = res[1].schema else config.schema = "public" end end ok, err = connection:query(concat { "SET SCHEMA ", connection:escape_literal(config.schema), ";\n", "SET TIME ZONE ", connection:escape_literal("UTC"), ";", }) if not ok then setkeepalive(connection) return nil, err end end return connection end local function connect(config) return kong.vault.try(reconnect, config) end setkeepalive = function(connection) if not connection or not connection.sock then return true end if connection.sock_type == "luasocket" then local _, err = connection:disconnect() if err then return nil, err end else local _, err = connection:keepalive() if err then return nil, err end end return true end local _mt = { reset = reset_schema } _mt.__index = _mt function _mt:get_stored_connection(operation) local conn = self.super.get_stored_connection(self, operation) if conn and conn.sock then return conn end end function _mt:init() local res, err = self:query("SHOW server_version_num;") local ver = tonumber(res and res[1] and res[1].server_version_num) if not ver then return nil, "failed to retrieve PostgreSQL server_version_num: " .. err end local major = floor(ver / 10000) if major < 10 then self.major_version = tonumber(fmt("%u.%u", major, floor(ver / 100 % 100))) self.major_minor_version = fmt("%u.%u.%u", major, floor(ver / 100 % 100), ver % 100) else self.major_version = major self.major_minor_version = fmt("%u.%u", major, ver % 100) end return true end function _mt:init_worker(strategies) if ngx.worker.id() == 0 then local table_names = get_names_of_tables_with_ttl(strategies) local ttl_escaped = self:escape_identifier("ttl") local expire_at_escaped = self:escape_identifier("expire_at") local cleanup_statements = {} local cleanup_statements_count = #table_names for i = 1, cleanup_statements_count do local table_name = table_names[i] local column_name = table_name == "cluster_events" and expire_at_escaped or ttl_escaped cleanup_statements[i] = concat { " DELETE FROM ", self:escape_identifier(table_name), " WHERE ", column_name, " < CURRENT_TIMESTAMP AT TIME ZONE 'UTC';" } end local cleanup_statement = concat(cleanup_statements, "\n") return timer_every(60, function(premature) if premature then return end local ok, err, _, num_queries = self:query(cleanup_statement) if not ok then if num_queries then for i = num_queries + 1, cleanup_statements_count do local statement = cleanup_statements[i] local ok, err = self:query(statement) if not ok then if err then log(WARN, "unable to clean expired rows from table '", table_names[i], "' on PostgreSQL database (", err, ")") else log(WARN, "unable to clean expired rows from table '", table_names[i], "' on PostgreSQL database") end end end else log(ERR, "unable to clean expired rows from PostgreSQL database (", err, ")") end end end) end return true end function _mt:infos() local db_ver if self.major_minor_version then db_ver = match(self.major_minor_version, "^(%d+%.%d+)") end return { strategy = "PostgreSQL", db_name = self.config.database, db_schema = self.config.schema, db_desc = "database", db_ver = db_ver or "unknown", db_readonly = self.config_ro ~= nil, } end function _mt:connect(operation) if operation ~= nil and operation ~= "read" and operation ~= "write" then error("operation must be 'read' or 'write', was: " .. tostring(operation), 2) end if not operation or not self.config_ro then operation = "write" end local conn = self:get_stored_connection(operation) if conn then return conn end local connection, err = connect(operation == "write" and self.config or self.config_ro) if not connection then return nil, err end self:store_connection(connection, operation) return connection end function _mt:connect_migrations() return self:connect("write") end function _mt:close() for operation in pairs(OPERATIONS) do local conn = self:get_stored_connection(operation) if conn then local _, err = conn:disconnect() self:store_connection(nil, operation) if err then return nil, err end end end return true end function _mt:setkeepalive() for operation in pairs(OPERATIONS) do local conn = self:get_stored_connection(operation) if conn then local _, err = setkeepalive(conn) self:store_connection(nil, operation) if err then return nil, err end end end return true end function _mt:acquire_query_semaphore_resource(operation) local sem = self["sem_" .. operation] if not sem then return true end do local phase = get_phase() if phase == "init" or phase == "init_worker" then return true end end local ok, err = sem:wait(self.config.sem_timeout) if not ok then return nil, err end return true end function _mt:release_query_semaphore_resource(operation) local sem = self["sem_" .. operation] if not sem then return true end do local phase = get_phase() if phase == "init" or phase == "init_worker" then return true end end sem:post() end function _mt:query(sql, operation) if operation ~= nil and operation ~= "read" and operation ~= "write" then error("operation must be 'read' or 'write', was: " .. tostring(operation), 2) end local phase = get_phase() if not operation or not self.config_ro or (phase == "content" and ngx.ctx.KONG_PHASE == ADMIN_API_PHASE) then -- admin API requests skips the replica optimization -- to ensure all its results are always strongly consistent operation = "write" end local res, err, partial, num_queries local ok ok, err = self:acquire_query_semaphore_resource(operation) if not ok then return nil, "error acquiring query semaphore: " .. err end local conn = self:get_stored_connection(operation) if conn then res, err, partial, num_queries = conn:query(sql) else local connection local config = operation == "write" and self.config or self.config_ro connection, err = connect(config) if not connection then self:release_query_semaphore_resource(operation) return nil, err end res, err, partial, num_queries = connection:query(sql) setkeepalive(connection) end self:release_query_semaphore_resource(operation) if res then return res, nil, partial, num_queries or err end return nil, err, partial, num_queries end function _mt:iterate(sql) local res, err, partial, num_queries = self:query(sql, "read") if not res then local failed = false return function() if not failed then failed = true return false, err, partial, num_queries end -- return error only once to avoid infinite loop return nil end end if res == true then return iterator { true } end return iterator(res) end function _mt:truncate() local table_names, err = get_table_names(self, PROTECTED_TABLES) if not table_names then return nil, err end if #table_names == 0 then return true end local truncate_statement = concat { "TRUNCATE ", concat(table_names, ", "), " RESTART IDENTITY CASCADE;" } local ok, err = self:query(truncate_statement) if not ok then return nil, err end return true end function _mt:truncate_table(table_name) local truncate_statement = concat { "TRUNCATE ", self:escape_identifier(table_name), " RESTART IDENTITY CASCADE;" } local ok, err = self:query(truncate_statement) if not ok then return nil, err end return true end function _mt:setup_locks(_, _) logger.debug("creating 'locks' table if not existing...") local ok, err = self:query([[ BEGIN; CREATE TABLE IF NOT EXISTS locks ( key TEXT PRIMARY KEY, owner TEXT, ttl TIMESTAMP WITH TIME ZONE ); CREATE INDEX IF NOT EXISTS locks_ttl_idx ON locks (ttl); COMMIT;]]) if not ok then return nil, err end logger.debug("successfully created 'locks' table") return true end function _mt:insert_lock(key, ttl, owner) local ttl_escaped = concat { "TO_TIMESTAMP(", self:escape_literal(tonumber(fmt("%.3f", now_updated() + ttl))), ") AT TIME ZONE 'UTC'" } local sql = concat { "BEGIN;\n", " DELETE FROM locks\n", " WHERE ttl < CURRENT_TIMESTAMP AT TIME ZONE 'UTC';\n", " INSERT INTO locks (key, owner, ttl)\n", " VALUES (", self:escape_literal(key), ", ", self:escape_literal(owner), ", ", ttl_escaped, ")\n", " ON CONFLICT DO NOTHING;\n", "COMMIT;" } local res, err, _, num_queries = self:query(sql) if not res then return nil, err end if num_queries ~= 4 then return nil, "unexpected result" end if res[3] and res[3].affected_rows == 1 then return true end return false end function _mt:read_lock(key) local sql = concat { "SELECT *\n", " FROM locks\n", " WHERE key = ", self:escape_literal(key), "\n", " AND ttl >= CURRENT_TIMESTAMP AT TIME ZONE 'UTC'\n", " LIMIT 1;" } local res, err = self:query(sql) if not res then return nil, err end return res[1] ~= nil end function _mt:remove_lock(key, owner) local sql = concat { "DELETE\n", " FROM ", self:escape_identifier("locks"), "\n", " WHERE ", self:escape_identifier("key"), " = ", self:escape_literal(key), "\n", " AND ", self:escape_identifier("owner"), " = ", self:escape_literal(owner), ";" } local res, err = self:query(sql) if not res then return nil, err end return true end function _mt:schema_migrations() local conn = self:get_stored_connection() if not conn then error("no connection") end local table_names, err = get_table_names(self) if not table_names then return nil, err end local schema_meta_table_name = self:escape_identifier("schema_meta") local schema_meta_table_exists for _, table_name in ipairs(table_names) do if table_name == schema_meta_table_name then schema_meta_table_exists = true break end end if not schema_meta_table_exists then -- database, but no schema_meta: needs bootstrap return nil end local rows, err = self:query(concat({ "SELECT *\n", " FROM schema_meta\n", " WHERE key = ", self:escape_literal("schema_meta"), ";" }), "read") if not rows then return nil, err end for _, row in ipairs(rows) do if row.pending == null then row.pending = nil end end -- no migrations: is bootstrapped but not migrated -- migrations: has some migrations return rows end function _mt:schema_bootstrap(kong_config, default_locks_ttl) local conn = self:get_stored_connection() if not conn then error("no connection") end -- create schema if not exists logger.debug("creating '%s' schema if not existing...", self.config.schema) local schema = self:escape_identifier(self.config.schema) local ok, err = self:query(concat { "BEGIN;\n", " DO $$\n", " BEGIN\n", " CREATE SCHEMA IF NOT EXISTS ", schema, " AUTHORIZATION CURRENT_USER;\n", " GRANT ALL ON SCHEMA ", schema ," TO CURRENT_USER;\n", " EXCEPTION WHEN insufficient_privilege THEN\n", " -- Do nothing, perhaps the schema has been created already\n", " END;\n", " $$;\n", " SET SCHEMA ", self:escape_literal(self.config.schema), ";\n", "COMMIT;", }) if not ok then return nil, err end logger.debug("successfully created '%s' schema", self.config.schema) -- create schema meta table if not exists logger.debug("creating 'schema_meta' table if not existing...") local res, err = self:query([[ CREATE TABLE IF NOT EXISTS schema_meta ( key TEXT, subsystem TEXT, last_executed TEXT, executed TEXT[], pending TEXT[], PRIMARY KEY (key, subsystem) );]]) if not res then return nil, err end logger.debug("successfully created 'schema_meta' table") local ok ok, err = self:setup_locks(default_locks_ttl, true) if not ok then return nil, err end return true end function _mt:schema_reset() local conn = self:get_stored_connection() if not conn then error("no connection") end return reset_schema(self) end function _mt:run_up_migration(name, up_sql) if type(name) ~= "string" then error("name must be a string", 2) end if type(up_sql) ~= "string" then error("up_sql must be a string", 2) end local conn = self:get_stored_connection() if not conn then error("no connection") end local sql = stringx.strip(up_sql) if sub(sql, -1) ~= ";" then sql = sql .. ";" end local sql = concat { "BEGIN;\n", sql, "\n", "COMMIT;\n", } local res, err = self:query(sql) if not res then self:query("ROLLBACK;") return nil, err end return true end function _mt:record_migration(subsystem, name, state) if type(subsystem) ~= "string" then error("subsystem must be a string", 2) end if type(name) ~= "string" then error("name must be a string", 2) end local conn = self:get_stored_connection() if not conn then error("no connection") end local key_escaped = self:escape_literal("schema_meta") local subsystem_escaped = self:escape_literal(subsystem) local name_escaped = self:escape_literal(name) local name_array = encode_array({ name }) local sql if state == "executed" then sql = concat({ "INSERT INTO schema_meta (key, subsystem, last_executed, executed)\n", " VALUES (", key_escaped, ", ", subsystem_escaped, ", ", name_escaped, ", ", name_array, ")\n", "ON CONFLICT (key, subsystem) DO UPDATE\n", " SET last_executed = EXCLUDED.last_executed,\n", " executed = ARRAY_APPEND(COALESCE(schema_meta.executed, ARRAY[]::TEXT[]), ", name_escaped, ");", }) elseif state == "pending" then sql = concat({ "INSERT INTO schema_meta (key, subsystem, pending)\n", " VALUES (", key_escaped, ", ", subsystem_escaped, ", ", name_array, ")\n", "ON CONFLICT (key, subsystem) DO UPDATE\n", " SET pending = ARRAY_APPEND(schema_meta.pending, ", name_escaped, ");" }) elseif state == "teardown" then sql = concat({ "INSERT INTO schema_meta (key, subsystem, last_executed, executed)\n", " VALUES (", key_escaped, ", ", subsystem_escaped, ", ", name_escaped, ", ", name_array, ")\n", "ON CONFLICT (key, subsystem) DO UPDATE\n", " SET last_executed = EXCLUDED.last_executed,\n", " executed = ARRAY_APPEND(COALESCE(schema_meta.executed, ARRAY[]::TEXT[]), ", name_escaped, "),\n", " pending = ARRAY_REMOVE(COALESCE(schema_meta.pending, ARRAY[]::TEXT[]), ", name_escaped, ");", }) else error("unknown 'state' argument: " .. tostring(state)) end local res, err = self:query(sql) if not res then return nil, err end return true end local _M = {} function _M.new(kong_config) local config = { host = kong_config.pg_host, port = kong_config.pg_port, timeout = kong_config.pg_timeout, user = kong_config.pg_user, password = kong_config.pg_password, database = kong_config.pg_database, schema = kong_config.pg_schema or "", ssl = kong_config.pg_ssl, ssl_verify = kong_config.pg_ssl_verify, cafile = kong_config.lua_ssl_trusted_certificate_combined, sem_max = kong_config.pg_max_concurrent_queries or 0, sem_timeout = (kong_config.pg_semaphore_timeout or 60000) / 1000, } local refs = kong_config["$refs"] if refs then local user_ref = refs.pg_user local password_ref = refs.pg_password if user_ref or password_ref then config["$refs"] = { user = user_ref, password = password_ref, } end end local db = pgmoon.new(config) local sem if config.sem_max > 0 then local err sem, err = semaphore.new(config.sem_max) if not sem then ngx.log(ngx.CRIT, "failed creating the PostgreSQL connector semaphore: ", err) end end local self = { config = config, escape_identifier = db.escape_identifier, escape_literal = db.escape_literal, sem_write = sem, } if not ngx.IS_CLI and kong_config.pg_ro_host then ngx.log(ngx.DEBUG, "PostgreSQL connector readonly connection enabled") local ro_override = { host = kong_config.pg_ro_host, port = kong_config.pg_ro_port, timeout = kong_config.pg_ro_timeout, user = kong_config.pg_ro_user, password = kong_config.pg_ro_password, database = kong_config.pg_ro_database, schema = kong_config.pg_ro_schema, ssl = kong_config.pg_ro_ssl, ssl_verify = kong_config.pg_ro_ssl_verify, cafile = kong_config.lua_ssl_trusted_certificate_combined, sem_max = kong_config.pg_ro_max_concurrent_queries, sem_timeout = kong_config.pg_ro_semaphore_timeout and (kong_config.pg_ro_semaphore_timeout / 1000) or nil, } if refs then local ro_user_ref = refs.pg_ro_user local ro_password_ref = refs.pg_ro_password if ro_user_ref or ro_password_ref then ro_override["$refs"] = { user = ro_user_ref, password = ro_password_ref, } end end local config_ro = utils.table_merge(config, ro_override) local sem if config_ro.sem_max > 0 then local err sem, err = semaphore.new(config_ro.sem_max) if not sem then ngx.log(ngx.CRIT, "failed creating the PostgreSQL connector semaphore: ", err) end end self.config_ro = config_ro self.sem_read = sem end return setmetatable(self, _mt) end -- for tests only _mt._get_topologically_sorted_table_names = get_names_of_tables_with_ttl return _M