From ca85e735b0618f8246f00dd78f692bdf61503b3c Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Fri, 23 May 2025 17:10:12 -0500 Subject: [PATCH] update sqlite wrapper and sandbox with utils --- runner/lua/sandbox.lua | 275 +++++++++++++++--------------- runner/lua/sqlite.lua | 368 +++++++++++++---------------------------- 2 files changed, 250 insertions(+), 393 deletions(-) diff --git a/runner/lua/sandbox.lua b/runner/lua/sandbox.lua index f09e249..97d67c7 100644 --- a/runner/lua/sandbox.lua +++ b/runner/lua/sandbox.lua @@ -13,50 +13,50 @@ __EXIT_SENTINEL = {} -- Unique object for exit identification -- ====================================================================== function exit() - error(__EXIT_SENTINEL) + error(__EXIT_SENTINEL) end -- Create environment inheriting from _G function __create_env(ctx) - local env = setmetatable({}, {__index = _G}) + local env = setmetatable({}, {__index = _G}) - if ctx then - env.ctx = ctx - end + if ctx then + env.ctx = ctx + end - if __setup_require then - __setup_require(env) - end + if __setup_require then + __setup_require(env) + end - return env + return env end -- Execute script with clean environment function __execute_script(fn, ctx) - __http_response = nil + __http_response = nil - local env = __create_env(ctx) - env.exit = exit - setfenv(fn, env) + local env = __create_env(ctx) + env.exit = exit + setfenv(fn, env) - local ok, result = pcall(fn) - if not ok then - if result == __EXIT_SENTINEL then - return - end + local ok, result = pcall(fn) + if not ok then + if result == __EXIT_SENTINEL then + return + end - error(result, 0) - end + error(result, 0) + end - return result + return result end -- Ensure __http_response exists, then return it function __ensure_response() - if not __http_response then - __http_response = {} - end - return __http_response + if not __http_response then + __http_response = {} + end + return __http_response end -- ====================================================================== @@ -137,15 +137,15 @@ local http = { for k, v in pairs(params) do if type(v) == "table" then for _, item in ipairs(v) do - table.insert(query, k .. "=" .. tostring(item)) + table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(item))) end else - table.insert(query, k .. "=" .. tostring(v)) + table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(v))) end end if #query > 0 then - if base_url:find("?") then + if string.contains(base_url, "?") then return base_url .. "&" .. table.concat(query, "&") else return base_url .. "?" .. table.concat(query, "&") @@ -198,89 +198,89 @@ end -- ====================================================================== local cookie = { - -- Set a cookie - set = function(name, value, options) - if type(name) ~= "string" then - error("cookie.set: name must be a string", 2) - end + -- Set a cookie + set = function(name, value, options) + if type(name) ~= "string" then + error("cookie.set: name must be a string", 2) + end - local resp = __ensure_response() - resp.cookies = resp.cookies or {} + local resp = __ensure_response() + resp.cookies = resp.cookies or {} - local opts = options or {} - local cookie = { - name = name, - value = value or "", - path = opts.path or "/", - domain = opts.domain - } + local opts = options or {} + local cookie = { + name = name, + value = value or "", + path = opts.path or "/", + domain = opts.domain + } - if opts.expires then - if type(opts.expires) == "number" then - if opts.expires > 0 then - cookie.max_age = opts.expires - local now = os.time() - cookie.expires = now + opts.expires - elseif opts.expires < 0 then - cookie.expires = 1 - cookie.max_age = 0 - end - -- opts.expires == 0: Session cookie (omitting both expires and max-age) - end - end + if opts.expires then + if type(opts.expires) == "number" then + if opts.expires > 0 then + cookie.max_age = opts.expires + local now = os.time() + cookie.expires = now + opts.expires + elseif opts.expires < 0 then + cookie.expires = 1 + cookie.max_age = 0 + end + -- opts.expires == 0: Session cookie (omitting both expires and max-age) + end + end - cookie.secure = (opts.secure ~= false) - cookie.http_only = (opts.http_only ~= false) + cookie.secure = (opts.secure ~= false) + cookie.http_only = (opts.http_only ~= false) - if opts.same_site then - local valid_values = {none = true, lax = true, strict = true} - local same_site = string.lower(opts.same_site) + if opts.same_site then + local same_site = string.trim(opts.same_site):lower() + local valid_values = {none = true, lax = true, strict = true} - if not valid_values[same_site] then - error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2) - end + if not valid_values[same_site] then + error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2) + end - -- If SameSite=None, the cookie must be secure - if same_site == "none" and not cookie.secure then - cookie.secure = true - end + -- If SameSite=None, the cookie must be secure + if same_site == "none" and not cookie.secure then + cookie.secure = true + end - cookie.same_site = opts.same_site - else - cookie.same_site = "Lax" - end + cookie.same_site = opts.same_site + else + cookie.same_site = "Lax" + end - table.insert(resp.cookies, cookie) - return true - end, + table.insert(resp.cookies, cookie) + return true + end, - -- Get a cookie value - get = function(name) - if type(name) ~= "string" then - error("cookie.get: name must be a string", 2) - end + -- Get a cookie value + get = function(name) + if type(name) ~= "string" then + error("cookie.get: name must be a string", 2) + end - local env = getfenv(2) + local env = getfenv(2) - if env.ctx and env.ctx.cookies then - return env.ctx.cookies[name] - end + if env.ctx and env.ctx.cookies then + return env.ctx.cookies[name] + end - if env.ctx and env.ctx._request_cookies then - return env.ctx._request_cookies[name] - end + if env.ctx and env.ctx._request_cookies then + return env.ctx._request_cookies[name] + end - return nil - end, + return nil + end, - -- Remove a cookie - remove = function(name, path, domain) - if type(name) ~= "string" then - error("cookie.remove: name must be a string", 2) - end + -- Remove a cookie + remove = function(name, path, domain) + if type(name) ~= "string" then + error("cookie.remove: name must be a string", 2) + end - return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain}) - end + return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain}) + end } -- ====================================================================== @@ -311,7 +311,7 @@ local session = { end local resp = __ensure_response() - resp.session = resp.session or {} + resp.session = resp.session or {} resp.session[key] = value end, @@ -360,7 +360,6 @@ local session = { local resp = __ensure_response() resp.session = {} - resp.session["__clear_all"] = true end } @@ -370,51 +369,52 @@ local session = { -- ====================================================================== local csrf = { - generate = function() - local token = util.generate_token(32) - session.set("_csrf_token", token) - return token - end, + generate = function() + local token = util.generate_token(32) + session.set("_csrf_token", token) + return token + end, - field = function() - local token = session.get("_csrf_token") - if not token then - token = csrf.generate() - end - return string.format('', token) - end, + field = function() + local token = session.get("_csrf_token") + if not token then + token = csrf.generate() + end + return string.format('', + util.html_special_chars(token)) + end, - validate = function() + validate = function() local env = getfenv(2) - local token = false + local token = false if env.ctx and env.ctx.session and env.ctx.session.data then token = env.ctx.session.data["_csrf_token"] end - if not token then - http.set_status(403) - __http_response.body = "CSRF validation failed" - exit() - end + if not token then + http.set_status(403) + __http_response.body = "CSRF validation failed" + exit() + end - local request_token = nil - if env.ctx and env.ctx.form then - request_token = env.ctx.form._csrf_token - end + local request_token = nil + if env.ctx and env.ctx.form then + request_token = env.ctx.form._csrf_token + end - if not request_token and env.ctx and env.ctx._request_headers then - request_token = env.ctx._request_headers["x-csrf-token"] or - env.ctx._request_headers["csrf-token"] - end + if not request_token and env.ctx and env.ctx._request_headers then + request_token = env.ctx._request_headers["x-csrf-token"] or + env.ctx._request_headers["csrf-token"] + end - if not request_token or request_token ~= token then - http.set_status(403) - __http_response.body = "CSRF validation failed" - exit() - end + if not request_token or request_token ~= token then + http.set_status(403) + __http_response.body = "CSRF validation failed" + exit() + end - return true - end + return true + end } -- ====================================================================== @@ -423,11 +423,6 @@ local csrf = { -- Template processing with code execution _G.render = function(template_str, env) - local function escape_html(s) - local entities = {['&']='&', ['<']='<', ['>']='>', ['"']='"', ["'"]='''} - return (s:gsub([=[["><'&]]=], entities)) - end - local function get_line(s, ln) for line in s:gmatch("([^\n]*)\n?") do if ln == 1 then return line end @@ -516,7 +511,7 @@ _G.render = function(template_str, env) setfenv(fn, runtime_env) local output_buffer = {} - fn(tostring, escape_html, output_buffer, 0) + fn(tostring, util.html_special_chars, output_buffer, 0) return table.concat(output_buffer) end @@ -550,8 +545,7 @@ _G.parse = function(template_str, env) local value = env[name] local str = tostring(value or "") if escaped then - local entities = {['&']='&', ['<']='<', ['>']='>', ['"']='"', ["'"]='''} - str = str:gsub([=[["><'&]]=], entities) + str = util.html_special_chars(str) end table.insert(output, str) @@ -591,8 +585,7 @@ _G.iparse = function(template_str, values) local value = values[value_index] local str = tostring(value or "") if escaped then - local entities = {['&']='&', ['<']='<', ['>']='>', ['"']='"', ["'"]='''} - str = str:gsub([=[["><'&]]=], entities) + str = util.html_special_chars(str) end table.insert(output, str) diff --git a/runner/lua/sqlite.lua b/runner/lua/sqlite.lua index ccfe458..0a24190 100644 --- a/runner/lua/sqlite.lua +++ b/runner/lua/sqlite.lua @@ -1,70 +1,37 @@ --- Simplified SQLite wrapper --- Connection is now lightweight with persistent connection tracking - --- Helper function to handle parameters -local function handle_params(params, ...) - -- If params is a table, use it for named parameters - if type(params) == "table" then - return params - end - - -- If we have varargs, collect them for positional parameters +local function normalize_params(params, ...) + if type(params) == "table" then return params end local args = {...} if #args > 0 or params ~= nil then - -- Include the first param in the args table.insert(args, 1, params) return args end - return nil end --- Connection metatable local connection_mt = { __index = { - -- Execute a query and return results as a table query = function(self, query, params, ...) if type(query) ~= "string" then error("connection:query: query must be a string", 2) end - -- Execute with proper connection tracking - local results, token - if params == nil and select('#', ...) == 0 then - results, token = __sqlite_query(self.db_name, query, nil, self.conn_token) - elseif type(params) == "table" then - results, token = __sqlite_query(self.db_name, query, params, self.conn_token) - else - local args = {params, ...} - results, token = __sqlite_query(self.db_name, query, args, self.conn_token) - end - + local normalized_params = normalize_params(params, ...) + local results, token = __sqlite_query(self.db_name, query, normalized_params, self.conn_token) self.conn_token = token return results end, - -- Execute a statement and return affected rows exec = function(self, query, params, ...) if type(query) ~= "string" then error("connection:exec: query must be a string", 2) end - -- Execute with proper connection tracking - local affected, token - if params == nil and select('#', ...) == 0 then - affected, token = __sqlite_exec(self.db_name, query, nil, self.conn_token) - elseif type(params) == "table" then - affected, token = __sqlite_exec(self.db_name, query, params, self.conn_token) - else - local args = {params, ...} - affected, token = __sqlite_exec(self.db_name, query, args, self.conn_token) - end - + local normalized_params = normalize_params(params, ...) + local affected, token = __sqlite_exec(self.db_name, query, normalized_params, self.conn_token) self.conn_token = token return affected end, - -- Close the connection (release back to pool) close = function(self) if self.conn_token then local success = __sqlite_close(self.conn_token) @@ -74,50 +41,61 @@ local connection_mt = { return false end, - -- Insert a row or multiple rows with a single query insert = function(self, table_name, data, columns) if type(data) ~= "table" then error("connection:insert: data must be a table", 2) end - -- Case 1: Named columns with array data - if columns and type(columns) == "table" then - -- Check if we have multiple rows - if #data > 0 and type(data[1]) == "table" then - -- Build a single multi-value INSERT - local placeholders = {} - local values = {} - local params = {} - local param_index = 1 + -- Single object: {col1=val1, col2=val2} + if data[1] == nil and next(data) ~= nil then + local cols = table.keys(data) + local placeholders = table.map(cols, function(_, i) return ":p" .. i end) + local params = {} + for i, col in ipairs(cols) do + params["p" .. i] = data[col] + end - for i, row in ipairs(data) do + local query = string.format( + "INSERT INTO %s (%s) VALUES (%s)", + table_name, + table.concat(cols, ", "), + table.concat(placeholders, ", ") + ) + return self:exec(query, params) + end + + -- Array data with columns + if columns and type(columns) == "table" then + if #data > 0 and type(data[1]) == "table" then + -- Multiple rows + local value_groups = {} + local params = {} + local param_idx = 1 + + for _, row in ipairs(data) do local row_placeholders = {} - for j, _ in ipairs(columns) do - local param_name = "p" .. param_index + for j = 1, #columns do + local param_name = "p" .. param_idx table.insert(row_placeholders, ":" .. param_name) params[param_name] = row[j] - param_index = param_index + 1 + param_idx = param_idx + 1 end - table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")") + table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")") end local query = string.format( "INSERT INTO %s (%s) VALUES %s", table_name, table.concat(columns, ", "), - table.concat(placeholders, ", ") + table.concat(value_groups, ", ") ) - return self:exec(query, params) else - -- Single row with defined columns - local placeholders = {} + -- Single row array + local placeholders = table.map(columns, function(_, i) return ":p" .. i end) local params = {} - - for i, col in ipairs(columns) do - local param_name = "p" .. i - table.insert(placeholders, ":" .. param_name) - params[param_name] = data[i] + for i = 1, #columns do + params["p" .. i] = data[i] end local query = string.format( @@ -126,161 +104,71 @@ local connection_mt = { table.concat(columns, ", "), table.concat(placeholders, ", ") ) - return self:exec(query, params) end end - -- Case 2: Object-style single row {col1=val1, col2=val2} - if data[1] == nil and next(data) ~= nil then - local columns = {} - local placeholders = {} + -- Array of objects + if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then + local cols = table.keys(data[1]) + local value_groups = {} local params = {} + local param_idx = 1 - for col, val in pairs(data) do - table.insert(columns, col) - local param_name = "p" .. #columns - table.insert(placeholders, ":" .. param_name) - params[param_name] = val + for _, row in ipairs(data) do + local row_placeholders = {} + for _, col in ipairs(cols) do + local param_name = "p" .. param_idx + table.insert(row_placeholders, ":" .. param_name) + params[param_name] = row[col] + param_idx = param_idx + 1 + end + table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")") end local query = string.format( - "INSERT INTO %s (%s) VALUES (%s)", + "INSERT INTO %s (%s) VALUES %s", table_name, - table.concat(columns, ", "), - table.concat(placeholders, ", ") + table.concat(cols, ", "), + table.concat(value_groups, ", ") ) - return self:exec(query, params) end - -- Case 3: Array of rows without predefined columns - if #data > 0 and type(data[1]) == "table" then - -- Extract columns from the first row - local first_row = data[1] - local inferred_columns = {} - - -- Determine if first row is array or object - local is_array = first_row[1] ~= nil - - if is_array then - -- Cannot infer column names from array - error("connection:insert: column names required for array data", 2) - else - -- Get columns from object keys - for col, _ in pairs(first_row) do - table.insert(inferred_columns, col) - end - - -- Build multi-value INSERT - local placeholders = {} - local params = {} - local param_index = 1 - - for _, row in ipairs(data) do - local row_placeholders = {} - for _, col in ipairs(inferred_columns) do - local param_name = "p" .. param_index - table.insert(row_placeholders, ":" .. param_name) - params[param_name] = row[col] - param_index = param_index + 1 - end - table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")") - end - - local query = string.format( - "INSERT INTO %s (%s) VALUES %s", - table_name, - table.concat(inferred_columns, ", "), - table.concat(placeholders, ", ") - ) - - return self:exec(query, params) - end - end - error("connection:insert: invalid data format", 2) end, - -- Update rows in a table update = function(self, table_name, data, where, where_params, ...) - if type(data) ~= "table" then - error("connection:update: data must be a table", 2) - end - - -- Fast path for when there's no data - if next(data) == nil then + if type(data) ~= "table" or next(data) == nil then return 0 end local sets = {} local params = {} - local param_index = 1 + local param_idx = 1 for col, val in pairs(data) do - local param_name = "p" .. param_index + local param_name = "p" .. param_idx table.insert(sets, col .. " = :" .. param_name) params[param_name] = val - param_index = param_index + 1 + param_idx = param_idx + 1 end - local query = string.format( - "UPDATE %s SET %s", - table_name, - table.concat(sets, ", ") - ) + local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", ")) if where then query = query .. " WHERE " .. where - if where_params then - if type(where_params) == "table" then - -- Handle named parameters in WHERE clause - for k, v in pairs(where_params) do - local param_name - if type(k) == "string" and k:sub(1, 1) == ":" then - param_name = k:sub(2) + local normalized = normalize_params(where_params, ...) + if type(normalized) == "table" then + for k, v in pairs(normalized) do + if type(k) == "string" then + params[k] = v else - param_name = "w" .. param_index - -- Replace the placeholder in the WHERE clause - where = where:gsub(":" .. k, ":" .. param_name) + params["w" .. param_idx] = v + param_idx = param_idx + 1 end - params[param_name] = v - param_index = param_index + 1 end - else - -- Handle positional parameters (? placeholders) - local args = {where_params, ...} - local pos = 1 - local offset = 0 - - -- Replace ? with named parameters - while true do - local start_pos, end_pos = where:find("?", pos) - if not start_pos then break end - - local param_name = "w" .. param_index - local replacement = ":" .. param_name - - where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1) - - if args[pos - offset] ~= nil then - params[param_name] = args[pos - offset] - else - params[param_name] = nil - end - - param_index = param_index + 1 - pos = start_pos + #replacement - offset = offset + 1 - end - - query = string.format( - "UPDATE %s SET %s WHERE %s", - table_name, - table.concat(sets, ", "), - where - ) end end end @@ -288,129 +176,108 @@ local connection_mt = { return self:exec(query, params) end, - -- Create a new table create_table = function(self, table_name, ...) - local columns = {} - local indices = {} + local column_definitions = {} + local index_definitions = {} - -- Process all arguments - for _, def in ipairs({...}) do - if type(def) == "string" then - -- Check if it's an index definition - local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)") + for _, def_string in ipairs({...}) do + if type(def_string) == "string" then + local is_unique = false + local index_def = def_string - if index_def then - -- Parse index definition - local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)") - - if index_name and columns_str then - -- Split columns by comma - local index_columns = {} - for col in columns_str:gmatch("[^,]+") do - table.insert(index_columns, col:match("^%s*(.-)%s*$")) -- Trim whitespace - end - - table.insert(indices, { - name = index_name, - columns = index_columns, - unique = (index_type == "UNIQUE INDEX:") - }) - end + if string.starts_with(def_string, "UNIQUE INDEX:") then + is_unique = true + index_def = string.trim(def_string:sub(14)) + elseif string.starts_with(def_string, "INDEX:") then + index_def = string.trim(def_string:sub(7)) else - -- Regular column definition - table.insert(columns, def) + table.insert(column_definitions, def_string) + goto continue + end + + local paren_pos = index_def:find("%(") + if not paren_pos then goto continue end + + local index_name = string.trim(index_def:sub(1, paren_pos - 1)) + local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$") + if not columns_part then goto continue end + + local columns = table.map(string.split(columns_part, ","), string.trim) + + if #columns > 0 then + table.insert(index_definitions, { + name = index_name, + columns = columns, + unique = is_unique + }) end end + ::continue:: end - if #columns == 0 then - error("connection:create_table: no columns specified", 2) + if #column_definitions == 0 then + error("connection:create_table: no column definitions specified for table " .. table_name, 2) end - -- Build combined statement for table and indices local statements = {} - -- Add the CREATE TABLE statement table.insert(statements, string.format( "CREATE TABLE IF NOT EXISTS %s (%s)", table_name, - table.concat(columns, ", ") + table.concat(column_definitions, ", ") )) - -- Add CREATE INDEX statements - for _, idx in ipairs(indices) do - local unique = idx.unique and "UNIQUE " or "" - + for _, idx in ipairs(index_definitions) do + local unique_prefix = idx.unique and "UNIQUE " or "" table.insert(statements, string.format( "CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)", - unique, + unique_prefix, idx.name, table_name, table.concat(idx.columns, ", ") )) end - -- Execute all statements in a single transaction - local combined_sql = table.concat(statements, ";\n") - return self:exec(combined_sql) + return self:exec(table.concat(statements, ";\n")) end, - -- Delete rows - delete = function(self, table_name, where, params) + delete = function(self, table_name, where, params, ...) local query = "DELETE FROM " .. table_name - if where then query = query .. " WHERE " .. where end - - return self:exec(query, params) + return self:exec(query, normalize_params(params, ...)) end, - -- Get one row efficiently get_one = function(self, query, params, ...) if type(query) ~= "string" then error("connection:get_one: query must be a string", 2) end - -- Add LIMIT 1 to query if not already limited local limited_query = query - if not query:lower():match("limit%s+%d+") then + if not string.contains(query:lower(), "limit") then limited_query = query .. " LIMIT 1" end - local results - if select('#', ...) > 0 then - results = self:query(limited_query, params, ...) - else - results = self:query(limited_query, params) - end - + local results = self:query(limited_query, normalize_params(params, ...)) return results[1] end, - -- Begin transaction begin = function(self) return self:exec("BEGIN TRANSACTION") end, - -- Commit transaction commit = function(self) return self:exec("COMMIT") end, - -- Rollback transaction rollback = function(self) return self:exec("ROLLBACK") end, - -- Transaction wrapper function transaction = function(self, callback) self:begin() - - local success, result = pcall(function() - return callback(self) - end) - + local success, result = pcall(callback, self) if success then self:commit() return result @@ -422,16 +289,13 @@ local connection_mt = { } } --- Create sqlite() function that returns a connection object return function(db_name) if type(db_name) ~= "string" then error("sqlite: database name must be a string", 2) end - local conn = { + return setmetatable({ db_name = db_name, - conn_token = nil -- Will be populated on first query/exec - } - - return setmetatable(conn, connection_mt) + conn_token = nil + }, connection_mt) end