kong/kong/db/strategies/cassandra/init.lua (1,225 lines of code) (raw):
local iteration = require "kong.db.iteration"
local cassandra = require "cassandra"
local cjson = require "cjson"
local new_tab = require "table.new"
local clear_tab = require "table.clear"
local fmt = string.format
local rep = string.rep
local sub = string.sub
local byte = string.byte
local null = ngx.null
local type = type
local error = error
local pairs = pairs
local ipairs = ipairs
local insert = table.insert
local concat = table.concat
local get_phase = ngx.get_phase
local setmetatable = setmetatable
local encode_base64 = ngx.encode_base64
local decode_base64 = ngx.decode_base64
local APPLIED_COLUMN = "[applied]"
local cache_key_field = { type = "string" }
local ws_id_field = { type = "string", uuid = true }
local workspaces_strategy
local _M = {}
local _mt = {}
_mt.__index = _mt
local function format_cql(...)
return (fmt(...):gsub("^%s*", "")
:gsub("%s*$", "")
:gsub("\n%s*", "\n"))
end
local function get_ws_id()
local phase = get_phase()
return (phase ~= "init" and phase ~= "init_worker")
and ngx.ctx.workspace or kong.default_workspace
end
local function is_partitioned(self)
local cql
-- Assume a release version number of 3 & greater will use the same schema.
if self.connector.major_version >= 3 then
cql = format_cql([[
SELECT * FROM system_schema.columns
WHERE keyspace_name = '%s'
AND table_name = '%s'
AND column_name = 'partition';
]], self.connector.keyspace, self.schema.table_name)
else
cql = format_cql([[
SELECT * FROM system.schema_columns
WHERE keyspace_name = '%s'
AND columnfamily_name = '%s'
AND column_name = 'partition';
]], self.connector.keyspace, self.schema.table_name)
end
local rows, err = self.connector:query(cql, {}, nil, "read")
if err then
return nil, err
end
-- Assume a release version number of 3 & greater will use the same schema.
if self.connector.major_version >= 3 then
return rows[1] and rows[1].kind == "partition_key"
end
return not not rows[1]
end
local function build_queries(self)
local schema = self.schema
local n_fields = #schema.fields
local n_pk = #schema.primary_key
local has_composite_cache_key = schema.cache_key and #schema.cache_key > 1
local has_ws_id = schema.workspaceable
local select_columns = new_tab(n_fields, 0)
for field_name, field in schema:each_field() do
if field.type == "foreign" then
local db_columns = self.foreign_keys_db_columns[field_name]
for i = 1, #db_columns do
insert(select_columns, db_columns[i].col_name)
end
else
insert(select_columns, field_name)
end
end
select_columns = concat(select_columns, ", ")
local insert_columns = select_columns
local insert_bind_args = rep("?, ", n_fields):sub(1, -3)
if has_composite_cache_key then
insert_columns = select_columns .. ", cache_key"
insert_bind_args = insert_bind_args .. ", ?"
end
if has_ws_id then
insert_columns = insert_columns .. ", ws_id"
insert_bind_args = insert_bind_args .. ", ?"
select_columns = select_columns .. ", ws_id"
end
local select_bind_args = new_tab(n_pk, 0)
for _, field_name in self.each_pk_field() do
if schema.fields[field_name].type == "foreign" then
field_name = field_name .. "_" .. schema.fields[field_name].schema.primary_key[1]
end
insert(select_bind_args, field_name .. " = ?")
end
select_bind_args = concat(select_bind_args, " AND ")
local partitioned, err = is_partitioned(self)
if err then
return nil, err
end
if schema.ttl == true then
select_columns = select_columns .. fmt(", TTL(%s) as ttl", self.ttl_field())
end
if partitioned then
return {
insert = format_cql([[
INSERT INTO %s (partition, %s) VALUES ('%s', %s) IF NOT EXISTS
]], schema.table_name, insert_columns, schema.name, insert_bind_args),
insert_ttl = format_cql([[
INSERT INTO %s (partition, %s) VALUES ('%s', %s) IF NOT EXISTS USING TTL %s
]], schema.table_name, insert_columns, schema.name, insert_bind_args, "%u"),
insert_no_transaction = format_cql([[
INSERT INTO %s (partition, %s) VALUES ('%s', %s)
]], schema.table_name, insert_columns, schema.name, insert_bind_args),
insert_no_transaction_ttl = format_cql([[
INSERT INTO %s (partition, %s) VALUES ('%s', %s) USING TTL %s
]], schema.table_name, insert_columns, schema.name, insert_bind_args, "%u"),
select = format_cql([[
SELECT %s FROM %s WHERE partition = '%s' AND %s
]], select_columns, schema.table_name, schema.name, select_bind_args),
select_page = format_cql([[
SELECT %s FROM %s WHERE partition = '%s'
]], select_columns, schema.table_name, schema.name),
select_with_filter = format_cql([[
SELECT %s FROM %s WHERE partition = '%s' AND %s
]], select_columns, schema.table_name, schema.name, "%s"),
select_tags_cond_and_first_tag = format_cql([[
SELECT entity_id FROM tags WHERE entity_name = '%s' AND tag = ?
]], schema.table_name),
select_tags_cond_and_next_tags = format_cql([[
SELECT entity_id FROM tags WHERE entity_name = '%s' AND tag = ? AND entity_id IN ?
]], schema.table_name),
select_tags_cond_or = format_cql([[
SELECT tag, entity_id, other_tags FROM tags WHERE entity_name = '%s' AND tag IN ?
]], schema.table_name),
update = format_cql([[
UPDATE %s SET %s WHERE partition = '%s' AND %s IF EXISTS
]], schema.table_name, "%s", schema.name, select_bind_args),
update_ttl = format_cql([[
UPDATE %s USING TTL %s SET %s WHERE partition = '%s' AND %s IF EXISTS
]], schema.table_name, "%u", "%s", schema.name, select_bind_args),
upsert = format_cql([[
UPDATE %s SET %s WHERE partition = '%s' AND %s
]], schema.table_name, "%s", schema.name, select_bind_args),
upsert_ttl = format_cql([[
UPDATE %s USING TTL %s SET %s WHERE partition = '%s' AND %s
]], schema.table_name, "%u", "%s", schema.name, select_bind_args),
delete = format_cql([[
DELETE FROM %s WHERE partition = '%s' AND %s
]], schema.table_name, schema.name, select_bind_args),
}
end
return {
insert = format_cql([[
INSERT INTO %s (%s) VALUES (%s) IF NOT EXISTS
]], schema.table_name, insert_columns, insert_bind_args),
insert_ttl = format_cql([[
INSERT INTO %s (%s) VALUES (%s) IF NOT EXISTS USING TTL %s
]], schema.table_name, insert_columns, insert_bind_args, "%u"),
insert_no_transaction = format_cql([[
INSERT INTO %s (%s) VALUES (%s)
]], schema.table_name, insert_columns, insert_bind_args),
insert_no_transaction_ttl = format_cql([[
INSERT INTO %s ( %s) VALUES (%s) USING TTL %s
]], schema.table_name, insert_columns, insert_bind_args, "%u"),
-- might raise a "you must enable ALLOW FILTERING" error
select = format_cql([[
SELECT %s FROM %s WHERE %s
]], select_columns, schema.table_name, select_bind_args),
-- might raise a "you must enable ALLOW FILTERING" error
select_page = format_cql([[
SELECT %s FROM %s
]], select_columns, schema.table_name),
-- might raise a "you must enable ALLOW FILTERING" error
select_with_filter = format_cql([[
SELECT %s FROM %s WHERE %s
]], select_columns, schema.table_name, "%s"),
select_tags_cond_and_first_tag = format_cql([[
SELECT entity_id FROM tags WHERE entity_name = '%s' AND tag = ?
]], schema.table_name),
select_tags_cond_and_next_tags = format_cql([[
SELECT entity_id FROM tags WHERE entity_name = '%s' AND tag = ? AND entity_id IN ?
]], schema.table_name),
select_tags_cond_or = format_cql([[
SELECT tag, entity_id, other_tags FROM tags WHERE entity_name = '%s' AND tag IN ?
]], schema.table_name),
update = format_cql([[
UPDATE %s SET %s WHERE %s IF EXISTS
]], schema.table_name, "%s", select_bind_args),
update_ttl = format_cql([[
UPDATE %s USING TTL %s SET %s WHERE %s IF EXISTS
]], schema.table_name, "%u", "%s", select_bind_args),
upsert = format_cql([[
UPDATE %s SET %s WHERE %s
]], schema.table_name, "%s", select_bind_args),
upsert_ttl = format_cql([[
UPDATE %s USING TTL %s SET %s WHERE %s
]], schema.table_name, "%u", "%s", select_bind_args),
delete = format_cql([[
DELETE FROM %s WHERE %s
]], schema.table_name, select_bind_args),
}
end
local function get_query(self, query_name)
if not self.queries then
local err
self.queries, err = build_queries(self)
if err then
return nil, err
end
end
return self.queries[query_name]
end
local function serialize_arg(field, arg, ws_id)
local serialized_arg
if arg == null then
serialized_arg = cassandra.null
elseif field.uuid then
serialized_arg = cassandra.uuid(arg)
elseif field.timestamp then
serialized_arg = cassandra.timestamp(arg * 1000)
elseif field.type == "integer" then
serialized_arg = cassandra.int(arg)
elseif field.type == "number" then
serialized_arg = cassandra.float(arg)
elseif field.type == "boolean" then
serialized_arg = cassandra.boolean(arg)
elseif field.type == "string" then
if field.unique and ws_id and not field.unique_across_ws then
arg = ws_id .. ":" .. arg
end
serialized_arg = cassandra.text(arg)
elseif field.type == "array" then
local t = {}
for i = 1, #arg do
t[i] = serialize_arg(field.elements, arg[i], ws_id)
end
serialized_arg = cassandra.list(t)
elseif field.type == "set" then
local t = {}
for i = 1, #arg do
t[i] = serialize_arg(field.elements, arg[i], ws_id)
end
serialized_arg = cassandra.set(t)
elseif field.type == "map" then
local t = {}
for k, v in pairs(arg) do
t[k] = serialize_arg(field.values, arg[k], ws_id)
end
serialized_arg = cassandra.map(t)
elseif field.type == "record" then
serialized_arg = cassandra.text(cjson.encode(arg))
elseif field.type == "foreign" then
local fk_pk = field.schema.primary_key[1]
local fk_field = field.schema.fields[fk_pk]
serialized_arg = serialize_arg(fk_field, arg[fk_pk], ws_id)
else
error("[cassandra strategy] don't know how to serialize field")
end
return serialized_arg
end
local function serialize_foreign_pk(db_columns, args, args_names, foreign_pk, ws_id)
for _, db_column in ipairs(db_columns) do
local to_serialize
if foreign_pk == null then
to_serialize = null
else
to_serialize = foreign_pk[db_column.foreign_field_name]
end
insert(args, serialize_arg(db_column.foreign_field, to_serialize, ws_id))
if args_names then
insert(args_names, db_column.col_name)
end
end
end
-- Check existence of foreign entity.
--
-- Note: this follows an innevitable "read-before-write" pattern in
-- our Cassandra strategy. While unfortunate, this pattern is made
-- necessary for Kong to behave in a database-agnostic fashion between
-- its supported RDBMs and Cassandra. This pattern is judged acceptable
-- given the relatively low number of expected writes (more or less at
-- a human pace), and mitigated by the introduction of different levels
-- of consistency for read vs. write queries, as well as the linearizable
-- consistency of lightweight transactions (IF [NOT] EXISTS).
local function foreign_pk_exists(self, field_name, field, foreign_pk, ws_id)
local foreign_schema = field.schema
local foreign_strategy = _M.new(self.connector, foreign_schema,
self.errors)
local foreign_row, err_t = foreign_strategy:select(foreign_pk, { workspace = ws_id or null })
if err_t then
return nil, err_t
end
if not foreign_row then
return nil, self.errors:foreign_key_violation_invalid_reference(foreign_pk,
field_name,
foreign_schema.name)
end
if ws_id and foreign_row.ws_id and foreign_row.ws_id ~= ws_id then
return nil, self.errors:invalid_workspace(foreign_row.ws_id or "null")
end
return true
end
local function set_difference(old_set, new_set)
local new_set_hash = new_tab(0, #new_set)
for _, elem in ipairs(new_set) do
new_set_hash[elem] = true
end
local old_set_hash = new_tab(0, #old_set)
for _, elem in ipairs(old_set) do
old_set_hash[elem] = true
end
local elem_to_add = {}
local elem_to_delete = {}
local elem_not_changed = {}
for _, elem in ipairs(new_set) do
if not old_set_hash[elem] then
insert(elem_to_add, elem)
end
end
for _, elem in ipairs(old_set) do
if not new_set_hash[elem] then
insert(elem_to_delete, elem)
else
insert(elem_not_changed, elem)
end
end
return elem_to_add, elem_to_delete, elem_not_changed
end
-- Calculate the difference of current tags and updated tags of an entity
-- return the cql to execute, and error if any
--
-- Note: this follows an innevitable "read-before-write" pattern in
-- our Cassandra strategy. While unfortunate, this pattern is made
-- necessary for Kong to behave in a database-agnostic fashion between
-- its supported RDBMs and Cassandra. This pattern is judged acceptable
-- given the relatively low number of expected writes (more or less at
-- a human pace), and mitigated by the introduction of different levels
-- of consistency for read vs. write queries, as well as the linearizable
-- consistency of lightweight transactions (IF [NOT] EXISTS).
local function build_tags_cql(primary_key, schema, new_tags, ttl, rbw_entity)
local tags_to_add, tags_to_delete, tags_not_changed
new_tags = (not new_tags or new_tags == null) and {} or new_tags
if rbw_entity then
if rbw_entity and rbw_entity['tags'] and rbw_entity['tags'] ~= null then
tags_to_add, tags_to_delete, tags_not_changed = set_difference(rbw_entity['tags'], new_tags)
else
tags_to_add = new_tags
tags_to_delete = {}
tags_not_changed = {}
end
else
tags_to_add = new_tags
tags_to_delete = {}
tags_not_changed = {}
end
if #tags_to_add == 0 and #tags_to_delete == 0 then
return nil, nil
end
-- Note: here we assume tags column only exists
-- with those entities use id as their primary key
local entity_id = primary_key['id']
local cqls = {}
local update_cql = "UPDATE tags SET other_tags=? WHERE tag=? AND entity_name=? AND entity_id=?"
if ttl then
update_cql = update_cql .. fmt(" USING TTL %u", ttl)
end
for _, tag in ipairs(tags_not_changed) do
insert(cqls,
{
update_cql,
{ cassandra.set(new_tags), cassandra.text(tag), cassandra.text(schema.name), cassandra.text(entity_id) }
}
)
end
local insert_cql = "INSERT INTO tags (tag, entity_name, entity_id, other_tags) VALUES (?, ?, ?, ?)"
if ttl then
insert_cql = insert_cql .. fmt(" USING TTL %u", ttl)
end
for _, tag in ipairs(tags_to_add) do
insert(cqls,
{
insert_cql,
{ cassandra.text(tag), cassandra.text(schema.name), cassandra.text(entity_id), cassandra.set(new_tags) }
}
)
end
local delete_cql = "DELETE FROM tags WHERE tag=? AND entity_name=? and entity_id=?"
if ttl then
delete_cql = delete_cql .. fmt(" USING TTL %u", ttl)
end
for _, tag in ipairs(tags_to_delete) do
insert(cqls,
{
delete_cql,
{ cassandra.text(tag), cassandra.text(schema.name), cassandra.text(entity_id) }
}
)
end
return cqls, nil
end
function _M.new(connector, schema, errors)
local n_fields = #schema.fields
local n_pk = #schema.primary_key
local each_pk_field
local each_non_pk_field
local ttl_field
do
local non_pk_fields = new_tab(n_fields - n_pk, 0)
local pk_fields = new_tab(n_pk, 0)
for field_name, field in schema:each_field() do
local is_pk
for _, pk_field_name in ipairs(schema.primary_key) do
if field_name == pk_field_name then
is_pk = true
break
end
end
insert(is_pk and pk_fields or non_pk_fields, {
field_name = field_name,
field = field,
})
end
local function iter(t, i)
i = i + 1
local f = t[i]
if f then
return i, f.field_name, f.field
end
end
each_pk_field = function()
return iter, pk_fields, 0
end
each_non_pk_field = function()
return iter, non_pk_fields, 0
end
ttl_field = function()
return schema.ttl and non_pk_fields[1] and non_pk_fields[1].field_name
end
end
-- self instanciation
local self = {
connector = connector, -- instance of kong.db.strategies.cassandra.init
schema = schema,
errors = errors,
each_pk_field = each_pk_field,
each_non_pk_field = each_non_pk_field,
ttl_field = ttl_field,
foreign_keys_db_columns = {},
queries = nil,
}
-- foreign keys constraints and page_for_ selector methods
for field_name, field in schema:each_field() do
if field.type == "foreign" then
local foreign_schema = field.schema
local foreign_pk = foreign_schema.primary_key
local foreign_pk_len = #foreign_pk
local db_columns = new_tab(foreign_pk_len, 0)
for i = 1, foreign_pk_len do
for foreign_field_name, foreign_field in foreign_schema:each_field() do
if foreign_field_name == foreign_pk[i] then
insert(db_columns, {
col_name = field_name .. "_" .. foreign_pk[i],
foreign_field = foreign_field,
foreign_field_name = foreign_field_name,
})
end
end
end
local db_columns_args_names = new_tab(#db_columns, 0)
for i = 1, #db_columns do
-- keep args_names for 'page_for_*' methods
db_columns_args_names[i] = db_columns[i].col_name .. " = ?"
end
db_columns.args_names = concat(db_columns_args_names, " AND ")
self.foreign_keys_db_columns[field_name] = db_columns
end
end
-- generate page_for_ method for inverse selection
-- e.g. routes:page_for_service(service_pk)
for field_name, field in schema:each_field() do
if field.type == "foreign" then
local method_name = "page_for_" .. field_name
local db_columns = self.foreign_keys_db_columns[field_name]
local select_foreign_bind_args = {}
for _, foreign_key_column in ipairs(db_columns) do
insert(select_foreign_bind_args, foreign_key_column.col_name .. " = ?")
end
self[method_name] = function(self, foreign_key, size, offset, options)
return self:page(size, offset, options, foreign_key, db_columns)
end
end
end
return setmetatable(self, _mt)
end
local function deserialize_aggregates(value, field)
if field.type == "record" then
if type(value) == "string" then
value = cjson.decode(value)
end
elseif field.type == "set" then
if type(value) == "table" then
for i = 1, #value do
value[i] = deserialize_aggregates(value[i], field.elements)
end
end
end
if value == nil then
return null
end
return value
end
function _mt:deserialize_row(row)
if not row then
error("row must be a table", 2)
end
-- deserialize rows
-- replace `nil` fields with `ngx.null`
-- replace `foreign_key` with `foreign = { key = "" }`
-- return timestamps in seconds instead of ms
for field_name, field in self.schema:each_field() do
local ws_unique = field.unique and not field.unique_across_ws
if field.type == "foreign" then
local db_columns = self.foreign_keys_db_columns[field_name]
local has_fk
row[field_name] = new_tab(0, #db_columns)
for i = 1, #db_columns do
local col_name = db_columns[i].col_name
if row[col_name] ~= nil then
row[field_name][db_columns[i].foreign_field_name] = row[col_name]
row[col_name] = nil
has_fk = true
end
end
if not has_fk then
row[field_name] = null
end
elseif field.timestamp and row[field_name] ~= nil then
row[field_name] = row[field_name] / 1000
elseif field.type == "string" and ws_unique and row[field_name] ~= nil then
local value = row[field_name]
-- for regular 'unique' values (that are *not* 'unique_across_ws')
-- value is of the form "<uuid>:<value>" in the DB: strip the "<uuid>:"
if byte(value, 37) == byte(":") then
row[field_name] = sub(value, 38)
end
else
row[field_name] = deserialize_aggregates(row[field_name], field)
end
end
return row
end
local function _select(self, cql, args, ws_id)
local rows, err = self.connector:query(cql, args, nil, "read")
if not rows then
return nil, self.errors:database_error("could not execute selection query: "
.. err)
end
-- lua-cassandra returns `nil` values for Cassandra's `NULL`. We need to
-- populate `ngx.null` ourselves
local row = rows[1]
if not row then
return nil
end
if row.ws_id and ws_id and row.ws_id ~= ws_id then
return nil
end
return self:deserialize_row(row)
end
local function check_unique(self, primary_key, entity, field_name, ws_id)
-- a UNIQUE constaint is set on this field.
-- We unfortunately follow a read-before-write pattern in this case,
-- but this is made necessary for Kong to behave in a
-- database-agnostic fashion between its supported RDBMs and
-- Cassandra.
local opts = { workspace = ws_id or null }
local row, err_t = self:select_by_field(field_name, entity[field_name], opts)
if err_t then
return nil, err_t
end
if row then
for _, pk_field_name in self.each_pk_field() do
if primary_key[pk_field_name] ~= row[pk_field_name] then
-- already exists
if field_name == "cache_key" then
local keys = {}
local schema = self.schema
for _, k in ipairs(schema.cache_key) do
local field = schema.fields[k]
if field.type == "foreign" and entity[k] ~= ngx.null then
keys[k] = field.schema:extract_pk_values(entity[k])
else
keys[k] = entity[k]
end
end
return nil, self.errors:unique_violation(keys)
end
return nil, self.errors:unique_violation {
[field_name] = entity[field_name],
}
end
end
end
return true
end
-- Determine if a workspace is to be used, and if so, which one.
-- If a workspace is given in `options.workspace` and the entity is
-- workspaceable, it will use it.
-- If `use_null` is false (indicating the query calling this function
-- does not accept global queries) or `options.workspace` is not given,
-- then this function will obtain the current workspace UUID from
-- the execution context.
-- @tparam table schema The schema definition table
-- @tparam table option The DAO request options table
-- @tparam boolean use_null If true, accept ngx.null as a possible
-- value of options.workspace and use it to signal a global query
-- @treturn boolean,string?,table? One of the following:
-- * false, nil, nil = entity is not workspaceable
-- * true, uuid, nil = entity is workspaceable, this is the workspace to use
-- * true, nil, nil = entity is workspaceable, but a global query was requested
-- * nil, nil, err = database error or selected workspace does not exist
local function check_workspace(self, options, use_null)
local workspace = options and options.workspace
local schema = self.schema
local ws_id
local has_ws_id = schema.workspaceable
if has_ws_id then
if use_null and workspace == null then
ws_id = nil
elseif workspace ~= nil and workspace ~= null then
ws_id = workspace
else
ws_id = get_ws_id()
end
end
-- check that workspace actually exists
if ws_id then
if not workspaces_strategy then
local Entity = require("kong.db.schema.entity")
local schema = Entity.new(require("kong.db.schema.entities.workspaces"))
workspaces_strategy = _M.new(self.connector, schema, self.errors)
end
local row, err_t = workspaces_strategy:select({ id = ws_id })
if err_t then
return nil, nil, err_t
end
if not row then
return nil, nil, self.errors:invalid_workspace(ws_id)
end
end
return has_ws_id, ws_id
end
function _mt:insert(entity, options)
local schema = self.schema
local args = new_tab(#schema.fields, 0)
local ttl = schema.ttl and options and options.ttl
local has_composite_cache_key = schema.cache_key and #schema.cache_key > 1
local has_ws_id, ws_id, err = check_workspace(self, options, false)
if err then
return nil, err
end
local cql_batch
local batch_mode
local mode = 'insert'
local primary_key
if schema.fields.tags then
primary_key = schema:extract_pk_values(entity)
local err_t
cql_batch, err_t = build_tags_cql(primary_key, schema, entity["tags"], ttl)
if err_t then
return nil, err_t
end
if cql_batch then
-- Batch with conditions cannot span multiple tables
-- Note this will also disables the APPLIED_COLUMN check
mode = 'insert_no_transaction'
batch_mode = true
end
end
local cql, err
if ttl then
cql, err = get_query(self, mode .. "_ttl")
if err then
return nil, err
end
cql = fmt(cql, ttl)
else
cql, err = get_query(self, mode)
if err then
return nil, err
end
end
-- serialize VALUES clause args
for field_name, field in schema:each_field() do
if field.type == "foreign" then
local foreign_pk = entity[field_name]
if foreign_pk ~= null then
-- if given, check if this foreign entity exists
local exists, err_t = foreign_pk_exists(self, field_name, field, foreign_pk, ws_id)
if not exists then
return nil, err_t
end
end
local db_columns = self.foreign_keys_db_columns[field_name]
serialize_foreign_pk(db_columns, args, nil, foreign_pk, ws_id)
else
if field.unique
and entity[field_name] ~= null
and entity[field_name] ~= nil
then
-- a UNIQUE constaint is set on this field.
-- We unfortunately follow a read-before-write pattern in this case,
-- but this is made necessary for Kong to behave in a database-agnostic
-- fashion between its supported RDBMs and Cassandra.
primary_key = primary_key or schema:extract_pk_values(entity)
local _, err_t = check_unique(self, primary_key, entity, field_name, ws_id)
if err_t then
return nil, err_t
end
end
insert(args, serialize_arg(field, entity[field_name], ws_id))
end
end
if has_composite_cache_key then
primary_key = primary_key or schema:extract_pk_values(entity)
local _, err_t = check_unique(self, primary_key, entity, "cache_key", ws_id)
if err_t then
return nil, err_t
end
insert(args, serialize_arg(cache_key_field, entity["cache_key"], ws_id))
end
if has_ws_id then
insert(args, serialize_arg(ws_id_field, ws_id, ws_id))
end
-- execute query
local res, err
if batch_mode then
-- insert the cql to current entity table at first position
insert(cql_batch, 1, {cql, args})
res, err = self.connector:batch(cql_batch, nil, "write", true)
else
res, err = self.connector:query(cql, args, nil, "write")
end
if not res then
return nil, self.errors:database_error("could not execute insertion query: "
.. err)
end
-- check for linearizable consistency (Paxos)
-- in batch_mode, we currently don't know the APPLIED_COLUMN
if not batch_mode and res[1][APPLIED_COLUMN] == false then
-- lightweight transaction (IF NOT EXISTS) failed,
-- retrieve PK values for the PK violation error
primary_key = primary_key or schema:extract_pk_values(entity)
return nil, self.errors:primary_key_violation(primary_key)
end
-- return foreign key as if they were fetched from :select()
-- this means foreign relationship tables should only contain
-- the primary key of the foreign entity
clear_tab(res)
for field_name, field in schema:each_field() do
local value = entity[field_name]
if field.type == "foreign" then
if value ~= null and value ~= nil then
value = field.schema:extract_pk_values(value)
else
value = null
end
end
res[field_name] = value
end
if has_ws_id then
res.ws_id = ws_id
end
return res
end
function _mt:select(primary_key, options)
local schema = self.schema
local _, ws_id, err = check_workspace(self, options, true)
if err then
return nil, err
end
local cql
cql, err = get_query(self, "select")
if err then
return nil, err
end
local args = new_tab(#schema.primary_key, 0)
-- serialize WHERE clause args
for i, field_name, field in self.each_pk_field() do
args[i] = serialize_arg(field, primary_key[field_name], ws_id)
end
-- execute query
return _select(self, cql, args, ws_id)
end
function _mt:select_by_field(field_name, field_value, options)
local has_ws_id, ws_id, err = check_workspace(self, options, true)
if err then
return nil, err
end
if has_ws_id and ws_id == nil
and not self.schema.fields[field_name].unique_across_ws then
-- fail with error: this is not a database failure, this is programmer error
error("cannot select on field " .. field_name .. "without a workspace " ..
"because it is not marked unique_across_ws")
end
local cql, err = get_query(self, "select_with_filter")
if err then
return nil, err
end
local field
if field_name == "cache_key" then
field = cache_key_field
else
field = self.schema.fields[field_name]
if field
and field.reference
and self.foreign_keys_db_columns[field_name]
and self.foreign_keys_db_columns[field_name][1]
then
field_name = self.foreign_keys_db_columns[field_name][1].col_name
end
end
local select_cql = fmt(cql, field_name .. " = ?")
local bind_args = new_tab(1, 0)
bind_args[1] = serialize_arg(field, field_value, ws_id)
return _select(self, select_cql, bind_args, ws_id)
end
do
local function execute_page(self, cql, args, offset, opts)
local rows, err = self.connector:query(cql, args, opts, "read")
if not rows then
if err:match("Invalid value for the paging state") then
return nil, self.errors:invalid_offset(offset, err)
end
return nil, self.errors:database_error("could not execute page query: "
.. err)
end
local next_offset
if rows.meta and rows.meta.paging_state then
next_offset = rows.meta.paging_state
end
rows.meta = nil
rows.type = nil
return rows, nil, next_offset
end
local function query_page(self, offset, foreign_key, foreign_key_db_columns, opts)
local _, ws_id, err = check_workspace(self, opts, true)
if err then
return nil, err
end
local cql
local args
if not foreign_key then
if ws_id then
cql, err = get_query(self, "select_with_filter")
if err then
return nil, err
end
cql = fmt(cql, "ws_id = ?")
args = { serialize_arg(ws_id_field, ws_id, ws_id) }
else
cql, err = get_query(self, "select_page")
if err then
return nil, err
end
end
elseif foreign_key and foreign_key_db_columns then
args = new_tab(#foreign_key_db_columns, 0)
cql, err = get_query(self, "select_with_filter")
if err then
return nil, err
end
cql = fmt(cql, foreign_key_db_columns.args_names)
serialize_foreign_pk(foreign_key_db_columns, args, nil, foreign_key, ws_id)
else
error("should provide both of: foreign_key, foreign_key_db_columns", 2)
end
local rows, err_t, next_offset = execute_page(self, cql, args, offset, opts)
if err_t then
return nil, err_t
end
for i = 1, #rows do
rows[i] = self:deserialize_row(rows[i])
end
return rows, nil, next_offset and encode_base64(next_offset)
end
--[[
Define the max rounds of queries we will send when filtering entity with
tags
For each "round" with AND we send queries at max to the number of tags provided and
filter in Lua land with entities with all tags provided; for OR we send one request
each round.
Depending on the distribution of tags attached to entity, it might be possible
that number of tags attached with one tag and doesn't has others is larger
than the page size provided. In such case, we limit the "rounds" of such
filtering and thus limit the total number of queries we send per paging
request to be at most (number of tags) * PAGING_MAX_QUERY_ROUNDS
Note the number here may not suite all conditions. If the paging request
returns too less results, this limit can be bumped up. To archieve less
latency for the Admin paging API, this limit can be decreased.
]]--
local PAGING_MAX_QUERY_ROUNDS = 20
-- helper function used in query_page_for_tags to translate
-- a row with entity_id with an entity
local function dereference_rows(self, entity_ids, entity_count)
if not entity_ids then
return {}, nil, nil
end
local entity_index = 0
entity_count = entity_count or #entity_ids
local entities = new_tab(entity_count, 0)
-- TODO: send one query using IN
for i, row in ipairs(entity_ids) do
-- TODO: pk name id is hardcoded
local entity, err, err_t = self:select{ id = row.entity_id }
if err then
return nil, err, err_t
end
if entity then
entity_index = entity_index + 1
entities[entity_index] = entity
end
end
return entities, nil, nil
end
local function query_page_for_tags(self, size, offset, tags, cond, opts)
-- TODO: if we don't sort, we can have a performance guidance to user
-- to "always put tags with less entity at the front of query"
table.sort(tags)
local tags_count = #tags
-- merge the condition of only one tags to be "and" condition
local cond_or = cond == "or" and tags_count > 1
local cql
local args
local next_offset = opts.paging_state
local tags_hash = new_tab(0, tags_count)
local cond_and_cql_first, cond_and_cql_next
if cond_or then
cql = get_query(self, "select_tags_cond_or")
args = { cassandra.list(tags) }
for _, tag in ipairs(tags) do
tags_hash[tag] = true
end
else
cond_and_cql_first = get_query(self, "select_tags_cond_and_first_tag")
cond_and_cql_next = get_query(self, "select_tags_cond_and_next_tags")
end
-- the entity_ids to return
local entity_ids = new_tab(size, 0)
local entity_count = 0
-- a temporary table for current query
local current_entity_ids = new_tab(size, 0)
local rows, err_t
for _=1, PAGING_MAX_QUERY_ROUNDS, 1 do
local current_next_offset
local current_entity_count = 0
if cond_or then
rows, err_t, next_offset = execute_page(self, cql, args, offset, opts)
if err_t then
return nil, err_t, nil
end
clear_tab(current_entity_ids)
for _, row in ipairs(rows) do
local row_tag = row.tag
local duplicated = false
for _, attached_tag in ipairs(row.other_tags) do
-- To ensure we don't add same entity_id twice (during current
-- admin api request or across different requests), we only add
-- entity_id when it first appears in the result set.
-- That means we don't add current row if row.tag
-- 1. is a matching tag towards provided tags (that means this
-- entity_id can potentially be duplicated), and
-- 2. is not the alphabetically smallest tag among all its tags
-- (as in row.other_tags)
if tags_hash[attached_tag] and attached_tag < row_tag then
duplicated = true
break
end
end
if not duplicated then
current_entity_count = current_entity_count + 1
current_entity_ids[current_entity_count] = row.entity_id
end
end
else
for i=1, #tags, 1 do
local tag = tags[i]
if i == 1 then
opts.paging_state = next_offset
cql = cond_and_cql_first
-- TODO: cache me
args = { cassandra.text(tag) }
else
opts.paging_state = nil
cql = cond_and_cql_next
-- TODO: cache me
args = { cassandra.text(tag), cassandra.list(current_entity_ids) }
end
rows, err_t, current_next_offset = execute_page(self, cql, args, nil, opts)
if err_t then
return nil, err_t, nil
end
if i == 1 then
next_offset = current_next_offset
end
-- No rows left, stop filtering
if not rows or #rows == 0 then
current_entity_count = 0
break
end
clear_tab(current_entity_ids)
current_entity_count = 0
for i, row in ipairs(rows) do
current_entity_count = current_entity_count + 1
current_entity_ids[current_entity_count] = row.entity_id
end
end
end
if current_entity_count > 0 then
for i=1, current_entity_count do
entity_count = entity_count + 1
entity_ids[entity_count] = { entity_id = current_entity_ids[i] }
if entity_count >= size then
-- shouldn't be "larger than" actually
break
end
end
if entity_count < size then
-- next time we only read what we left
-- to return a row with length of `size`
opts.page_size = size - entity_count
end
end
-- break the loop either we read enough rows
-- or no more data available in the datastore
if entity_count >= size or not next_offset then
break
end
end
local entities, err_t = dereference_rows(self, entity_ids, entity_count)
if err_t then
return nil, err_t, nil
end
if next_offset and entity_count < size then
-- Note: don't cache ngx.log so we can test in 02-intergration/07-tags_spec.lua
ngx.log(ngx.WARN, "maximum ", PAGING_MAX_QUERY_ROUNDS, " rounds exceeded ",
"without retrieving required size of rows, ",
"consider lower the sparsity of tags, or increase the paging size per request"
)
end
return entities, nil, next_offset and encode_base64(next_offset)
end
function _mt:page(size, offset, options, foreign_key, foreign_key_db_columns)
local opts = new_tab(0, 2)
if not size then
size = self.connector:get_page_size(options)
end
if offset then
local offset_decoded = decode_base64(offset)
if not offset_decoded then
return nil, self.errors:invalid_offset(offset, "bad base64 encoding")
end
offset = offset_decoded
end
opts.page_size = size
opts.paging_state = offset
opts.workspace = options and options.workspace
if not foreign_key and options and options.tags then
return query_page_for_tags(self, size, offset, options.tags, options.tags_cond, opts)
end
return query_page(self, offset, foreign_key, foreign_key_db_columns, opts)
end
end
do
local function update(self, primary_key, entity, mode, options)
local schema = self.schema
local ttl = schema.ttl and options and options.ttl
local has_ws_id, ws_id, err = check_workspace(self, options, false)
if err then
return nil, err
end
local cql_batch
local batch_mode
if schema.fields.tags then
local rbw_entity, err_t = self:select(primary_key, { workspace = ws_id or null })
if err_t then
return nil, err_t
end
cql_batch, err_t = build_tags_cql(primary_key, schema, entity["tags"], ttl, rbw_entity)
if err_t then
return nil, err_t
end
if cql_batch then
-- Batch with conditions cannot span multiple tables
-- Note this will also disables the APPLIED_COLUMN check
mode = 'upsert'
batch_mode = true
end
end
local query_name
if ttl then
query_name = mode .. "_ttl"
else
query_name = mode
end
local cql, err = get_query(self, query_name)
if err then
return nil, err
end
local args = new_tab(#schema.fields, 0)
local args_names = new_tab(#schema.fields, 0)
-- serialize SET clause args
for _, field_name, field in self.each_non_pk_field() do
if entity[field_name] ~= nil then
if field.unique and entity[field_name] ~= null then
local _, err_t = check_unique(self, primary_key, entity, field_name, ws_id)
if err_t then
return nil, err_t
end
end
if field.type == "foreign" then
local foreign_pk = entity[field_name]
if foreign_pk ~= null then
-- if given, check if this foreign entity exists
local exists, err_t = foreign_pk_exists(self, field_name, field, foreign_pk, ws_id)
if not exists then
return nil, err_t
end
end
local db_columns = self.foreign_keys_db_columns[field_name]
serialize_foreign_pk(db_columns, args, args_names, foreign_pk, ws_id)
else
insert(args, serialize_arg(field, entity[field_name], ws_id))
insert(args_names, field_name)
end
end
end
local has_composite_cache_key = schema.cache_key and #schema.cache_key > 1
if has_composite_cache_key then
local _, err_t = check_unique(self, primary_key, entity, "cache_key", ws_id)
if err_t then
return nil, err_t
end
insert(args, serialize_arg(cache_key_field, entity["cache_key"], ws_id))
end
if has_ws_id then
insert(args, serialize_arg(ws_id_field, ws_id, ws_id))
end
-- serialize WHERE clause args
for i, field_name, field in self.each_pk_field() do
insert(args, serialize_arg(field, primary_key[field_name], ws_id))
end
-- inject SET clause bindings
local n_args = #args_names
local update_columns_binds = new_tab(n_args, 0)
for i = 1, n_args do
update_columns_binds[i] = args_names[i] .. " = ?"
end
if has_composite_cache_key then
insert(update_columns_binds, "cache_key = ?")
end
if has_ws_id then
insert(update_columns_binds, "ws_id = ?")
end
if ttl then
cql = fmt(cql, ttl, concat(update_columns_binds, ", "))
else
cql = fmt(cql, concat(update_columns_binds, ", "))
end
-- execute query
local res, err
if batch_mode then
-- insert the cql to current entity table at first position
insert(cql_batch, 1, {cql, args})
res, err = self.connector:batch(cql_batch, nil, "write", true)
else
res, err = self.connector:query(cql, args, nil, "write")
end
if not res then
return nil, self.errors:database_error("could not execute update query: "
.. err)
end
if not batch_mode and mode == "update" and res[1][APPLIED_COLUMN] == false then
return nil, self.errors:not_found(primary_key)
end
-- SELECT after write
local row, err_t = self:select(primary_key, { workspace = ws_id or null })
if err_t then
return nil, err_t
end
if not row then
return nil, self.errors:not_found(primary_key)
end
return row
end
local function update_by_field(self, field_name, field_value, entity, mode, options)
local row, err_t = self:select_by_field(field_name, field_value)
if err_t then
return nil, err_t
end
if not row then
if mode == "upsert" then
row = entity
row[field_name] = field_value
else
return nil, self.errors:not_found_by_field({
[field_name] = field_value,
})
end
end
local pk = self.schema:extract_pk_values(row)
return self[mode](self, pk, entity, options)
end
function _mt:update(primary_key, entity, options)
return update(self, primary_key, entity, "update", options)
end
function _mt:upsert(primary_key, entity, options)
return update(self, primary_key, entity, "upsert", options)
end
function _mt:update_by_field(field_name, field_value, entity, options)
return update_by_field(self, field_name, field_value, entity, "update", options)
end
function _mt:upsert_by_field(field_name, field_value, entity, options)
return update_by_field(self, field_name, field_value, entity, "upsert", options)
end
end
do
local function select_by_foreign_key(self, foreign_schema,
foreign_field_name, foreign_key, ws_id)
local n_fields = #foreign_schema.fields
local strategy = _M.new(self.connector, foreign_schema, self.errors)
local cql, err = get_query(strategy, "select_with_filter")
if err then
return nil, err
end
local args = new_tab(n_fields, 0)
local args_names = new_tab(n_fields, 0)
if foreign_field_name then
local db_columns = strategy.foreign_keys_db_columns[foreign_field_name]
serialize_foreign_pk(db_columns, args, args_names, foreign_key, ws_id)
else
-- workspaces don't have a foreign_field_name
-- and the query needs to be different than in "regular" foreign keys
cql = fmt(cql, "ws_id = ?")
args = { serialize_arg(ws_id_field, ws_id, ws_id) }
end
local n_args = #args_names
local where_clause_binds = new_tab(n_args, 0)
for i = 1, n_args do
where_clause_binds[i] = args_names[i] .. " = ?"
end
cql = fmt(cql, concat(where_clause_binds, " AND "))
return _select(strategy, cql, args, ws_id)
end
function _mt:delete(primary_key, options)
local _, ws_id, err = check_workspace(self, options, true)
if err then
return nil, err
end
if self.schema.name == "workspaces" then
ws_id = primary_key.id
end
local schema = self.schema
local cql, err = get_query(self, "delete")
if err then
return nil, err
end
local args = new_tab(#schema.primary_key, 0)
local constraints = schema:get_constraints()
for i = 1, #constraints do
local constraint = constraints[i]
-- foreign keys could be pointing to this entity
-- this mimics the "ON DELETE" constraint of supported
-- RDBMs (e.g. PostgreSQL)
--
-- The possible behaviors on such a constraint are:
-- * RESTRICT (default)
-- * CASCADE (on_delete = "cascade", NYI)
-- * SET NULL (NYI)
local behavior = constraint.on_delete or "restrict"
if behavior == "restrict" then
local row, err_t = select_by_foreign_key(self,
constraint.schema,
constraint.field_name,
primary_key, ws_id)
if err_t then
return nil, err_t
end
if row then
-- a row is referring to this entity, we cannot delete it.
-- deleting the parent entity would violate the foreign key
-- constraint
return nil, self.errors:foreign_key_violation_restricted(schema.name,
constraint.schema.name)
end
elseif behavior == "cascade" then
local strategy = _M.new(self.connector, constraint.schema, self.errors)
local method = "page_for_" .. constraint.field_name
local pager = function(size, offset)
return strategy[method](strategy, primary_key, size, offset)
end
for row, err in iteration.by_row(self, pager) do
if err then
return nil, self.errors:database_error("could not gather " ..
"associated entities " ..
"for delete cascade: " .. err)
end
local row_pk = constraint.schema:extract_pk_values(row)
local _
_, err = strategy:delete(row_pk)
if err then
return nil, self.errors:database_error("could not cascade " ..
"delete entity: " .. err)
end
end
end
end
-- serialize WHERE clause args
for i, field_name, field in self.each_pk_field() do
args[i] = serialize_arg(field, primary_key[field_name], ws_id)
end
local rbw_entity, err_t
if schema.workspaceable or schema.fields.tags then
rbw_entity, err_t = self:select(primary_key, { workspace = ws_id or null })
if err_t then
return nil, err_t
end
if not rbw_entity then
return true
end
end
local cql_batch
if schema.fields.tags then
cql_batch, err_t = build_tags_cql(primary_key, self.schema, {}, nil, rbw_entity)
if err_t then
return nil, err_t
end
end
-- execute query
local res, err
if cql_batch then
insert(cql_batch, 1, {cql, args} )
res, err = self.connector:batch(cql_batch, nil, "write", true)
else
res, err = self.connector:query(cql, args, nil, "write")
end
if not res then
return nil, self.errors:database_error("could not execute deletion query: "
.. err)
end
return true
end
end
function _mt:delete_by_field(field_name, field_value, options)
local row, err_t = self:select_by_field(field_name, field_value)
if err_t then
return nil, err_t
end
if not row then
return nil
end
local pk = self.schema:extract_pk_values(row)
return self:delete(pk, options)
end
function _mt:truncate(options)
return self.connector:truncate_table(self.schema.name, options)
end
return _M