kong/kong/db/strategies/cassandra/init.lua (1,225 lines of code) (raw):

local iteration = require "kong.db.iteration" local cassandra = require "cassandra" local cjson = require "cjson" local new_tab = require "table.new" local clear_tab = require "table.clear" local fmt = string.format local rep = string.rep local sub = string.sub local byte = string.byte local null = ngx.null local type = type local error = error local pairs = pairs local ipairs = ipairs local insert = table.insert local concat = table.concat local get_phase = ngx.get_phase local setmetatable = setmetatable local encode_base64 = ngx.encode_base64 local decode_base64 = ngx.decode_base64 local APPLIED_COLUMN = "[applied]" local cache_key_field = { type = "string" } local ws_id_field = { type = "string", uuid = true } local workspaces_strategy local _M = {} local _mt = {} _mt.__index = _mt local function format_cql(...) return (fmt(...):gsub("^%s*", "") :gsub("%s*$", "") :gsub("\n%s*", "\n")) end local function get_ws_id() local phase = get_phase() return (phase ~= "init" and phase ~= "init_worker") and ngx.ctx.workspace or kong.default_workspace end local function is_partitioned(self) local cql -- Assume a release version number of 3 & greater will use the same schema. if self.connector.major_version >= 3 then cql = format_cql([[ SELECT * FROM system_schema.columns WHERE keyspace_name = '%s' AND table_name = '%s' AND column_name = 'partition'; ]], self.connector.keyspace, self.schema.table_name) else cql = format_cql([[ SELECT * FROM system.schema_columns WHERE keyspace_name = '%s' AND columnfamily_name = '%s' AND column_name = 'partition'; ]], self.connector.keyspace, self.schema.table_name) end local rows, err = self.connector:query(cql, {}, nil, "read") if err then return nil, err end -- Assume a release version number of 3 & greater will use the same schema. if self.connector.major_version >= 3 then return rows[1] and rows[1].kind == "partition_key" end return not not rows[1] end local function build_queries(self) local schema = self.schema local n_fields = #schema.fields local n_pk = #schema.primary_key local has_composite_cache_key = schema.cache_key and #schema.cache_key > 1 local has_ws_id = schema.workspaceable local select_columns = new_tab(n_fields, 0) for field_name, field in schema:each_field() do if field.type == "foreign" then local db_columns = self.foreign_keys_db_columns[field_name] for i = 1, #db_columns do insert(select_columns, db_columns[i].col_name) end else insert(select_columns, field_name) end end select_columns = concat(select_columns, ", ") local insert_columns = select_columns local insert_bind_args = rep("?, ", n_fields):sub(1, -3) if has_composite_cache_key then insert_columns = select_columns .. ", cache_key" insert_bind_args = insert_bind_args .. ", ?" end if has_ws_id then insert_columns = insert_columns .. ", ws_id" insert_bind_args = insert_bind_args .. ", ?" select_columns = select_columns .. ", ws_id" end local select_bind_args = new_tab(n_pk, 0) for _, field_name in self.each_pk_field() do if schema.fields[field_name].type == "foreign" then field_name = field_name .. "_" .. schema.fields[field_name].schema.primary_key[1] end insert(select_bind_args, field_name .. " = ?") end select_bind_args = concat(select_bind_args, " AND ") local partitioned, err = is_partitioned(self) if err then return nil, err end if schema.ttl == true then select_columns = select_columns .. fmt(", TTL(%s) as ttl", self.ttl_field()) end if partitioned then return { insert = format_cql([[ INSERT INTO %s (partition, %s) VALUES ('%s', %s) IF NOT EXISTS ]], schema.table_name, insert_columns, schema.name, insert_bind_args), insert_ttl = format_cql([[ INSERT INTO %s (partition, %s) VALUES ('%s', %s) IF NOT EXISTS USING TTL %s ]], schema.table_name, insert_columns, schema.name, insert_bind_args, "%u"), insert_no_transaction = format_cql([[ INSERT INTO %s (partition, %s) VALUES ('%s', %s) ]], schema.table_name, insert_columns, schema.name, insert_bind_args), insert_no_transaction_ttl = format_cql([[ INSERT INTO %s (partition, %s) VALUES ('%s', %s) USING TTL %s ]], schema.table_name, insert_columns, schema.name, insert_bind_args, "%u"), select = format_cql([[ SELECT %s FROM %s WHERE partition = '%s' AND %s ]], select_columns, schema.table_name, schema.name, select_bind_args), select_page = format_cql([[ SELECT %s FROM %s WHERE partition = '%s' ]], select_columns, schema.table_name, schema.name), select_with_filter = format_cql([[ SELECT %s FROM %s WHERE partition = '%s' AND %s ]], select_columns, schema.table_name, schema.name, "%s"), select_tags_cond_and_first_tag = format_cql([[ SELECT entity_id FROM tags WHERE entity_name = '%s' AND tag = ? ]], schema.table_name), select_tags_cond_and_next_tags = format_cql([[ SELECT entity_id FROM tags WHERE entity_name = '%s' AND tag = ? AND entity_id IN ? ]], schema.table_name), select_tags_cond_or = format_cql([[ SELECT tag, entity_id, other_tags FROM tags WHERE entity_name = '%s' AND tag IN ? ]], schema.table_name), update = format_cql([[ UPDATE %s SET %s WHERE partition = '%s' AND %s IF EXISTS ]], schema.table_name, "%s", schema.name, select_bind_args), update_ttl = format_cql([[ UPDATE %s USING TTL %s SET %s WHERE partition = '%s' AND %s IF EXISTS ]], schema.table_name, "%u", "%s", schema.name, select_bind_args), upsert = format_cql([[ UPDATE %s SET %s WHERE partition = '%s' AND %s ]], schema.table_name, "%s", schema.name, select_bind_args), upsert_ttl = format_cql([[ UPDATE %s USING TTL %s SET %s WHERE partition = '%s' AND %s ]], schema.table_name, "%u", "%s", schema.name, select_bind_args), delete = format_cql([[ DELETE FROM %s WHERE partition = '%s' AND %s ]], schema.table_name, schema.name, select_bind_args), } end return { insert = format_cql([[ INSERT INTO %s (%s) VALUES (%s) IF NOT EXISTS ]], schema.table_name, insert_columns, insert_bind_args), insert_ttl = format_cql([[ INSERT INTO %s (%s) VALUES (%s) IF NOT EXISTS USING TTL %s ]], schema.table_name, insert_columns, insert_bind_args, "%u"), insert_no_transaction = format_cql([[ INSERT INTO %s (%s) VALUES (%s) ]], schema.table_name, insert_columns, insert_bind_args), insert_no_transaction_ttl = format_cql([[ INSERT INTO %s ( %s) VALUES (%s) USING TTL %s ]], schema.table_name, insert_columns, insert_bind_args, "%u"), -- might raise a "you must enable ALLOW FILTERING" error select = format_cql([[ SELECT %s FROM %s WHERE %s ]], select_columns, schema.table_name, select_bind_args), -- might raise a "you must enable ALLOW FILTERING" error select_page = format_cql([[ SELECT %s FROM %s ]], select_columns, schema.table_name), -- might raise a "you must enable ALLOW FILTERING" error select_with_filter = format_cql([[ SELECT %s FROM %s WHERE %s ]], select_columns, schema.table_name, "%s"), select_tags_cond_and_first_tag = format_cql([[ SELECT entity_id FROM tags WHERE entity_name = '%s' AND tag = ? ]], schema.table_name), select_tags_cond_and_next_tags = format_cql([[ SELECT entity_id FROM tags WHERE entity_name = '%s' AND tag = ? AND entity_id IN ? ]], schema.table_name), select_tags_cond_or = format_cql([[ SELECT tag, entity_id, other_tags FROM tags WHERE entity_name = '%s' AND tag IN ? ]], schema.table_name), update = format_cql([[ UPDATE %s SET %s WHERE %s IF EXISTS ]], schema.table_name, "%s", select_bind_args), update_ttl = format_cql([[ UPDATE %s USING TTL %s SET %s WHERE %s IF EXISTS ]], schema.table_name, "%u", "%s", select_bind_args), upsert = format_cql([[ UPDATE %s SET %s WHERE %s ]], schema.table_name, "%s", select_bind_args), upsert_ttl = format_cql([[ UPDATE %s USING TTL %s SET %s WHERE %s ]], schema.table_name, "%u", "%s", select_bind_args), delete = format_cql([[ DELETE FROM %s WHERE %s ]], schema.table_name, select_bind_args), } end local function get_query(self, query_name) if not self.queries then local err self.queries, err = build_queries(self) if err then return nil, err end end return self.queries[query_name] end local function serialize_arg(field, arg, ws_id) local serialized_arg if arg == null then serialized_arg = cassandra.null elseif field.uuid then serialized_arg = cassandra.uuid(arg) elseif field.timestamp then serialized_arg = cassandra.timestamp(arg * 1000) elseif field.type == "integer" then serialized_arg = cassandra.int(arg) elseif field.type == "number" then serialized_arg = cassandra.float(arg) elseif field.type == "boolean" then serialized_arg = cassandra.boolean(arg) elseif field.type == "string" then if field.unique and ws_id and not field.unique_across_ws then arg = ws_id .. ":" .. arg end serialized_arg = cassandra.text(arg) elseif field.type == "array" then local t = {} for i = 1, #arg do t[i] = serialize_arg(field.elements, arg[i], ws_id) end serialized_arg = cassandra.list(t) elseif field.type == "set" then local t = {} for i = 1, #arg do t[i] = serialize_arg(field.elements, arg[i], ws_id) end serialized_arg = cassandra.set(t) elseif field.type == "map" then local t = {} for k, v in pairs(arg) do t[k] = serialize_arg(field.values, arg[k], ws_id) end serialized_arg = cassandra.map(t) elseif field.type == "record" then serialized_arg = cassandra.text(cjson.encode(arg)) elseif field.type == "foreign" then local fk_pk = field.schema.primary_key[1] local fk_field = field.schema.fields[fk_pk] serialized_arg = serialize_arg(fk_field, arg[fk_pk], ws_id) else error("[cassandra strategy] don't know how to serialize field") end return serialized_arg end local function serialize_foreign_pk(db_columns, args, args_names, foreign_pk, ws_id) for _, db_column in ipairs(db_columns) do local to_serialize if foreign_pk == null then to_serialize = null else to_serialize = foreign_pk[db_column.foreign_field_name] end insert(args, serialize_arg(db_column.foreign_field, to_serialize, ws_id)) if args_names then insert(args_names, db_column.col_name) end end end -- Check existence of foreign entity. -- -- Note: this follows an innevitable "read-before-write" pattern in -- our Cassandra strategy. While unfortunate, this pattern is made -- necessary for Kong to behave in a database-agnostic fashion between -- its supported RDBMs and Cassandra. This pattern is judged acceptable -- given the relatively low number of expected writes (more or less at -- a human pace), and mitigated by the introduction of different levels -- of consistency for read vs. write queries, as well as the linearizable -- consistency of lightweight transactions (IF [NOT] EXISTS). local function foreign_pk_exists(self, field_name, field, foreign_pk, ws_id) local foreign_schema = field.schema local foreign_strategy = _M.new(self.connector, foreign_schema, self.errors) local foreign_row, err_t = foreign_strategy:select(foreign_pk, { workspace = ws_id or null }) if err_t then return nil, err_t end if not foreign_row then return nil, self.errors:foreign_key_violation_invalid_reference(foreign_pk, field_name, foreign_schema.name) end if ws_id and foreign_row.ws_id and foreign_row.ws_id ~= ws_id then return nil, self.errors:invalid_workspace(foreign_row.ws_id or "null") end return true end local function set_difference(old_set, new_set) local new_set_hash = new_tab(0, #new_set) for _, elem in ipairs(new_set) do new_set_hash[elem] = true end local old_set_hash = new_tab(0, #old_set) for _, elem in ipairs(old_set) do old_set_hash[elem] = true end local elem_to_add = {} local elem_to_delete = {} local elem_not_changed = {} for _, elem in ipairs(new_set) do if not old_set_hash[elem] then insert(elem_to_add, elem) end end for _, elem in ipairs(old_set) do if not new_set_hash[elem] then insert(elem_to_delete, elem) else insert(elem_not_changed, elem) end end return elem_to_add, elem_to_delete, elem_not_changed end -- Calculate the difference of current tags and updated tags of an entity -- return the cql to execute, and error if any -- -- Note: this follows an innevitable "read-before-write" pattern in -- our Cassandra strategy. While unfortunate, this pattern is made -- necessary for Kong to behave in a database-agnostic fashion between -- its supported RDBMs and Cassandra. This pattern is judged acceptable -- given the relatively low number of expected writes (more or less at -- a human pace), and mitigated by the introduction of different levels -- of consistency for read vs. write queries, as well as the linearizable -- consistency of lightweight transactions (IF [NOT] EXISTS). local function build_tags_cql(primary_key, schema, new_tags, ttl, rbw_entity) local tags_to_add, tags_to_delete, tags_not_changed new_tags = (not new_tags or new_tags == null) and {} or new_tags if rbw_entity then if rbw_entity and rbw_entity['tags'] and rbw_entity['tags'] ~= null then tags_to_add, tags_to_delete, tags_not_changed = set_difference(rbw_entity['tags'], new_tags) else tags_to_add = new_tags tags_to_delete = {} tags_not_changed = {} end else tags_to_add = new_tags tags_to_delete = {} tags_not_changed = {} end if #tags_to_add == 0 and #tags_to_delete == 0 then return nil, nil end -- Note: here we assume tags column only exists -- with those entities use id as their primary key local entity_id = primary_key['id'] local cqls = {} local update_cql = "UPDATE tags SET other_tags=? WHERE tag=? AND entity_name=? AND entity_id=?" if ttl then update_cql = update_cql .. fmt(" USING TTL %u", ttl) end for _, tag in ipairs(tags_not_changed) do insert(cqls, { update_cql, { cassandra.set(new_tags), cassandra.text(tag), cassandra.text(schema.name), cassandra.text(entity_id) } } ) end local insert_cql = "INSERT INTO tags (tag, entity_name, entity_id, other_tags) VALUES (?, ?, ?, ?)" if ttl then insert_cql = insert_cql .. fmt(" USING TTL %u", ttl) end for _, tag in ipairs(tags_to_add) do insert(cqls, { insert_cql, { cassandra.text(tag), cassandra.text(schema.name), cassandra.text(entity_id), cassandra.set(new_tags) } } ) end local delete_cql = "DELETE FROM tags WHERE tag=? AND entity_name=? and entity_id=?" if ttl then delete_cql = delete_cql .. fmt(" USING TTL %u", ttl) end for _, tag in ipairs(tags_to_delete) do insert(cqls, { delete_cql, { cassandra.text(tag), cassandra.text(schema.name), cassandra.text(entity_id) } } ) end return cqls, nil end function _M.new(connector, schema, errors) local n_fields = #schema.fields local n_pk = #schema.primary_key local each_pk_field local each_non_pk_field local ttl_field do local non_pk_fields = new_tab(n_fields - n_pk, 0) local pk_fields = new_tab(n_pk, 0) for field_name, field in schema:each_field() do local is_pk for _, pk_field_name in ipairs(schema.primary_key) do if field_name == pk_field_name then is_pk = true break end end insert(is_pk and pk_fields or non_pk_fields, { field_name = field_name, field = field, }) end local function iter(t, i) i = i + 1 local f = t[i] if f then return i, f.field_name, f.field end end each_pk_field = function() return iter, pk_fields, 0 end each_non_pk_field = function() return iter, non_pk_fields, 0 end ttl_field = function() return schema.ttl and non_pk_fields[1] and non_pk_fields[1].field_name end end -- self instanciation local self = { connector = connector, -- instance of kong.db.strategies.cassandra.init schema = schema, errors = errors, each_pk_field = each_pk_field, each_non_pk_field = each_non_pk_field, ttl_field = ttl_field, foreign_keys_db_columns = {}, queries = nil, } -- foreign keys constraints and page_for_ selector methods for field_name, field in schema:each_field() do if field.type == "foreign" then local foreign_schema = field.schema local foreign_pk = foreign_schema.primary_key local foreign_pk_len = #foreign_pk local db_columns = new_tab(foreign_pk_len, 0) for i = 1, foreign_pk_len do for foreign_field_name, foreign_field in foreign_schema:each_field() do if foreign_field_name == foreign_pk[i] then insert(db_columns, { col_name = field_name .. "_" .. foreign_pk[i], foreign_field = foreign_field, foreign_field_name = foreign_field_name, }) end end end local db_columns_args_names = new_tab(#db_columns, 0) for i = 1, #db_columns do -- keep args_names for 'page_for_*' methods db_columns_args_names[i] = db_columns[i].col_name .. " = ?" end db_columns.args_names = concat(db_columns_args_names, " AND ") self.foreign_keys_db_columns[field_name] = db_columns end end -- generate page_for_ method for inverse selection -- e.g. routes:page_for_service(service_pk) for field_name, field in schema:each_field() do if field.type == "foreign" then local method_name = "page_for_" .. field_name local db_columns = self.foreign_keys_db_columns[field_name] local select_foreign_bind_args = {} for _, foreign_key_column in ipairs(db_columns) do insert(select_foreign_bind_args, foreign_key_column.col_name .. " = ?") end self[method_name] = function(self, foreign_key, size, offset, options) return self:page(size, offset, options, foreign_key, db_columns) end end end return setmetatable(self, _mt) end local function deserialize_aggregates(value, field) if field.type == "record" then if type(value) == "string" then value = cjson.decode(value) end elseif field.type == "set" then if type(value) == "table" then for i = 1, #value do value[i] = deserialize_aggregates(value[i], field.elements) end end end if value == nil then return null end return value end function _mt:deserialize_row(row) if not row then error("row must be a table", 2) end -- deserialize rows -- replace `nil` fields with `ngx.null` -- replace `foreign_key` with `foreign = { key = "" }` -- return timestamps in seconds instead of ms for field_name, field in self.schema:each_field() do local ws_unique = field.unique and not field.unique_across_ws if field.type == "foreign" then local db_columns = self.foreign_keys_db_columns[field_name] local has_fk row[field_name] = new_tab(0, #db_columns) for i = 1, #db_columns do local col_name = db_columns[i].col_name if row[col_name] ~= nil then row[field_name][db_columns[i].foreign_field_name] = row[col_name] row[col_name] = nil has_fk = true end end if not has_fk then row[field_name] = null end elseif field.timestamp and row[field_name] ~= nil then row[field_name] = row[field_name] / 1000 elseif field.type == "string" and ws_unique and row[field_name] ~= nil then local value = row[field_name] -- for regular 'unique' values (that are *not* 'unique_across_ws') -- value is of the form "<uuid>:<value>" in the DB: strip the "<uuid>:" if byte(value, 37) == byte(":") then row[field_name] = sub(value, 38) end else row[field_name] = deserialize_aggregates(row[field_name], field) end end return row end local function _select(self, cql, args, ws_id) local rows, err = self.connector:query(cql, args, nil, "read") if not rows then return nil, self.errors:database_error("could not execute selection query: " .. err) end -- lua-cassandra returns `nil` values for Cassandra's `NULL`. We need to -- populate `ngx.null` ourselves local row = rows[1] if not row then return nil end if row.ws_id and ws_id and row.ws_id ~= ws_id then return nil end return self:deserialize_row(row) end local function check_unique(self, primary_key, entity, field_name, ws_id) -- a UNIQUE constaint is set on this field. -- We unfortunately follow a read-before-write pattern in this case, -- but this is made necessary for Kong to behave in a -- database-agnostic fashion between its supported RDBMs and -- Cassandra. local opts = { workspace = ws_id or null } local row, err_t = self:select_by_field(field_name, entity[field_name], opts) if err_t then return nil, err_t end if row then for _, pk_field_name in self.each_pk_field() do if primary_key[pk_field_name] ~= row[pk_field_name] then -- already exists if field_name == "cache_key" then local keys = {} local schema = self.schema for _, k in ipairs(schema.cache_key) do local field = schema.fields[k] if field.type == "foreign" and entity[k] ~= ngx.null then keys[k] = field.schema:extract_pk_values(entity[k]) else keys[k] = entity[k] end end return nil, self.errors:unique_violation(keys) end return nil, self.errors:unique_violation { [field_name] = entity[field_name], } end end end return true end -- Determine if a workspace is to be used, and if so, which one. -- If a workspace is given in `options.workspace` and the entity is -- workspaceable, it will use it. -- If `use_null` is false (indicating the query calling this function -- does not accept global queries) or `options.workspace` is not given, -- then this function will obtain the current workspace UUID from -- the execution context. -- @tparam table schema The schema definition table -- @tparam table option The DAO request options table -- @tparam boolean use_null If true, accept ngx.null as a possible -- value of options.workspace and use it to signal a global query -- @treturn boolean,string?,table? One of the following: -- * false, nil, nil = entity is not workspaceable -- * true, uuid, nil = entity is workspaceable, this is the workspace to use -- * true, nil, nil = entity is workspaceable, but a global query was requested -- * nil, nil, err = database error or selected workspace does not exist local function check_workspace(self, options, use_null) local workspace = options and options.workspace local schema = self.schema local ws_id local has_ws_id = schema.workspaceable if has_ws_id then if use_null and workspace == null then ws_id = nil elseif workspace ~= nil and workspace ~= null then ws_id = workspace else ws_id = get_ws_id() end end -- check that workspace actually exists if ws_id then if not workspaces_strategy then local Entity = require("kong.db.schema.entity") local schema = Entity.new(require("kong.db.schema.entities.workspaces")) workspaces_strategy = _M.new(self.connector, schema, self.errors) end local row, err_t = workspaces_strategy:select({ id = ws_id }) if err_t then return nil, nil, err_t end if not row then return nil, nil, self.errors:invalid_workspace(ws_id) end end return has_ws_id, ws_id end function _mt:insert(entity, options) local schema = self.schema local args = new_tab(#schema.fields, 0) local ttl = schema.ttl and options and options.ttl local has_composite_cache_key = schema.cache_key and #schema.cache_key > 1 local has_ws_id, ws_id, err = check_workspace(self, options, false) if err then return nil, err end local cql_batch local batch_mode local mode = 'insert' local primary_key if schema.fields.tags then primary_key = schema:extract_pk_values(entity) local err_t cql_batch, err_t = build_tags_cql(primary_key, schema, entity["tags"], ttl) if err_t then return nil, err_t end if cql_batch then -- Batch with conditions cannot span multiple tables -- Note this will also disables the APPLIED_COLUMN check mode = 'insert_no_transaction' batch_mode = true end end local cql, err if ttl then cql, err = get_query(self, mode .. "_ttl") if err then return nil, err end cql = fmt(cql, ttl) else cql, err = get_query(self, mode) if err then return nil, err end end -- serialize VALUES clause args for field_name, field in schema:each_field() do if field.type == "foreign" then local foreign_pk = entity[field_name] if foreign_pk ~= null then -- if given, check if this foreign entity exists local exists, err_t = foreign_pk_exists(self, field_name, field, foreign_pk, ws_id) if not exists then return nil, err_t end end local db_columns = self.foreign_keys_db_columns[field_name] serialize_foreign_pk(db_columns, args, nil, foreign_pk, ws_id) else if field.unique and entity[field_name] ~= null and entity[field_name] ~= nil then -- a UNIQUE constaint is set on this field. -- We unfortunately follow a read-before-write pattern in this case, -- but this is made necessary for Kong to behave in a database-agnostic -- fashion between its supported RDBMs and Cassandra. primary_key = primary_key or schema:extract_pk_values(entity) local _, err_t = check_unique(self, primary_key, entity, field_name, ws_id) if err_t then return nil, err_t end end insert(args, serialize_arg(field, entity[field_name], ws_id)) end end if has_composite_cache_key then primary_key = primary_key or schema:extract_pk_values(entity) local _, err_t = check_unique(self, primary_key, entity, "cache_key", ws_id) if err_t then return nil, err_t end insert(args, serialize_arg(cache_key_field, entity["cache_key"], ws_id)) end if has_ws_id then insert(args, serialize_arg(ws_id_field, ws_id, ws_id)) end -- execute query local res, err if batch_mode then -- insert the cql to current entity table at first position insert(cql_batch, 1, {cql, args}) res, err = self.connector:batch(cql_batch, nil, "write", true) else res, err = self.connector:query(cql, args, nil, "write") end if not res then return nil, self.errors:database_error("could not execute insertion query: " .. err) end -- check for linearizable consistency (Paxos) -- in batch_mode, we currently don't know the APPLIED_COLUMN if not batch_mode and res[1][APPLIED_COLUMN] == false then -- lightweight transaction (IF NOT EXISTS) failed, -- retrieve PK values for the PK violation error primary_key = primary_key or schema:extract_pk_values(entity) return nil, self.errors:primary_key_violation(primary_key) end -- return foreign key as if they were fetched from :select() -- this means foreign relationship tables should only contain -- the primary key of the foreign entity clear_tab(res) for field_name, field in schema:each_field() do local value = entity[field_name] if field.type == "foreign" then if value ~= null and value ~= nil then value = field.schema:extract_pk_values(value) else value = null end end res[field_name] = value end if has_ws_id then res.ws_id = ws_id end return res end function _mt:select(primary_key, options) local schema = self.schema local _, ws_id, err = check_workspace(self, options, true) if err then return nil, err end local cql cql, err = get_query(self, "select") if err then return nil, err end local args = new_tab(#schema.primary_key, 0) -- serialize WHERE clause args for i, field_name, field in self.each_pk_field() do args[i] = serialize_arg(field, primary_key[field_name], ws_id) end -- execute query return _select(self, cql, args, ws_id) end function _mt:select_by_field(field_name, field_value, options) local has_ws_id, ws_id, err = check_workspace(self, options, true) if err then return nil, err end if has_ws_id and ws_id == nil and not self.schema.fields[field_name].unique_across_ws then -- fail with error: this is not a database failure, this is programmer error error("cannot select on field " .. field_name .. "without a workspace " .. "because it is not marked unique_across_ws") end local cql, err = get_query(self, "select_with_filter") if err then return nil, err end local field if field_name == "cache_key" then field = cache_key_field else field = self.schema.fields[field_name] if field and field.reference and self.foreign_keys_db_columns[field_name] and self.foreign_keys_db_columns[field_name][1] then field_name = self.foreign_keys_db_columns[field_name][1].col_name end end local select_cql = fmt(cql, field_name .. " = ?") local bind_args = new_tab(1, 0) bind_args[1] = serialize_arg(field, field_value, ws_id) return _select(self, select_cql, bind_args, ws_id) end do local function execute_page(self, cql, args, offset, opts) local rows, err = self.connector:query(cql, args, opts, "read") if not rows then if err:match("Invalid value for the paging state") then return nil, self.errors:invalid_offset(offset, err) end return nil, self.errors:database_error("could not execute page query: " .. err) end local next_offset if rows.meta and rows.meta.paging_state then next_offset = rows.meta.paging_state end rows.meta = nil rows.type = nil return rows, nil, next_offset end local function query_page(self, offset, foreign_key, foreign_key_db_columns, opts) local _, ws_id, err = check_workspace(self, opts, true) if err then return nil, err end local cql local args if not foreign_key then if ws_id then cql, err = get_query(self, "select_with_filter") if err then return nil, err end cql = fmt(cql, "ws_id = ?") args = { serialize_arg(ws_id_field, ws_id, ws_id) } else cql, err = get_query(self, "select_page") if err then return nil, err end end elseif foreign_key and foreign_key_db_columns then args = new_tab(#foreign_key_db_columns, 0) cql, err = get_query(self, "select_with_filter") if err then return nil, err end cql = fmt(cql, foreign_key_db_columns.args_names) serialize_foreign_pk(foreign_key_db_columns, args, nil, foreign_key, ws_id) else error("should provide both of: foreign_key, foreign_key_db_columns", 2) end local rows, err_t, next_offset = execute_page(self, cql, args, offset, opts) if err_t then return nil, err_t end for i = 1, #rows do rows[i] = self:deserialize_row(rows[i]) end return rows, nil, next_offset and encode_base64(next_offset) end --[[ Define the max rounds of queries we will send when filtering entity with tags For each "round" with AND we send queries at max to the number of tags provided and filter in Lua land with entities with all tags provided; for OR we send one request each round. Depending on the distribution of tags attached to entity, it might be possible that number of tags attached with one tag and doesn't has others is larger than the page size provided. In such case, we limit the "rounds" of such filtering and thus limit the total number of queries we send per paging request to be at most (number of tags) * PAGING_MAX_QUERY_ROUNDS Note the number here may not suite all conditions. If the paging request returns too less results, this limit can be bumped up. To archieve less latency for the Admin paging API, this limit can be decreased. ]]-- local PAGING_MAX_QUERY_ROUNDS = 20 -- helper function used in query_page_for_tags to translate -- a row with entity_id with an entity local function dereference_rows(self, entity_ids, entity_count) if not entity_ids then return {}, nil, nil end local entity_index = 0 entity_count = entity_count or #entity_ids local entities = new_tab(entity_count, 0) -- TODO: send one query using IN for i, row in ipairs(entity_ids) do -- TODO: pk name id is hardcoded local entity, err, err_t = self:select{ id = row.entity_id } if err then return nil, err, err_t end if entity then entity_index = entity_index + 1 entities[entity_index] = entity end end return entities, nil, nil end local function query_page_for_tags(self, size, offset, tags, cond, opts) -- TODO: if we don't sort, we can have a performance guidance to user -- to "always put tags with less entity at the front of query" table.sort(tags) local tags_count = #tags -- merge the condition of only one tags to be "and" condition local cond_or = cond == "or" and tags_count > 1 local cql local args local next_offset = opts.paging_state local tags_hash = new_tab(0, tags_count) local cond_and_cql_first, cond_and_cql_next if cond_or then cql = get_query(self, "select_tags_cond_or") args = { cassandra.list(tags) } for _, tag in ipairs(tags) do tags_hash[tag] = true end else cond_and_cql_first = get_query(self, "select_tags_cond_and_first_tag") cond_and_cql_next = get_query(self, "select_tags_cond_and_next_tags") end -- the entity_ids to return local entity_ids = new_tab(size, 0) local entity_count = 0 -- a temporary table for current query local current_entity_ids = new_tab(size, 0) local rows, err_t for _=1, PAGING_MAX_QUERY_ROUNDS, 1 do local current_next_offset local current_entity_count = 0 if cond_or then rows, err_t, next_offset = execute_page(self, cql, args, offset, opts) if err_t then return nil, err_t, nil end clear_tab(current_entity_ids) for _, row in ipairs(rows) do local row_tag = row.tag local duplicated = false for _, attached_tag in ipairs(row.other_tags) do -- To ensure we don't add same entity_id twice (during current -- admin api request or across different requests), we only add -- entity_id when it first appears in the result set. -- That means we don't add current row if row.tag -- 1. is a matching tag towards provided tags (that means this -- entity_id can potentially be duplicated), and -- 2. is not the alphabetically smallest tag among all its tags -- (as in row.other_tags) if tags_hash[attached_tag] and attached_tag < row_tag then duplicated = true break end end if not duplicated then current_entity_count = current_entity_count + 1 current_entity_ids[current_entity_count] = row.entity_id end end else for i=1, #tags, 1 do local tag = tags[i] if i == 1 then opts.paging_state = next_offset cql = cond_and_cql_first -- TODO: cache me args = { cassandra.text(tag) } else opts.paging_state = nil cql = cond_and_cql_next -- TODO: cache me args = { cassandra.text(tag), cassandra.list(current_entity_ids) } end rows, err_t, current_next_offset = execute_page(self, cql, args, nil, opts) if err_t then return nil, err_t, nil end if i == 1 then next_offset = current_next_offset end -- No rows left, stop filtering if not rows or #rows == 0 then current_entity_count = 0 break end clear_tab(current_entity_ids) current_entity_count = 0 for i, row in ipairs(rows) do current_entity_count = current_entity_count + 1 current_entity_ids[current_entity_count] = row.entity_id end end end if current_entity_count > 0 then for i=1, current_entity_count do entity_count = entity_count + 1 entity_ids[entity_count] = { entity_id = current_entity_ids[i] } if entity_count >= size then -- shouldn't be "larger than" actually break end end if entity_count < size then -- next time we only read what we left -- to return a row with length of `size` opts.page_size = size - entity_count end end -- break the loop either we read enough rows -- or no more data available in the datastore if entity_count >= size or not next_offset then break end end local entities, err_t = dereference_rows(self, entity_ids, entity_count) if err_t then return nil, err_t, nil end if next_offset and entity_count < size then -- Note: don't cache ngx.log so we can test in 02-intergration/07-tags_spec.lua ngx.log(ngx.WARN, "maximum ", PAGING_MAX_QUERY_ROUNDS, " rounds exceeded ", "without retrieving required size of rows, ", "consider lower the sparsity of tags, or increase the paging size per request" ) end return entities, nil, next_offset and encode_base64(next_offset) end function _mt:page(size, offset, options, foreign_key, foreign_key_db_columns) local opts = new_tab(0, 2) if not size then size = self.connector:get_page_size(options) end if offset then local offset_decoded = decode_base64(offset) if not offset_decoded then return nil, self.errors:invalid_offset(offset, "bad base64 encoding") end offset = offset_decoded end opts.page_size = size opts.paging_state = offset opts.workspace = options and options.workspace if not foreign_key and options and options.tags then return query_page_for_tags(self, size, offset, options.tags, options.tags_cond, opts) end return query_page(self, offset, foreign_key, foreign_key_db_columns, opts) end end do local function update(self, primary_key, entity, mode, options) local schema = self.schema local ttl = schema.ttl and options and options.ttl local has_ws_id, ws_id, err = check_workspace(self, options, false) if err then return nil, err end local cql_batch local batch_mode if schema.fields.tags then local rbw_entity, err_t = self:select(primary_key, { workspace = ws_id or null }) if err_t then return nil, err_t end cql_batch, err_t = build_tags_cql(primary_key, schema, entity["tags"], ttl, rbw_entity) if err_t then return nil, err_t end if cql_batch then -- Batch with conditions cannot span multiple tables -- Note this will also disables the APPLIED_COLUMN check mode = 'upsert' batch_mode = true end end local query_name if ttl then query_name = mode .. "_ttl" else query_name = mode end local cql, err = get_query(self, query_name) if err then return nil, err end local args = new_tab(#schema.fields, 0) local args_names = new_tab(#schema.fields, 0) -- serialize SET clause args for _, field_name, field in self.each_non_pk_field() do if entity[field_name] ~= nil then if field.unique and entity[field_name] ~= null then local _, err_t = check_unique(self, primary_key, entity, field_name, ws_id) if err_t then return nil, err_t end end if field.type == "foreign" then local foreign_pk = entity[field_name] if foreign_pk ~= null then -- if given, check if this foreign entity exists local exists, err_t = foreign_pk_exists(self, field_name, field, foreign_pk, ws_id) if not exists then return nil, err_t end end local db_columns = self.foreign_keys_db_columns[field_name] serialize_foreign_pk(db_columns, args, args_names, foreign_pk, ws_id) else insert(args, serialize_arg(field, entity[field_name], ws_id)) insert(args_names, field_name) end end end local has_composite_cache_key = schema.cache_key and #schema.cache_key > 1 if has_composite_cache_key then local _, err_t = check_unique(self, primary_key, entity, "cache_key", ws_id) if err_t then return nil, err_t end insert(args, serialize_arg(cache_key_field, entity["cache_key"], ws_id)) end if has_ws_id then insert(args, serialize_arg(ws_id_field, ws_id, ws_id)) end -- serialize WHERE clause args for i, field_name, field in self.each_pk_field() do insert(args, serialize_arg(field, primary_key[field_name], ws_id)) end -- inject SET clause bindings local n_args = #args_names local update_columns_binds = new_tab(n_args, 0) for i = 1, n_args do update_columns_binds[i] = args_names[i] .. " = ?" end if has_composite_cache_key then insert(update_columns_binds, "cache_key = ?") end if has_ws_id then insert(update_columns_binds, "ws_id = ?") end if ttl then cql = fmt(cql, ttl, concat(update_columns_binds, ", ")) else cql = fmt(cql, concat(update_columns_binds, ", ")) end -- execute query local res, err if batch_mode then -- insert the cql to current entity table at first position insert(cql_batch, 1, {cql, args}) res, err = self.connector:batch(cql_batch, nil, "write", true) else res, err = self.connector:query(cql, args, nil, "write") end if not res then return nil, self.errors:database_error("could not execute update query: " .. err) end if not batch_mode and mode == "update" and res[1][APPLIED_COLUMN] == false then return nil, self.errors:not_found(primary_key) end -- SELECT after write local row, err_t = self:select(primary_key, { workspace = ws_id or null }) if err_t then return nil, err_t end if not row then return nil, self.errors:not_found(primary_key) end return row end local function update_by_field(self, field_name, field_value, entity, mode, options) local row, err_t = self:select_by_field(field_name, field_value) if err_t then return nil, err_t end if not row then if mode == "upsert" then row = entity row[field_name] = field_value else return nil, self.errors:not_found_by_field({ [field_name] = field_value, }) end end local pk = self.schema:extract_pk_values(row) return self[mode](self, pk, entity, options) end function _mt:update(primary_key, entity, options) return update(self, primary_key, entity, "update", options) end function _mt:upsert(primary_key, entity, options) return update(self, primary_key, entity, "upsert", options) end function _mt:update_by_field(field_name, field_value, entity, options) return update_by_field(self, field_name, field_value, entity, "update", options) end function _mt:upsert_by_field(field_name, field_value, entity, options) return update_by_field(self, field_name, field_value, entity, "upsert", options) end end do local function select_by_foreign_key(self, foreign_schema, foreign_field_name, foreign_key, ws_id) local n_fields = #foreign_schema.fields local strategy = _M.new(self.connector, foreign_schema, self.errors) local cql, err = get_query(strategy, "select_with_filter") if err then return nil, err end local args = new_tab(n_fields, 0) local args_names = new_tab(n_fields, 0) if foreign_field_name then local db_columns = strategy.foreign_keys_db_columns[foreign_field_name] serialize_foreign_pk(db_columns, args, args_names, foreign_key, ws_id) else -- workspaces don't have a foreign_field_name -- and the query needs to be different than in "regular" foreign keys cql = fmt(cql, "ws_id = ?") args = { serialize_arg(ws_id_field, ws_id, ws_id) } end local n_args = #args_names local where_clause_binds = new_tab(n_args, 0) for i = 1, n_args do where_clause_binds[i] = args_names[i] .. " = ?" end cql = fmt(cql, concat(where_clause_binds, " AND ")) return _select(strategy, cql, args, ws_id) end function _mt:delete(primary_key, options) local _, ws_id, err = check_workspace(self, options, true) if err then return nil, err end if self.schema.name == "workspaces" then ws_id = primary_key.id end local schema = self.schema local cql, err = get_query(self, "delete") if err then return nil, err end local args = new_tab(#schema.primary_key, 0) local constraints = schema:get_constraints() for i = 1, #constraints do local constraint = constraints[i] -- foreign keys could be pointing to this entity -- this mimics the "ON DELETE" constraint of supported -- RDBMs (e.g. PostgreSQL) -- -- The possible behaviors on such a constraint are: -- * RESTRICT (default) -- * CASCADE (on_delete = "cascade", NYI) -- * SET NULL (NYI) local behavior = constraint.on_delete or "restrict" if behavior == "restrict" then local row, err_t = select_by_foreign_key(self, constraint.schema, constraint.field_name, primary_key, ws_id) if err_t then return nil, err_t end if row then -- a row is referring to this entity, we cannot delete it. -- deleting the parent entity would violate the foreign key -- constraint return nil, self.errors:foreign_key_violation_restricted(schema.name, constraint.schema.name) end elseif behavior == "cascade" then local strategy = _M.new(self.connector, constraint.schema, self.errors) local method = "page_for_" .. constraint.field_name local pager = function(size, offset) return strategy[method](strategy, primary_key, size, offset) end for row, err in iteration.by_row(self, pager) do if err then return nil, self.errors:database_error("could not gather " .. "associated entities " .. "for delete cascade: " .. err) end local row_pk = constraint.schema:extract_pk_values(row) local _ _, err = strategy:delete(row_pk) if err then return nil, self.errors:database_error("could not cascade " .. "delete entity: " .. err) end end end end -- serialize WHERE clause args for i, field_name, field in self.each_pk_field() do args[i] = serialize_arg(field, primary_key[field_name], ws_id) end local rbw_entity, err_t if schema.workspaceable or schema.fields.tags then rbw_entity, err_t = self:select(primary_key, { workspace = ws_id or null }) if err_t then return nil, err_t end if not rbw_entity then return true end end local cql_batch if schema.fields.tags then cql_batch, err_t = build_tags_cql(primary_key, self.schema, {}, nil, rbw_entity) if err_t then return nil, err_t end end -- execute query local res, err if cql_batch then insert(cql_batch, 1, {cql, args} ) res, err = self.connector:batch(cql_batch, nil, "write", true) else res, err = self.connector:query(cql, args, nil, "write") end if not res then return nil, self.errors:database_error("could not execute deletion query: " .. err) end return true end end function _mt:delete_by_field(field_name, field_value, options) local row, err_t = self:select_by_field(field_name, field_value) if err_t then return nil, err_t end if not row then return nil end local pk = self.schema:extract_pk_values(row) return self:delete(pk, options) end function _mt:truncate(options) return self.connector:truncate_table(self.schema.name, options) end return _M