diff --git a/runner/lua/sandbox.lua b/runner/lua/sandbox.lua
index f09e249..97d67c7 100644
--- a/runner/lua/sandbox.lua
+++ b/runner/lua/sandbox.lua
@@ -13,50 +13,50 @@ __EXIT_SENTINEL = {} -- Unique object for exit identification
-- ======================================================================
function exit()
- error(__EXIT_SENTINEL)
+ error(__EXIT_SENTINEL)
end
-- Create environment inheriting from _G
function __create_env(ctx)
- local env = setmetatable({}, {__index = _G})
+ local env = setmetatable({}, {__index = _G})
- if ctx then
- env.ctx = ctx
- end
+ if ctx then
+ env.ctx = ctx
+ end
- if __setup_require then
- __setup_require(env)
- end
+ if __setup_require then
+ __setup_require(env)
+ end
- return env
+ return env
end
-- Execute script with clean environment
function __execute_script(fn, ctx)
- __http_response = nil
+ __http_response = nil
- local env = __create_env(ctx)
- env.exit = exit
- setfenv(fn, env)
+ local env = __create_env(ctx)
+ env.exit = exit
+ setfenv(fn, env)
- local ok, result = pcall(fn)
- if not ok then
- if result == __EXIT_SENTINEL then
- return
- end
+ local ok, result = pcall(fn)
+ if not ok then
+ if result == __EXIT_SENTINEL then
+ return
+ end
- error(result, 0)
- end
+ error(result, 0)
+ end
- return result
+ return result
end
-- Ensure __http_response exists, then return it
function __ensure_response()
- if not __http_response then
- __http_response = {}
- end
- return __http_response
+ if not __http_response then
+ __http_response = {}
+ end
+ return __http_response
end
-- ======================================================================
@@ -137,15 +137,15 @@ local http = {
for k, v in pairs(params) do
if type(v) == "table" then
for _, item in ipairs(v) do
- table.insert(query, k .. "=" .. tostring(item))
+ table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(item)))
end
else
- table.insert(query, k .. "=" .. tostring(v))
+ table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(v)))
end
end
if #query > 0 then
- if base_url:find("?") then
+ if string.contains(base_url, "?") then
return base_url .. "&" .. table.concat(query, "&")
else
return base_url .. "?" .. table.concat(query, "&")
@@ -198,89 +198,89 @@ end
-- ======================================================================
local cookie = {
- -- Set a cookie
- set = function(name, value, options)
- if type(name) ~= "string" then
- error("cookie.set: name must be a string", 2)
- end
+ -- Set a cookie
+ set = function(name, value, options)
+ if type(name) ~= "string" then
+ error("cookie.set: name must be a string", 2)
+ end
- local resp = __ensure_response()
- resp.cookies = resp.cookies or {}
+ local resp = __ensure_response()
+ resp.cookies = resp.cookies or {}
- local opts = options or {}
- local cookie = {
- name = name,
- value = value or "",
- path = opts.path or "/",
- domain = opts.domain
- }
+ local opts = options or {}
+ local cookie = {
+ name = name,
+ value = value or "",
+ path = opts.path or "/",
+ domain = opts.domain
+ }
- if opts.expires then
- if type(opts.expires) == "number" then
- if opts.expires > 0 then
- cookie.max_age = opts.expires
- local now = os.time()
- cookie.expires = now + opts.expires
- elseif opts.expires < 0 then
- cookie.expires = 1
- cookie.max_age = 0
- end
- -- opts.expires == 0: Session cookie (omitting both expires and max-age)
- end
- end
+ if opts.expires then
+ if type(opts.expires) == "number" then
+ if opts.expires > 0 then
+ cookie.max_age = opts.expires
+ local now = os.time()
+ cookie.expires = now + opts.expires
+ elseif opts.expires < 0 then
+ cookie.expires = 1
+ cookie.max_age = 0
+ end
+ -- opts.expires == 0: Session cookie (omitting both expires and max-age)
+ end
+ end
- cookie.secure = (opts.secure ~= false)
- cookie.http_only = (opts.http_only ~= false)
+ cookie.secure = (opts.secure ~= false)
+ cookie.http_only = (opts.http_only ~= false)
- if opts.same_site then
- local valid_values = {none = true, lax = true, strict = true}
- local same_site = string.lower(opts.same_site)
+ if opts.same_site then
+ local same_site = string.trim(opts.same_site):lower()
+ local valid_values = {none = true, lax = true, strict = true}
- if not valid_values[same_site] then
- error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
- end
+ if not valid_values[same_site] then
+ error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
+ end
- -- If SameSite=None, the cookie must be secure
- if same_site == "none" and not cookie.secure then
- cookie.secure = true
- end
+ -- If SameSite=None, the cookie must be secure
+ if same_site == "none" and not cookie.secure then
+ cookie.secure = true
+ end
- cookie.same_site = opts.same_site
- else
- cookie.same_site = "Lax"
- end
+ cookie.same_site = opts.same_site
+ else
+ cookie.same_site = "Lax"
+ end
- table.insert(resp.cookies, cookie)
- return true
- end,
+ table.insert(resp.cookies, cookie)
+ return true
+ end,
- -- Get a cookie value
- get = function(name)
- if type(name) ~= "string" then
- error("cookie.get: name must be a string", 2)
- end
+ -- Get a cookie value
+ get = function(name)
+ if type(name) ~= "string" then
+ error("cookie.get: name must be a string", 2)
+ end
- local env = getfenv(2)
+ local env = getfenv(2)
- if env.ctx and env.ctx.cookies then
- return env.ctx.cookies[name]
- end
+ if env.ctx and env.ctx.cookies then
+ return env.ctx.cookies[name]
+ end
- if env.ctx and env.ctx._request_cookies then
- return env.ctx._request_cookies[name]
- end
+ if env.ctx and env.ctx._request_cookies then
+ return env.ctx._request_cookies[name]
+ end
- return nil
- end,
+ return nil
+ end,
- -- Remove a cookie
- remove = function(name, path, domain)
- if type(name) ~= "string" then
- error("cookie.remove: name must be a string", 2)
- end
+ -- Remove a cookie
+ remove = function(name, path, domain)
+ if type(name) ~= "string" then
+ error("cookie.remove: name must be a string", 2)
+ end
- return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
- end
+ return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
+ end
}
-- ======================================================================
@@ -311,7 +311,7 @@ local session = {
end
local resp = __ensure_response()
- resp.session = resp.session or {}
+ resp.session = resp.session or {}
resp.session[key] = value
end,
@@ -360,7 +360,6 @@ local session = {
local resp = __ensure_response()
resp.session = {}
-
resp.session["__clear_all"] = true
end
}
@@ -370,51 +369,52 @@ local session = {
-- ======================================================================
local csrf = {
- generate = function()
- local token = util.generate_token(32)
- session.set("_csrf_token", token)
- return token
- end,
+ generate = function()
+ local token = util.generate_token(32)
+ session.set("_csrf_token", token)
+ return token
+ end,
- field = function()
- local token = session.get("_csrf_token")
- if not token then
- token = csrf.generate()
- end
- return string.format('', token)
- end,
+ field = function()
+ local token = session.get("_csrf_token")
+ if not token then
+ token = csrf.generate()
+ end
+ return string.format('',
+ util.html_special_chars(token))
+ end,
- validate = function()
+ validate = function()
local env = getfenv(2)
- local token = false
+ local token = false
if env.ctx and env.ctx.session and env.ctx.session.data then
token = env.ctx.session.data["_csrf_token"]
end
- if not token then
- http.set_status(403)
- __http_response.body = "CSRF validation failed"
- exit()
- end
+ if not token then
+ http.set_status(403)
+ __http_response.body = "CSRF validation failed"
+ exit()
+ end
- local request_token = nil
- if env.ctx and env.ctx.form then
- request_token = env.ctx.form._csrf_token
- end
+ local request_token = nil
+ if env.ctx and env.ctx.form then
+ request_token = env.ctx.form._csrf_token
+ end
- if not request_token and env.ctx and env.ctx._request_headers then
- request_token = env.ctx._request_headers["x-csrf-token"] or
- env.ctx._request_headers["csrf-token"]
- end
+ if not request_token and env.ctx and env.ctx._request_headers then
+ request_token = env.ctx._request_headers["x-csrf-token"] or
+ env.ctx._request_headers["csrf-token"]
+ end
- if not request_token or request_token ~= token then
- http.set_status(403)
- __http_response.body = "CSRF validation failed"
- exit()
- end
+ if not request_token or request_token ~= token then
+ http.set_status(403)
+ __http_response.body = "CSRF validation failed"
+ exit()
+ end
- return true
- end
+ return true
+ end
}
-- ======================================================================
@@ -423,11 +423,6 @@ local csrf = {
-- Template processing with code execution
_G.render = function(template_str, env)
- local function escape_html(s)
- local entities = {['&']='&', ['<']='<', ['>']='>', ['"']='"', ["'"]='''}
- return (s:gsub([=[["><'&]]=], entities))
- end
-
local function get_line(s, ln)
for line in s:gmatch("([^\n]*)\n?") do
if ln == 1 then return line end
@@ -516,7 +511,7 @@ _G.render = function(template_str, env)
setfenv(fn, runtime_env)
local output_buffer = {}
- fn(tostring, escape_html, output_buffer, 0)
+ fn(tostring, util.html_special_chars, output_buffer, 0)
return table.concat(output_buffer)
end
@@ -550,8 +545,7 @@ _G.parse = function(template_str, env)
local value = env[name]
local str = tostring(value or "")
if escaped then
- local entities = {['&']='&', ['<']='<', ['>']='>', ['"']='"', ["'"]='''}
- str = str:gsub([=[["><'&]]=], entities)
+ str = util.html_special_chars(str)
end
table.insert(output, str)
@@ -591,8 +585,7 @@ _G.iparse = function(template_str, values)
local value = values[value_index]
local str = tostring(value or "")
if escaped then
- local entities = {['&']='&', ['<']='<', ['>']='>', ['"']='"', ["'"]='''}
- str = str:gsub([=[["><'&]]=], entities)
+ str = util.html_special_chars(str)
end
table.insert(output, str)
diff --git a/runner/lua/sqlite.lua b/runner/lua/sqlite.lua
index ccfe458..0a24190 100644
--- a/runner/lua/sqlite.lua
+++ b/runner/lua/sqlite.lua
@@ -1,70 +1,37 @@
--- Simplified SQLite wrapper
--- Connection is now lightweight with persistent connection tracking
-
--- Helper function to handle parameters
-local function handle_params(params, ...)
- -- If params is a table, use it for named parameters
- if type(params) == "table" then
- return params
- end
-
- -- If we have varargs, collect them for positional parameters
+local function normalize_params(params, ...)
+ if type(params) == "table" then return params end
local args = {...}
if #args > 0 or params ~= nil then
- -- Include the first param in the args
table.insert(args, 1, params)
return args
end
-
return nil
end
--- Connection metatable
local connection_mt = {
__index = {
- -- Execute a query and return results as a table
query = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:query: query must be a string", 2)
end
- -- Execute with proper connection tracking
- local results, token
- if params == nil and select('#', ...) == 0 then
- results, token = __sqlite_query(self.db_name, query, nil, self.conn_token)
- elseif type(params) == "table" then
- results, token = __sqlite_query(self.db_name, query, params, self.conn_token)
- else
- local args = {params, ...}
- results, token = __sqlite_query(self.db_name, query, args, self.conn_token)
- end
-
+ local normalized_params = normalize_params(params, ...)
+ local results, token = __sqlite_query(self.db_name, query, normalized_params, self.conn_token)
self.conn_token = token
return results
end,
- -- Execute a statement and return affected rows
exec = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:exec: query must be a string", 2)
end
- -- Execute with proper connection tracking
- local affected, token
- if params == nil and select('#', ...) == 0 then
- affected, token = __sqlite_exec(self.db_name, query, nil, self.conn_token)
- elseif type(params) == "table" then
- affected, token = __sqlite_exec(self.db_name, query, params, self.conn_token)
- else
- local args = {params, ...}
- affected, token = __sqlite_exec(self.db_name, query, args, self.conn_token)
- end
-
+ local normalized_params = normalize_params(params, ...)
+ local affected, token = __sqlite_exec(self.db_name, query, normalized_params, self.conn_token)
self.conn_token = token
return affected
end,
- -- Close the connection (release back to pool)
close = function(self)
if self.conn_token then
local success = __sqlite_close(self.conn_token)
@@ -74,50 +41,61 @@ local connection_mt = {
return false
end,
- -- Insert a row or multiple rows with a single query
insert = function(self, table_name, data, columns)
if type(data) ~= "table" then
error("connection:insert: data must be a table", 2)
end
- -- Case 1: Named columns with array data
- if columns and type(columns) == "table" then
- -- Check if we have multiple rows
- if #data > 0 and type(data[1]) == "table" then
- -- Build a single multi-value INSERT
- local placeholders = {}
- local values = {}
- local params = {}
- local param_index = 1
+ -- Single object: {col1=val1, col2=val2}
+ if data[1] == nil and next(data) ~= nil then
+ local cols = table.keys(data)
+ local placeholders = table.map(cols, function(_, i) return ":p" .. i end)
+ local params = {}
+ for i, col in ipairs(cols) do
+ params["p" .. i] = data[col]
+ end
- for i, row in ipairs(data) do
+ local query = string.format(
+ "INSERT INTO %s (%s) VALUES (%s)",
+ table_name,
+ table.concat(cols, ", "),
+ table.concat(placeholders, ", ")
+ )
+ return self:exec(query, params)
+ end
+
+ -- Array data with columns
+ if columns and type(columns) == "table" then
+ if #data > 0 and type(data[1]) == "table" then
+ -- Multiple rows
+ local value_groups = {}
+ local params = {}
+ local param_idx = 1
+
+ for _, row in ipairs(data) do
local row_placeholders = {}
- for j, _ in ipairs(columns) do
- local param_name = "p" .. param_index
+ for j = 1, #columns do
+ local param_name = "p" .. param_idx
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[j]
- param_index = param_index + 1
+ param_idx = param_idx + 1
end
- table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
+ table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(columns, ", "),
- table.concat(placeholders, ", ")
+ table.concat(value_groups, ", ")
)
-
return self:exec(query, params)
else
- -- Single row with defined columns
- local placeholders = {}
+ -- Single row array
+ local placeholders = table.map(columns, function(_, i) return ":p" .. i end)
local params = {}
-
- for i, col in ipairs(columns) do
- local param_name = "p" .. i
- table.insert(placeholders, ":" .. param_name)
- params[param_name] = data[i]
+ for i = 1, #columns do
+ params["p" .. i] = data[i]
end
local query = string.format(
@@ -126,161 +104,71 @@ local connection_mt = {
table.concat(columns, ", "),
table.concat(placeholders, ", ")
)
-
return self:exec(query, params)
end
end
- -- Case 2: Object-style single row {col1=val1, col2=val2}
- if data[1] == nil and next(data) ~= nil then
- local columns = {}
- local placeholders = {}
+ -- Array of objects
+ if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then
+ local cols = table.keys(data[1])
+ local value_groups = {}
local params = {}
+ local param_idx = 1
- for col, val in pairs(data) do
- table.insert(columns, col)
- local param_name = "p" .. #columns
- table.insert(placeholders, ":" .. param_name)
- params[param_name] = val
+ for _, row in ipairs(data) do
+ local row_placeholders = {}
+ for _, col in ipairs(cols) do
+ local param_name = "p" .. param_idx
+ table.insert(row_placeholders, ":" .. param_name)
+ params[param_name] = row[col]
+ param_idx = param_idx + 1
+ end
+ table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
- "INSERT INTO %s (%s) VALUES (%s)",
+ "INSERT INTO %s (%s) VALUES %s",
table_name,
- table.concat(columns, ", "),
- table.concat(placeholders, ", ")
+ table.concat(cols, ", "),
+ table.concat(value_groups, ", ")
)
-
return self:exec(query, params)
end
- -- Case 3: Array of rows without predefined columns
- if #data > 0 and type(data[1]) == "table" then
- -- Extract columns from the first row
- local first_row = data[1]
- local inferred_columns = {}
-
- -- Determine if first row is array or object
- local is_array = first_row[1] ~= nil
-
- if is_array then
- -- Cannot infer column names from array
- error("connection:insert: column names required for array data", 2)
- else
- -- Get columns from object keys
- for col, _ in pairs(first_row) do
- table.insert(inferred_columns, col)
- end
-
- -- Build multi-value INSERT
- local placeholders = {}
- local params = {}
- local param_index = 1
-
- for _, row in ipairs(data) do
- local row_placeholders = {}
- for _, col in ipairs(inferred_columns) do
- local param_name = "p" .. param_index
- table.insert(row_placeholders, ":" .. param_name)
- params[param_name] = row[col]
- param_index = param_index + 1
- end
- table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
- end
-
- local query = string.format(
- "INSERT INTO %s (%s) VALUES %s",
- table_name,
- table.concat(inferred_columns, ", "),
- table.concat(placeholders, ", ")
- )
-
- return self:exec(query, params)
- end
- end
-
error("connection:insert: invalid data format", 2)
end,
- -- Update rows in a table
update = function(self, table_name, data, where, where_params, ...)
- if type(data) ~= "table" then
- error("connection:update: data must be a table", 2)
- end
-
- -- Fast path for when there's no data
- if next(data) == nil then
+ if type(data) ~= "table" or next(data) == nil then
return 0
end
local sets = {}
local params = {}
- local param_index = 1
+ local param_idx = 1
for col, val in pairs(data) do
- local param_name = "p" .. param_index
+ local param_name = "p" .. param_idx
table.insert(sets, col .. " = :" .. param_name)
params[param_name] = val
- param_index = param_index + 1
+ param_idx = param_idx + 1
end
- local query = string.format(
- "UPDATE %s SET %s",
- table_name,
- table.concat(sets, ", ")
- )
+ local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", "))
if where then
query = query .. " WHERE " .. where
-
if where_params then
- if type(where_params) == "table" then
- -- Handle named parameters in WHERE clause
- for k, v in pairs(where_params) do
- local param_name
- if type(k) == "string" and k:sub(1, 1) == ":" then
- param_name = k:sub(2)
+ local normalized = normalize_params(where_params, ...)
+ if type(normalized) == "table" then
+ for k, v in pairs(normalized) do
+ if type(k) == "string" then
+ params[k] = v
else
- param_name = "w" .. param_index
- -- Replace the placeholder in the WHERE clause
- where = where:gsub(":" .. k, ":" .. param_name)
+ params["w" .. param_idx] = v
+ param_idx = param_idx + 1
end
- params[param_name] = v
- param_index = param_index + 1
end
- else
- -- Handle positional parameters (? placeholders)
- local args = {where_params, ...}
- local pos = 1
- local offset = 0
-
- -- Replace ? with named parameters
- while true do
- local start_pos, end_pos = where:find("?", pos)
- if not start_pos then break end
-
- local param_name = "w" .. param_index
- local replacement = ":" .. param_name
-
- where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1)
-
- if args[pos - offset] ~= nil then
- params[param_name] = args[pos - offset]
- else
- params[param_name] = nil
- end
-
- param_index = param_index + 1
- pos = start_pos + #replacement
- offset = offset + 1
- end
-
- query = string.format(
- "UPDATE %s SET %s WHERE %s",
- table_name,
- table.concat(sets, ", "),
- where
- )
end
end
end
@@ -288,129 +176,108 @@ local connection_mt = {
return self:exec(query, params)
end,
- -- Create a new table
create_table = function(self, table_name, ...)
- local columns = {}
- local indices = {}
+ local column_definitions = {}
+ local index_definitions = {}
- -- Process all arguments
- for _, def in ipairs({...}) do
- if type(def) == "string" then
- -- Check if it's an index definition
- local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
+ for _, def_string in ipairs({...}) do
+ if type(def_string) == "string" then
+ local is_unique = false
+ local index_def = def_string
- if index_def then
- -- Parse index definition
- local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)")
-
- if index_name and columns_str then
- -- Split columns by comma
- local index_columns = {}
- for col in columns_str:gmatch("[^,]+") do
- table.insert(index_columns, col:match("^%s*(.-)%s*$")) -- Trim whitespace
- end
-
- table.insert(indices, {
- name = index_name,
- columns = index_columns,
- unique = (index_type == "UNIQUE INDEX:")
- })
- end
+ if string.starts_with(def_string, "UNIQUE INDEX:") then
+ is_unique = true
+ index_def = string.trim(def_string:sub(14))
+ elseif string.starts_with(def_string, "INDEX:") then
+ index_def = string.trim(def_string:sub(7))
else
- -- Regular column definition
- table.insert(columns, def)
+ table.insert(column_definitions, def_string)
+ goto continue
+ end
+
+ local paren_pos = index_def:find("%(")
+ if not paren_pos then goto continue end
+
+ local index_name = string.trim(index_def:sub(1, paren_pos - 1))
+ local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$")
+ if not columns_part then goto continue end
+
+ local columns = table.map(string.split(columns_part, ","), string.trim)
+
+ if #columns > 0 then
+ table.insert(index_definitions, {
+ name = index_name,
+ columns = columns,
+ unique = is_unique
+ })
end
end
+ ::continue::
end
- if #columns == 0 then
- error("connection:create_table: no columns specified", 2)
+ if #column_definitions == 0 then
+ error("connection:create_table: no column definitions specified for table " .. table_name, 2)
end
- -- Build combined statement for table and indices
local statements = {}
- -- Add the CREATE TABLE statement
table.insert(statements, string.format(
"CREATE TABLE IF NOT EXISTS %s (%s)",
table_name,
- table.concat(columns, ", ")
+ table.concat(column_definitions, ", ")
))
- -- Add CREATE INDEX statements
- for _, idx in ipairs(indices) do
- local unique = idx.unique and "UNIQUE " or ""
-
+ for _, idx in ipairs(index_definitions) do
+ local unique_prefix = idx.unique and "UNIQUE " or ""
table.insert(statements, string.format(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
- unique,
+ unique_prefix,
idx.name,
table_name,
table.concat(idx.columns, ", ")
))
end
- -- Execute all statements in a single transaction
- local combined_sql = table.concat(statements, ";\n")
- return self:exec(combined_sql)
+ return self:exec(table.concat(statements, ";\n"))
end,
- -- Delete rows
- delete = function(self, table_name, where, params)
+ delete = function(self, table_name, where, params, ...)
local query = "DELETE FROM " .. table_name
-
if where then
query = query .. " WHERE " .. where
end
-
- return self:exec(query, params)
+ return self:exec(query, normalize_params(params, ...))
end,
- -- Get one row efficiently
get_one = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2)
end
- -- Add LIMIT 1 to query if not already limited
local limited_query = query
- if not query:lower():match("limit%s+%d+") then
+ if not string.contains(query:lower(), "limit") then
limited_query = query .. " LIMIT 1"
end
- local results
- if select('#', ...) > 0 then
- results = self:query(limited_query, params, ...)
- else
- results = self:query(limited_query, params)
- end
-
+ local results = self:query(limited_query, normalize_params(params, ...))
return results[1]
end,
- -- Begin transaction
begin = function(self)
return self:exec("BEGIN TRANSACTION")
end,
- -- Commit transaction
commit = function(self)
return self:exec("COMMIT")
end,
- -- Rollback transaction
rollback = function(self)
return self:exec("ROLLBACK")
end,
- -- Transaction wrapper function
transaction = function(self, callback)
self:begin()
-
- local success, result = pcall(function()
- return callback(self)
- end)
-
+ local success, result = pcall(callback, self)
if success then
self:commit()
return result
@@ -422,16 +289,13 @@ local connection_mt = {
}
}
--- Create sqlite() function that returns a connection object
return function(db_name)
if type(db_name) ~= "string" then
error("sqlite: database name must be a string", 2)
end
- local conn = {
+ return setmetatable({
db_name = db_name,
- conn_token = nil -- Will be populated on first query/exec
- }
-
- return setmetatable(conn, connection_mt)
+ conn_token = nil
+ }, connection_mt)
end