update sqlite wrapper and sandbox with utils

This commit is contained in:
Sky Johnson 2025-05-23 17:10:12 -05:00
parent 5913bc4ba3
commit ca85e735b0
2 changed files with 250 additions and 393 deletions

View File

@ -13,50 +13,50 @@ __EXIT_SENTINEL = {} -- Unique object for exit identification
-- ====================================================================== -- ======================================================================
function exit() function exit()
error(__EXIT_SENTINEL) error(__EXIT_SENTINEL)
end end
-- Create environment inheriting from _G -- Create environment inheriting from _G
function __create_env(ctx) function __create_env(ctx)
local env = setmetatable({}, {__index = _G}) local env = setmetatable({}, {__index = _G})
if ctx then if ctx then
env.ctx = ctx env.ctx = ctx
end end
if __setup_require then if __setup_require then
__setup_require(env) __setup_require(env)
end end
return env return env
end end
-- Execute script with clean environment -- Execute script with clean environment
function __execute_script(fn, ctx) function __execute_script(fn, ctx)
__http_response = nil __http_response = nil
local env = __create_env(ctx) local env = __create_env(ctx)
env.exit = exit env.exit = exit
setfenv(fn, env) setfenv(fn, env)
local ok, result = pcall(fn) local ok, result = pcall(fn)
if not ok then if not ok then
if result == __EXIT_SENTINEL then if result == __EXIT_SENTINEL then
return return
end end
error(result, 0) error(result, 0)
end end
return result return result
end end
-- Ensure __http_response exists, then return it -- Ensure __http_response exists, then return it
function __ensure_response() function __ensure_response()
if not __http_response then if not __http_response then
__http_response = {} __http_response = {}
end end
return __http_response return __http_response
end end
-- ====================================================================== -- ======================================================================
@ -137,15 +137,15 @@ local http = {
for k, v in pairs(params) do for k, v in pairs(params) do
if type(v) == "table" then if type(v) == "table" then
for _, item in ipairs(v) do for _, item in ipairs(v) do
table.insert(query, k .. "=" .. tostring(item)) table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(item)))
end end
else else
table.insert(query, k .. "=" .. tostring(v)) table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(v)))
end end
end end
if #query > 0 then if #query > 0 then
if base_url:find("?") then if string.contains(base_url, "?") then
return base_url .. "&" .. table.concat(query, "&") return base_url .. "&" .. table.concat(query, "&")
else else
return base_url .. "?" .. table.concat(query, "&") return base_url .. "?" .. table.concat(query, "&")
@ -198,89 +198,89 @@ end
-- ====================================================================== -- ======================================================================
local cookie = { local cookie = {
-- Set a cookie -- Set a cookie
set = function(name, value, options) set = function(name, value, options)
if type(name) ~= "string" then if type(name) ~= "string" then
error("cookie.set: name must be a string", 2) error("cookie.set: name must be a string", 2)
end end
local resp = __ensure_response() local resp = __ensure_response()
resp.cookies = resp.cookies or {} resp.cookies = resp.cookies or {}
local opts = options or {} local opts = options or {}
local cookie = { local cookie = {
name = name, name = name,
value = value or "", value = value or "",
path = opts.path or "/", path = opts.path or "/",
domain = opts.domain domain = opts.domain
} }
if opts.expires then if opts.expires then
if type(opts.expires) == "number" then if type(opts.expires) == "number" then
if opts.expires > 0 then if opts.expires > 0 then
cookie.max_age = opts.expires cookie.max_age = opts.expires
local now = os.time() local now = os.time()
cookie.expires = now + opts.expires cookie.expires = now + opts.expires
elseif opts.expires < 0 then elseif opts.expires < 0 then
cookie.expires = 1 cookie.expires = 1
cookie.max_age = 0 cookie.max_age = 0
end end
-- opts.expires == 0: Session cookie (omitting both expires and max-age) -- opts.expires == 0: Session cookie (omitting both expires and max-age)
end end
end end
cookie.secure = (opts.secure ~= false) cookie.secure = (opts.secure ~= false)
cookie.http_only = (opts.http_only ~= false) cookie.http_only = (opts.http_only ~= false)
if opts.same_site then if opts.same_site then
local valid_values = {none = true, lax = true, strict = true} local same_site = string.trim(opts.same_site):lower()
local same_site = string.lower(opts.same_site) local valid_values = {none = true, lax = true, strict = true}
if not valid_values[same_site] then if not valid_values[same_site] then
error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2) error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
end end
-- If SameSite=None, the cookie must be secure -- If SameSite=None, the cookie must be secure
if same_site == "none" and not cookie.secure then if same_site == "none" and not cookie.secure then
cookie.secure = true cookie.secure = true
end end
cookie.same_site = opts.same_site cookie.same_site = opts.same_site
else else
cookie.same_site = "Lax" cookie.same_site = "Lax"
end end
table.insert(resp.cookies, cookie) table.insert(resp.cookies, cookie)
return true return true
end, end,
-- Get a cookie value -- Get a cookie value
get = function(name) get = function(name)
if type(name) ~= "string" then if type(name) ~= "string" then
error("cookie.get: name must be a string", 2) error("cookie.get: name must be a string", 2)
end end
local env = getfenv(2) local env = getfenv(2)
if env.ctx and env.ctx.cookies then if env.ctx and env.ctx.cookies then
return env.ctx.cookies[name] return env.ctx.cookies[name]
end end
if env.ctx and env.ctx._request_cookies then if env.ctx and env.ctx._request_cookies then
return env.ctx._request_cookies[name] return env.ctx._request_cookies[name]
end end
return nil return nil
end, end,
-- Remove a cookie -- Remove a cookie
remove = function(name, path, domain) remove = function(name, path, domain)
if type(name) ~= "string" then if type(name) ~= "string" then
error("cookie.remove: name must be a string", 2) error("cookie.remove: name must be a string", 2)
end end
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain}) return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
end end
} }
-- ====================================================================== -- ======================================================================
@ -311,7 +311,7 @@ local session = {
end end
local resp = __ensure_response() local resp = __ensure_response()
resp.session = resp.session or {} resp.session = resp.session or {}
resp.session[key] = value resp.session[key] = value
end, end,
@ -360,7 +360,6 @@ local session = {
local resp = __ensure_response() local resp = __ensure_response()
resp.session = {} resp.session = {}
resp.session["__clear_all"] = true resp.session["__clear_all"] = true
end end
} }
@ -370,51 +369,52 @@ local session = {
-- ====================================================================== -- ======================================================================
local csrf = { local csrf = {
generate = function() generate = function()
local token = util.generate_token(32) local token = util.generate_token(32)
session.set("_csrf_token", token) session.set("_csrf_token", token)
return token return token
end, end,
field = function() field = function()
local token = session.get("_csrf_token") local token = session.get("_csrf_token")
if not token then if not token then
token = csrf.generate() token = csrf.generate()
end end
return string.format('<input type="hidden" name="_csrf_token" value="%s" />', token) return string.format('<input type="hidden" name="_csrf_token" value="%s" />',
end, util.html_special_chars(token))
end,
validate = function() validate = function()
local env = getfenv(2) local env = getfenv(2)
local token = false local token = false
if env.ctx and env.ctx.session and env.ctx.session.data then if env.ctx and env.ctx.session and env.ctx.session.data then
token = env.ctx.session.data["_csrf_token"] token = env.ctx.session.data["_csrf_token"]
end end
if not token then if not token then
http.set_status(403) http.set_status(403)
__http_response.body = "CSRF validation failed" __http_response.body = "CSRF validation failed"
exit() exit()
end end
local request_token = nil local request_token = nil
if env.ctx and env.ctx.form then if env.ctx and env.ctx.form then
request_token = env.ctx.form._csrf_token request_token = env.ctx.form._csrf_token
end end
if not request_token and env.ctx and env.ctx._request_headers then if not request_token and env.ctx and env.ctx._request_headers then
request_token = env.ctx._request_headers["x-csrf-token"] or request_token = env.ctx._request_headers["x-csrf-token"] or
env.ctx._request_headers["csrf-token"] env.ctx._request_headers["csrf-token"]
end end
if not request_token or request_token ~= token then if not request_token or request_token ~= token then
http.set_status(403) http.set_status(403)
__http_response.body = "CSRF validation failed" __http_response.body = "CSRF validation failed"
exit() exit()
end end
return true return true
end end
} }
-- ====================================================================== -- ======================================================================
@ -423,11 +423,6 @@ local csrf = {
-- Template processing with code execution -- Template processing with code execution
_G.render = function(template_str, env) _G.render = function(template_str, env)
local function escape_html(s)
local entities = {['&']='&amp;', ['<']='&lt;', ['>']='&gt;', ['"']='&quot;', ["'"]='&#039;'}
return (s:gsub([=[["><'&]]=], entities))
end
local function get_line(s, ln) local function get_line(s, ln)
for line in s:gmatch("([^\n]*)\n?") do for line in s:gmatch("([^\n]*)\n?") do
if ln == 1 then return line end if ln == 1 then return line end
@ -516,7 +511,7 @@ _G.render = function(template_str, env)
setfenv(fn, runtime_env) setfenv(fn, runtime_env)
local output_buffer = {} local output_buffer = {}
fn(tostring, escape_html, output_buffer, 0) fn(tostring, util.html_special_chars, output_buffer, 0)
return table.concat(output_buffer) return table.concat(output_buffer)
end end
@ -550,8 +545,7 @@ _G.parse = function(template_str, env)
local value = env[name] local value = env[name]
local str = tostring(value or "") local str = tostring(value or "")
if escaped then if escaped then
local entities = {['&']='&amp;', ['<']='&lt;', ['>']='&gt;', ['"']='&quot;', ["'"]='&#039;'} str = util.html_special_chars(str)
str = str:gsub([=[["><'&]]=], entities)
end end
table.insert(output, str) table.insert(output, str)
@ -591,8 +585,7 @@ _G.iparse = function(template_str, values)
local value = values[value_index] local value = values[value_index]
local str = tostring(value or "") local str = tostring(value or "")
if escaped then if escaped then
local entities = {['&']='&amp;', ['<']='&lt;', ['>']='&gt;', ['"']='&quot;', ["'"]='&#039;'} str = util.html_special_chars(str)
str = str:gsub([=[["><'&]]=], entities)
end end
table.insert(output, str) table.insert(output, str)

View File

@ -1,70 +1,37 @@
-- Simplified SQLite wrapper local function normalize_params(params, ...)
-- Connection is now lightweight with persistent connection tracking if type(params) == "table" then return params end
-- Helper function to handle parameters
local function handle_params(params, ...)
-- If params is a table, use it for named parameters
if type(params) == "table" then
return params
end
-- If we have varargs, collect them for positional parameters
local args = {...} local args = {...}
if #args > 0 or params ~= nil then if #args > 0 or params ~= nil then
-- Include the first param in the args
table.insert(args, 1, params) table.insert(args, 1, params)
return args return args
end end
return nil return nil
end end
-- Connection metatable
local connection_mt = { local connection_mt = {
__index = { __index = {
-- Execute a query and return results as a table
query = function(self, query, params, ...) query = function(self, query, params, ...)
if type(query) ~= "string" then if type(query) ~= "string" then
error("connection:query: query must be a string", 2) error("connection:query: query must be a string", 2)
end end
-- Execute with proper connection tracking local normalized_params = normalize_params(params, ...)
local results, token local results, token = __sqlite_query(self.db_name, query, normalized_params, self.conn_token)
if params == nil and select('#', ...) == 0 then
results, token = __sqlite_query(self.db_name, query, nil, self.conn_token)
elseif type(params) == "table" then
results, token = __sqlite_query(self.db_name, query, params, self.conn_token)
else
local args = {params, ...}
results, token = __sqlite_query(self.db_name, query, args, self.conn_token)
end
self.conn_token = token self.conn_token = token
return results return results
end, end,
-- Execute a statement and return affected rows
exec = function(self, query, params, ...) exec = function(self, query, params, ...)
if type(query) ~= "string" then if type(query) ~= "string" then
error("connection:exec: query must be a string", 2) error("connection:exec: query must be a string", 2)
end end
-- Execute with proper connection tracking local normalized_params = normalize_params(params, ...)
local affected, token local affected, token = __sqlite_exec(self.db_name, query, normalized_params, self.conn_token)
if params == nil and select('#', ...) == 0 then
affected, token = __sqlite_exec(self.db_name, query, nil, self.conn_token)
elseif type(params) == "table" then
affected, token = __sqlite_exec(self.db_name, query, params, self.conn_token)
else
local args = {params, ...}
affected, token = __sqlite_exec(self.db_name, query, args, self.conn_token)
end
self.conn_token = token self.conn_token = token
return affected return affected
end, end,
-- Close the connection (release back to pool)
close = function(self) close = function(self)
if self.conn_token then if self.conn_token then
local success = __sqlite_close(self.conn_token) local success = __sqlite_close(self.conn_token)
@ -74,50 +41,61 @@ local connection_mt = {
return false return false
end, end,
-- Insert a row or multiple rows with a single query
insert = function(self, table_name, data, columns) insert = function(self, table_name, data, columns)
if type(data) ~= "table" then if type(data) ~= "table" then
error("connection:insert: data must be a table", 2) error("connection:insert: data must be a table", 2)
end end
-- Case 1: Named columns with array data -- Single object: {col1=val1, col2=val2}
if columns and type(columns) == "table" then if data[1] == nil and next(data) ~= nil then
-- Check if we have multiple rows local cols = table.keys(data)
if #data > 0 and type(data[1]) == "table" then local placeholders = table.map(cols, function(_, i) return ":p" .. i end)
-- Build a single multi-value INSERT local params = {}
local placeholders = {} for i, col in ipairs(cols) do
local values = {} params["p" .. i] = data[col]
local params = {} end
local param_index = 1
for i, row in ipairs(data) do local query = string.format(
"INSERT INTO %s (%s) VALUES (%s)",
table_name,
table.concat(cols, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
-- Array data with columns
if columns and type(columns) == "table" then
if #data > 0 and type(data[1]) == "table" then
-- Multiple rows
local value_groups = {}
local params = {}
local param_idx = 1
for _, row in ipairs(data) do
local row_placeholders = {} local row_placeholders = {}
for j, _ in ipairs(columns) do for j = 1, #columns do
local param_name = "p" .. param_index local param_name = "p" .. param_idx
table.insert(row_placeholders, ":" .. param_name) table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[j] params[param_name] = row[j]
param_index = param_index + 1 param_idx = param_idx + 1
end end
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")") table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
end end
local query = string.format( local query = string.format(
"INSERT INTO %s (%s) VALUES %s", "INSERT INTO %s (%s) VALUES %s",
table_name, table_name,
table.concat(columns, ", "), table.concat(columns, ", "),
table.concat(placeholders, ", ") table.concat(value_groups, ", ")
) )
return self:exec(query, params) return self:exec(query, params)
else else
-- Single row with defined columns -- Single row array
local placeholders = {} local placeholders = table.map(columns, function(_, i) return ":p" .. i end)
local params = {} local params = {}
for i = 1, #columns do
for i, col in ipairs(columns) do params["p" .. i] = data[i]
local param_name = "p" .. i
table.insert(placeholders, ":" .. param_name)
params[param_name] = data[i]
end end
local query = string.format( local query = string.format(
@ -126,161 +104,71 @@ local connection_mt = {
table.concat(columns, ", "), table.concat(columns, ", "),
table.concat(placeholders, ", ") table.concat(placeholders, ", ")
) )
return self:exec(query, params) return self:exec(query, params)
end end
end end
-- Case 2: Object-style single row {col1=val1, col2=val2} -- Array of objects
if data[1] == nil and next(data) ~= nil then if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then
local columns = {} local cols = table.keys(data[1])
local placeholders = {} local value_groups = {}
local params = {} local params = {}
local param_idx = 1
for col, val in pairs(data) do for _, row in ipairs(data) do
table.insert(columns, col) local row_placeholders = {}
local param_name = "p" .. #columns for _, col in ipairs(cols) do
table.insert(placeholders, ":" .. param_name) local param_name = "p" .. param_idx
params[param_name] = val table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[col]
param_idx = param_idx + 1
end
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
end end
local query = string.format( local query = string.format(
"INSERT INTO %s (%s) VALUES (%s)", "INSERT INTO %s (%s) VALUES %s",
table_name, table_name,
table.concat(columns, ", "), table.concat(cols, ", "),
table.concat(placeholders, ", ") table.concat(value_groups, ", ")
) )
return self:exec(query, params) return self:exec(query, params)
end end
-- Case 3: Array of rows without predefined columns
if #data > 0 and type(data[1]) == "table" then
-- Extract columns from the first row
local first_row = data[1]
local inferred_columns = {}
-- Determine if first row is array or object
local is_array = first_row[1] ~= nil
if is_array then
-- Cannot infer column names from array
error("connection:insert: column names required for array data", 2)
else
-- Get columns from object keys
for col, _ in pairs(first_row) do
table.insert(inferred_columns, col)
end
-- Build multi-value INSERT
local placeholders = {}
local params = {}
local param_index = 1
for _, row in ipairs(data) do
local row_placeholders = {}
for _, col in ipairs(inferred_columns) do
local param_name = "p" .. param_index
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[col]
param_index = param_index + 1
end
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(inferred_columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
end
error("connection:insert: invalid data format", 2) error("connection:insert: invalid data format", 2)
end, end,
-- Update rows in a table
update = function(self, table_name, data, where, where_params, ...) update = function(self, table_name, data, where, where_params, ...)
if type(data) ~= "table" then if type(data) ~= "table" or next(data) == nil then
error("connection:update: data must be a table", 2)
end
-- Fast path for when there's no data
if next(data) == nil then
return 0 return 0
end end
local sets = {} local sets = {}
local params = {} local params = {}
local param_index = 1 local param_idx = 1
for col, val in pairs(data) do for col, val in pairs(data) do
local param_name = "p" .. param_index local param_name = "p" .. param_idx
table.insert(sets, col .. " = :" .. param_name) table.insert(sets, col .. " = :" .. param_name)
params[param_name] = val params[param_name] = val
param_index = param_index + 1 param_idx = param_idx + 1
end end
local query = string.format( local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", "))
"UPDATE %s SET %s",
table_name,
table.concat(sets, ", ")
)
if where then if where then
query = query .. " WHERE " .. where query = query .. " WHERE " .. where
if where_params then if where_params then
if type(where_params) == "table" then local normalized = normalize_params(where_params, ...)
-- Handle named parameters in WHERE clause if type(normalized) == "table" then
for k, v in pairs(where_params) do for k, v in pairs(normalized) do
local param_name if type(k) == "string" then
if type(k) == "string" and k:sub(1, 1) == ":" then params[k] = v
param_name = k:sub(2)
else else
param_name = "w" .. param_index params["w" .. param_idx] = v
-- Replace the placeholder in the WHERE clause param_idx = param_idx + 1
where = where:gsub(":" .. k, ":" .. param_name)
end end
params[param_name] = v
param_index = param_index + 1
end end
else
-- Handle positional parameters (? placeholders)
local args = {where_params, ...}
local pos = 1
local offset = 0
-- Replace ? with named parameters
while true do
local start_pos, end_pos = where:find("?", pos)
if not start_pos then break end
local param_name = "w" .. param_index
local replacement = ":" .. param_name
where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1)
if args[pos - offset] ~= nil then
params[param_name] = args[pos - offset]
else
params[param_name] = nil
end
param_index = param_index + 1
pos = start_pos + #replacement
offset = offset + 1
end
query = string.format(
"UPDATE %s SET %s WHERE %s",
table_name,
table.concat(sets, ", "),
where
)
end end
end end
end end
@ -288,129 +176,108 @@ local connection_mt = {
return self:exec(query, params) return self:exec(query, params)
end, end,
-- Create a new table
create_table = function(self, table_name, ...) create_table = function(self, table_name, ...)
local columns = {} local column_definitions = {}
local indices = {} local index_definitions = {}
-- Process all arguments for _, def_string in ipairs({...}) do
for _, def in ipairs({...}) do if type(def_string) == "string" then
if type(def) == "string" then local is_unique = false
-- Check if it's an index definition local index_def = def_string
local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
if index_def then if string.starts_with(def_string, "UNIQUE INDEX:") then
-- Parse index definition is_unique = true
local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)") index_def = string.trim(def_string:sub(14))
elseif string.starts_with(def_string, "INDEX:") then
if index_name and columns_str then index_def = string.trim(def_string:sub(7))
-- Split columns by comma
local index_columns = {}
for col in columns_str:gmatch("[^,]+") do
table.insert(index_columns, col:match("^%s*(.-)%s*$")) -- Trim whitespace
end
table.insert(indices, {
name = index_name,
columns = index_columns,
unique = (index_type == "UNIQUE INDEX:")
})
end
else else
-- Regular column definition table.insert(column_definitions, def_string)
table.insert(columns, def) goto continue
end
local paren_pos = index_def:find("%(")
if not paren_pos then goto continue end
local index_name = string.trim(index_def:sub(1, paren_pos - 1))
local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$")
if not columns_part then goto continue end
local columns = table.map(string.split(columns_part, ","), string.trim)
if #columns > 0 then
table.insert(index_definitions, {
name = index_name,
columns = columns,
unique = is_unique
})
end end
end end
::continue::
end end
if #columns == 0 then if #column_definitions == 0 then
error("connection:create_table: no columns specified", 2) error("connection:create_table: no column definitions specified for table " .. table_name, 2)
end end
-- Build combined statement for table and indices
local statements = {} local statements = {}
-- Add the CREATE TABLE statement
table.insert(statements, string.format( table.insert(statements, string.format(
"CREATE TABLE IF NOT EXISTS %s (%s)", "CREATE TABLE IF NOT EXISTS %s (%s)",
table_name, table_name,
table.concat(columns, ", ") table.concat(column_definitions, ", ")
)) ))
-- Add CREATE INDEX statements for _, idx in ipairs(index_definitions) do
for _, idx in ipairs(indices) do local unique_prefix = idx.unique and "UNIQUE " or ""
local unique = idx.unique and "UNIQUE " or ""
table.insert(statements, string.format( table.insert(statements, string.format(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)", "CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique, unique_prefix,
idx.name, idx.name,
table_name, table_name,
table.concat(idx.columns, ", ") table.concat(idx.columns, ", ")
)) ))
end end
-- Execute all statements in a single transaction return self:exec(table.concat(statements, ";\n"))
local combined_sql = table.concat(statements, ";\n")
return self:exec(combined_sql)
end, end,
-- Delete rows delete = function(self, table_name, where, params, ...)
delete = function(self, table_name, where, params)
local query = "DELETE FROM " .. table_name local query = "DELETE FROM " .. table_name
if where then if where then
query = query .. " WHERE " .. where query = query .. " WHERE " .. where
end end
return self:exec(query, normalize_params(params, ...))
return self:exec(query, params)
end, end,
-- Get one row efficiently
get_one = function(self, query, params, ...) get_one = function(self, query, params, ...)
if type(query) ~= "string" then if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2) error("connection:get_one: query must be a string", 2)
end end
-- Add LIMIT 1 to query if not already limited
local limited_query = query local limited_query = query
if not query:lower():match("limit%s+%d+") then if not string.contains(query:lower(), "limit") then
limited_query = query .. " LIMIT 1" limited_query = query .. " LIMIT 1"
end end
local results local results = self:query(limited_query, normalize_params(params, ...))
if select('#', ...) > 0 then
results = self:query(limited_query, params, ...)
else
results = self:query(limited_query, params)
end
return results[1] return results[1]
end, end,
-- Begin transaction
begin = function(self) begin = function(self)
return self:exec("BEGIN TRANSACTION") return self:exec("BEGIN TRANSACTION")
end, end,
-- Commit transaction
commit = function(self) commit = function(self)
return self:exec("COMMIT") return self:exec("COMMIT")
end, end,
-- Rollback transaction
rollback = function(self) rollback = function(self)
return self:exec("ROLLBACK") return self:exec("ROLLBACK")
end, end,
-- Transaction wrapper function
transaction = function(self, callback) transaction = function(self, callback)
self:begin() self:begin()
local success, result = pcall(callback, self)
local success, result = pcall(function()
return callback(self)
end)
if success then if success then
self:commit() self:commit()
return result return result
@ -422,16 +289,13 @@ local connection_mt = {
} }
} }
-- Create sqlite() function that returns a connection object
return function(db_name) return function(db_name)
if type(db_name) ~= "string" then if type(db_name) ~= "string" then
error("sqlite: database name must be a string", 2) error("sqlite: database name must be a string", 2)
end end
local conn = { return setmetatable({
db_name = db_name, db_name = db_name,
conn_token = nil -- Will be populated on first query/exec conn_token = nil
} }, connection_mt)
return setmetatable(conn, connection_mt)
end end