update sqlite wrapper and sandbox with utils
This commit is contained in:
parent
5913bc4ba3
commit
ca85e735b0
@ -13,50 +13,50 @@ __EXIT_SENTINEL = {} -- Unique object for exit identification
|
|||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
|
|
||||||
function exit()
|
function exit()
|
||||||
error(__EXIT_SENTINEL)
|
error(__EXIT_SENTINEL)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Create environment inheriting from _G
|
-- Create environment inheriting from _G
|
||||||
function __create_env(ctx)
|
function __create_env(ctx)
|
||||||
local env = setmetatable({}, {__index = _G})
|
local env = setmetatable({}, {__index = _G})
|
||||||
|
|
||||||
if ctx then
|
if ctx then
|
||||||
env.ctx = ctx
|
env.ctx = ctx
|
||||||
end
|
end
|
||||||
|
|
||||||
if __setup_require then
|
if __setup_require then
|
||||||
__setup_require(env)
|
__setup_require(env)
|
||||||
end
|
end
|
||||||
|
|
||||||
return env
|
return env
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Execute script with clean environment
|
-- Execute script with clean environment
|
||||||
function __execute_script(fn, ctx)
|
function __execute_script(fn, ctx)
|
||||||
__http_response = nil
|
__http_response = nil
|
||||||
|
|
||||||
local env = __create_env(ctx)
|
local env = __create_env(ctx)
|
||||||
env.exit = exit
|
env.exit = exit
|
||||||
setfenv(fn, env)
|
setfenv(fn, env)
|
||||||
|
|
||||||
local ok, result = pcall(fn)
|
local ok, result = pcall(fn)
|
||||||
if not ok then
|
if not ok then
|
||||||
if result == __EXIT_SENTINEL then
|
if result == __EXIT_SENTINEL then
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
error(result, 0)
|
error(result, 0)
|
||||||
end
|
end
|
||||||
|
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Ensure __http_response exists, then return it
|
-- Ensure __http_response exists, then return it
|
||||||
function __ensure_response()
|
function __ensure_response()
|
||||||
if not __http_response then
|
if not __http_response then
|
||||||
__http_response = {}
|
__http_response = {}
|
||||||
end
|
end
|
||||||
return __http_response
|
return __http_response
|
||||||
end
|
end
|
||||||
|
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
@ -137,15 +137,15 @@ local http = {
|
|||||||
for k, v in pairs(params) do
|
for k, v in pairs(params) do
|
||||||
if type(v) == "table" then
|
if type(v) == "table" then
|
||||||
for _, item in ipairs(v) do
|
for _, item in ipairs(v) do
|
||||||
table.insert(query, k .. "=" .. tostring(item))
|
table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(item)))
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
table.insert(query, k .. "=" .. tostring(v))
|
table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(v)))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
if #query > 0 then
|
if #query > 0 then
|
||||||
if base_url:find("?") then
|
if string.contains(base_url, "?") then
|
||||||
return base_url .. "&" .. table.concat(query, "&")
|
return base_url .. "&" .. table.concat(query, "&")
|
||||||
else
|
else
|
||||||
return base_url .. "?" .. table.concat(query, "&")
|
return base_url .. "?" .. table.concat(query, "&")
|
||||||
@ -198,89 +198,89 @@ end
|
|||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
|
|
||||||
local cookie = {
|
local cookie = {
|
||||||
-- Set a cookie
|
-- Set a cookie
|
||||||
set = function(name, value, options)
|
set = function(name, value, options)
|
||||||
if type(name) ~= "string" then
|
if type(name) ~= "string" then
|
||||||
error("cookie.set: name must be a string", 2)
|
error("cookie.set: name must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local resp = __ensure_response()
|
local resp = __ensure_response()
|
||||||
resp.cookies = resp.cookies or {}
|
resp.cookies = resp.cookies or {}
|
||||||
|
|
||||||
local opts = options or {}
|
local opts = options or {}
|
||||||
local cookie = {
|
local cookie = {
|
||||||
name = name,
|
name = name,
|
||||||
value = value or "",
|
value = value or "",
|
||||||
path = opts.path or "/",
|
path = opts.path or "/",
|
||||||
domain = opts.domain
|
domain = opts.domain
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.expires then
|
if opts.expires then
|
||||||
if type(opts.expires) == "number" then
|
if type(opts.expires) == "number" then
|
||||||
if opts.expires > 0 then
|
if opts.expires > 0 then
|
||||||
cookie.max_age = opts.expires
|
cookie.max_age = opts.expires
|
||||||
local now = os.time()
|
local now = os.time()
|
||||||
cookie.expires = now + opts.expires
|
cookie.expires = now + opts.expires
|
||||||
elseif opts.expires < 0 then
|
elseif opts.expires < 0 then
|
||||||
cookie.expires = 1
|
cookie.expires = 1
|
||||||
cookie.max_age = 0
|
cookie.max_age = 0
|
||||||
end
|
end
|
||||||
-- opts.expires == 0: Session cookie (omitting both expires and max-age)
|
-- opts.expires == 0: Session cookie (omitting both expires and max-age)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
cookie.secure = (opts.secure ~= false)
|
cookie.secure = (opts.secure ~= false)
|
||||||
cookie.http_only = (opts.http_only ~= false)
|
cookie.http_only = (opts.http_only ~= false)
|
||||||
|
|
||||||
if opts.same_site then
|
if opts.same_site then
|
||||||
local valid_values = {none = true, lax = true, strict = true}
|
local same_site = string.trim(opts.same_site):lower()
|
||||||
local same_site = string.lower(opts.same_site)
|
local valid_values = {none = true, lax = true, strict = true}
|
||||||
|
|
||||||
if not valid_values[same_site] then
|
if not valid_values[same_site] then
|
||||||
error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
|
error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- If SameSite=None, the cookie must be secure
|
-- If SameSite=None, the cookie must be secure
|
||||||
if same_site == "none" and not cookie.secure then
|
if same_site == "none" and not cookie.secure then
|
||||||
cookie.secure = true
|
cookie.secure = true
|
||||||
end
|
end
|
||||||
|
|
||||||
cookie.same_site = opts.same_site
|
cookie.same_site = opts.same_site
|
||||||
else
|
else
|
||||||
cookie.same_site = "Lax"
|
cookie.same_site = "Lax"
|
||||||
end
|
end
|
||||||
|
|
||||||
table.insert(resp.cookies, cookie)
|
table.insert(resp.cookies, cookie)
|
||||||
return true
|
return true
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Get a cookie value
|
-- Get a cookie value
|
||||||
get = function(name)
|
get = function(name)
|
||||||
if type(name) ~= "string" then
|
if type(name) ~= "string" then
|
||||||
error("cookie.get: name must be a string", 2)
|
error("cookie.get: name must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local env = getfenv(2)
|
local env = getfenv(2)
|
||||||
|
|
||||||
if env.ctx and env.ctx.cookies then
|
if env.ctx and env.ctx.cookies then
|
||||||
return env.ctx.cookies[name]
|
return env.ctx.cookies[name]
|
||||||
end
|
end
|
||||||
|
|
||||||
if env.ctx and env.ctx._request_cookies then
|
if env.ctx and env.ctx._request_cookies then
|
||||||
return env.ctx._request_cookies[name]
|
return env.ctx._request_cookies[name]
|
||||||
end
|
end
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Remove a cookie
|
-- Remove a cookie
|
||||||
remove = function(name, path, domain)
|
remove = function(name, path, domain)
|
||||||
if type(name) ~= "string" then
|
if type(name) ~= "string" then
|
||||||
error("cookie.remove: name must be a string", 2)
|
error("cookie.remove: name must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
|
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
|
||||||
end
|
end
|
||||||
}
|
}
|
||||||
|
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
@ -311,7 +311,7 @@ local session = {
|
|||||||
end
|
end
|
||||||
|
|
||||||
local resp = __ensure_response()
|
local resp = __ensure_response()
|
||||||
resp.session = resp.session or {}
|
resp.session = resp.session or {}
|
||||||
resp.session[key] = value
|
resp.session[key] = value
|
||||||
end,
|
end,
|
||||||
|
|
||||||
@ -360,7 +360,6 @@ local session = {
|
|||||||
|
|
||||||
local resp = __ensure_response()
|
local resp = __ensure_response()
|
||||||
resp.session = {}
|
resp.session = {}
|
||||||
|
|
||||||
resp.session["__clear_all"] = true
|
resp.session["__clear_all"] = true
|
||||||
end
|
end
|
||||||
}
|
}
|
||||||
@ -370,51 +369,52 @@ local session = {
|
|||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
|
|
||||||
local csrf = {
|
local csrf = {
|
||||||
generate = function()
|
generate = function()
|
||||||
local token = util.generate_token(32)
|
local token = util.generate_token(32)
|
||||||
session.set("_csrf_token", token)
|
session.set("_csrf_token", token)
|
||||||
return token
|
return token
|
||||||
end,
|
end,
|
||||||
|
|
||||||
field = function()
|
field = function()
|
||||||
local token = session.get("_csrf_token")
|
local token = session.get("_csrf_token")
|
||||||
if not token then
|
if not token then
|
||||||
token = csrf.generate()
|
token = csrf.generate()
|
||||||
end
|
end
|
||||||
return string.format('<input type="hidden" name="_csrf_token" value="%s" />', token)
|
return string.format('<input type="hidden" name="_csrf_token" value="%s" />',
|
||||||
end,
|
util.html_special_chars(token))
|
||||||
|
end,
|
||||||
|
|
||||||
validate = function()
|
validate = function()
|
||||||
local env = getfenv(2)
|
local env = getfenv(2)
|
||||||
local token = false
|
local token = false
|
||||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
if env.ctx and env.ctx.session and env.ctx.session.data then
|
||||||
token = env.ctx.session.data["_csrf_token"]
|
token = env.ctx.session.data["_csrf_token"]
|
||||||
end
|
end
|
||||||
|
|
||||||
if not token then
|
if not token then
|
||||||
http.set_status(403)
|
http.set_status(403)
|
||||||
__http_response.body = "CSRF validation failed"
|
__http_response.body = "CSRF validation failed"
|
||||||
exit()
|
exit()
|
||||||
end
|
end
|
||||||
|
|
||||||
local request_token = nil
|
local request_token = nil
|
||||||
if env.ctx and env.ctx.form then
|
if env.ctx and env.ctx.form then
|
||||||
request_token = env.ctx.form._csrf_token
|
request_token = env.ctx.form._csrf_token
|
||||||
end
|
end
|
||||||
|
|
||||||
if not request_token and env.ctx and env.ctx._request_headers then
|
if not request_token and env.ctx and env.ctx._request_headers then
|
||||||
request_token = env.ctx._request_headers["x-csrf-token"] or
|
request_token = env.ctx._request_headers["x-csrf-token"] or
|
||||||
env.ctx._request_headers["csrf-token"]
|
env.ctx._request_headers["csrf-token"]
|
||||||
end
|
end
|
||||||
|
|
||||||
if not request_token or request_token ~= token then
|
if not request_token or request_token ~= token then
|
||||||
http.set_status(403)
|
http.set_status(403)
|
||||||
__http_response.body = "CSRF validation failed"
|
__http_response.body = "CSRF validation failed"
|
||||||
exit()
|
exit()
|
||||||
end
|
end
|
||||||
|
|
||||||
return true
|
return true
|
||||||
end
|
end
|
||||||
}
|
}
|
||||||
|
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
@ -423,11 +423,6 @@ local csrf = {
|
|||||||
|
|
||||||
-- Template processing with code execution
|
-- Template processing with code execution
|
||||||
_G.render = function(template_str, env)
|
_G.render = function(template_str, env)
|
||||||
local function escape_html(s)
|
|
||||||
local entities = {['&']='&', ['<']='<', ['>']='>', ['"']='"', ["'"]='''}
|
|
||||||
return (s:gsub([=[["><'&]]=], entities))
|
|
||||||
end
|
|
||||||
|
|
||||||
local function get_line(s, ln)
|
local function get_line(s, ln)
|
||||||
for line in s:gmatch("([^\n]*)\n?") do
|
for line in s:gmatch("([^\n]*)\n?") do
|
||||||
if ln == 1 then return line end
|
if ln == 1 then return line end
|
||||||
@ -516,7 +511,7 @@ _G.render = function(template_str, env)
|
|||||||
setfenv(fn, runtime_env)
|
setfenv(fn, runtime_env)
|
||||||
|
|
||||||
local output_buffer = {}
|
local output_buffer = {}
|
||||||
fn(tostring, escape_html, output_buffer, 0)
|
fn(tostring, util.html_special_chars, output_buffer, 0)
|
||||||
return table.concat(output_buffer)
|
return table.concat(output_buffer)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -550,8 +545,7 @@ _G.parse = function(template_str, env)
|
|||||||
local value = env[name]
|
local value = env[name]
|
||||||
local str = tostring(value or "")
|
local str = tostring(value or "")
|
||||||
if escaped then
|
if escaped then
|
||||||
local entities = {['&']='&', ['<']='<', ['>']='>', ['"']='"', ["'"]='''}
|
str = util.html_special_chars(str)
|
||||||
str = str:gsub([=[["><'&]]=], entities)
|
|
||||||
end
|
end
|
||||||
table.insert(output, str)
|
table.insert(output, str)
|
||||||
|
|
||||||
@ -591,8 +585,7 @@ _G.iparse = function(template_str, values)
|
|||||||
local value = values[value_index]
|
local value = values[value_index]
|
||||||
local str = tostring(value or "")
|
local str = tostring(value or "")
|
||||||
if escaped then
|
if escaped then
|
||||||
local entities = {['&']='&', ['<']='<', ['>']='>', ['"']='"', ["'"]='''}
|
str = util.html_special_chars(str)
|
||||||
str = str:gsub([=[["><'&]]=], entities)
|
|
||||||
end
|
end
|
||||||
table.insert(output, str)
|
table.insert(output, str)
|
||||||
|
|
||||||
|
|||||||
@ -1,70 +1,37 @@
|
|||||||
-- Simplified SQLite wrapper
|
local function normalize_params(params, ...)
|
||||||
-- Connection is now lightweight with persistent connection tracking
|
if type(params) == "table" then return params end
|
||||||
|
|
||||||
-- Helper function to handle parameters
|
|
||||||
local function handle_params(params, ...)
|
|
||||||
-- If params is a table, use it for named parameters
|
|
||||||
if type(params) == "table" then
|
|
||||||
return params
|
|
||||||
end
|
|
||||||
|
|
||||||
-- If we have varargs, collect them for positional parameters
|
|
||||||
local args = {...}
|
local args = {...}
|
||||||
if #args > 0 or params ~= nil then
|
if #args > 0 or params ~= nil then
|
||||||
-- Include the first param in the args
|
|
||||||
table.insert(args, 1, params)
|
table.insert(args, 1, params)
|
||||||
return args
|
return args
|
||||||
end
|
end
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Connection metatable
|
|
||||||
local connection_mt = {
|
local connection_mt = {
|
||||||
__index = {
|
__index = {
|
||||||
-- Execute a query and return results as a table
|
|
||||||
query = function(self, query, params, ...)
|
query = function(self, query, params, ...)
|
||||||
if type(query) ~= "string" then
|
if type(query) ~= "string" then
|
||||||
error("connection:query: query must be a string", 2)
|
error("connection:query: query must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Execute with proper connection tracking
|
local normalized_params = normalize_params(params, ...)
|
||||||
local results, token
|
local results, token = __sqlite_query(self.db_name, query, normalized_params, self.conn_token)
|
||||||
if params == nil and select('#', ...) == 0 then
|
|
||||||
results, token = __sqlite_query(self.db_name, query, nil, self.conn_token)
|
|
||||||
elseif type(params) == "table" then
|
|
||||||
results, token = __sqlite_query(self.db_name, query, params, self.conn_token)
|
|
||||||
else
|
|
||||||
local args = {params, ...}
|
|
||||||
results, token = __sqlite_query(self.db_name, query, args, self.conn_token)
|
|
||||||
end
|
|
||||||
|
|
||||||
self.conn_token = token
|
self.conn_token = token
|
||||||
return results
|
return results
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Execute a statement and return affected rows
|
|
||||||
exec = function(self, query, params, ...)
|
exec = function(self, query, params, ...)
|
||||||
if type(query) ~= "string" then
|
if type(query) ~= "string" then
|
||||||
error("connection:exec: query must be a string", 2)
|
error("connection:exec: query must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Execute with proper connection tracking
|
local normalized_params = normalize_params(params, ...)
|
||||||
local affected, token
|
local affected, token = __sqlite_exec(self.db_name, query, normalized_params, self.conn_token)
|
||||||
if params == nil and select('#', ...) == 0 then
|
|
||||||
affected, token = __sqlite_exec(self.db_name, query, nil, self.conn_token)
|
|
||||||
elseif type(params) == "table" then
|
|
||||||
affected, token = __sqlite_exec(self.db_name, query, params, self.conn_token)
|
|
||||||
else
|
|
||||||
local args = {params, ...}
|
|
||||||
affected, token = __sqlite_exec(self.db_name, query, args, self.conn_token)
|
|
||||||
end
|
|
||||||
|
|
||||||
self.conn_token = token
|
self.conn_token = token
|
||||||
return affected
|
return affected
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Close the connection (release back to pool)
|
|
||||||
close = function(self)
|
close = function(self)
|
||||||
if self.conn_token then
|
if self.conn_token then
|
||||||
local success = __sqlite_close(self.conn_token)
|
local success = __sqlite_close(self.conn_token)
|
||||||
@ -74,50 +41,61 @@ local connection_mt = {
|
|||||||
return false
|
return false
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Insert a row or multiple rows with a single query
|
|
||||||
insert = function(self, table_name, data, columns)
|
insert = function(self, table_name, data, columns)
|
||||||
if type(data) ~= "table" then
|
if type(data) ~= "table" then
|
||||||
error("connection:insert: data must be a table", 2)
|
error("connection:insert: data must be a table", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Case 1: Named columns with array data
|
-- Single object: {col1=val1, col2=val2}
|
||||||
if columns and type(columns) == "table" then
|
if data[1] == nil and next(data) ~= nil then
|
||||||
-- Check if we have multiple rows
|
local cols = table.keys(data)
|
||||||
if #data > 0 and type(data[1]) == "table" then
|
local placeholders = table.map(cols, function(_, i) return ":p" .. i end)
|
||||||
-- Build a single multi-value INSERT
|
local params = {}
|
||||||
local placeholders = {}
|
for i, col in ipairs(cols) do
|
||||||
local values = {}
|
params["p" .. i] = data[col]
|
||||||
local params = {}
|
end
|
||||||
local param_index = 1
|
|
||||||
|
|
||||||
for i, row in ipairs(data) do
|
local query = string.format(
|
||||||
|
"INSERT INTO %s (%s) VALUES (%s)",
|
||||||
|
table_name,
|
||||||
|
table.concat(cols, ", "),
|
||||||
|
table.concat(placeholders, ", ")
|
||||||
|
)
|
||||||
|
return self:exec(query, params)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Array data with columns
|
||||||
|
if columns and type(columns) == "table" then
|
||||||
|
if #data > 0 and type(data[1]) == "table" then
|
||||||
|
-- Multiple rows
|
||||||
|
local value_groups = {}
|
||||||
|
local params = {}
|
||||||
|
local param_idx = 1
|
||||||
|
|
||||||
|
for _, row in ipairs(data) do
|
||||||
local row_placeholders = {}
|
local row_placeholders = {}
|
||||||
for j, _ in ipairs(columns) do
|
for j = 1, #columns do
|
||||||
local param_name = "p" .. param_index
|
local param_name = "p" .. param_idx
|
||||||
table.insert(row_placeholders, ":" .. param_name)
|
table.insert(row_placeholders, ":" .. param_name)
|
||||||
params[param_name] = row[j]
|
params[param_name] = row[j]
|
||||||
param_index = param_index + 1
|
param_idx = param_idx + 1
|
||||||
end
|
end
|
||||||
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.format(
|
local query = string.format(
|
||||||
"INSERT INTO %s (%s) VALUES %s",
|
"INSERT INTO %s (%s) VALUES %s",
|
||||||
table_name,
|
table_name,
|
||||||
table.concat(columns, ", "),
|
table.concat(columns, ", "),
|
||||||
table.concat(placeholders, ", ")
|
table.concat(value_groups, ", ")
|
||||||
)
|
)
|
||||||
|
|
||||||
return self:exec(query, params)
|
return self:exec(query, params)
|
||||||
else
|
else
|
||||||
-- Single row with defined columns
|
-- Single row array
|
||||||
local placeholders = {}
|
local placeholders = table.map(columns, function(_, i) return ":p" .. i end)
|
||||||
local params = {}
|
local params = {}
|
||||||
|
for i = 1, #columns do
|
||||||
for i, col in ipairs(columns) do
|
params["p" .. i] = data[i]
|
||||||
local param_name = "p" .. i
|
|
||||||
table.insert(placeholders, ":" .. param_name)
|
|
||||||
params[param_name] = data[i]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.format(
|
local query = string.format(
|
||||||
@ -126,161 +104,71 @@ local connection_mt = {
|
|||||||
table.concat(columns, ", "),
|
table.concat(columns, ", "),
|
||||||
table.concat(placeholders, ", ")
|
table.concat(placeholders, ", ")
|
||||||
)
|
)
|
||||||
|
|
||||||
return self:exec(query, params)
|
return self:exec(query, params)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Case 2: Object-style single row {col1=val1, col2=val2}
|
-- Array of objects
|
||||||
if data[1] == nil and next(data) ~= nil then
|
if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then
|
||||||
local columns = {}
|
local cols = table.keys(data[1])
|
||||||
local placeholders = {}
|
local value_groups = {}
|
||||||
local params = {}
|
local params = {}
|
||||||
|
local param_idx = 1
|
||||||
|
|
||||||
for col, val in pairs(data) do
|
for _, row in ipairs(data) do
|
||||||
table.insert(columns, col)
|
local row_placeholders = {}
|
||||||
local param_name = "p" .. #columns
|
for _, col in ipairs(cols) do
|
||||||
table.insert(placeholders, ":" .. param_name)
|
local param_name = "p" .. param_idx
|
||||||
params[param_name] = val
|
table.insert(row_placeholders, ":" .. param_name)
|
||||||
|
params[param_name] = row[col]
|
||||||
|
param_idx = param_idx + 1
|
||||||
|
end
|
||||||
|
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.format(
|
local query = string.format(
|
||||||
"INSERT INTO %s (%s) VALUES (%s)",
|
"INSERT INTO %s (%s) VALUES %s",
|
||||||
table_name,
|
table_name,
|
||||||
table.concat(columns, ", "),
|
table.concat(cols, ", "),
|
||||||
table.concat(placeholders, ", ")
|
table.concat(value_groups, ", ")
|
||||||
)
|
)
|
||||||
|
|
||||||
return self:exec(query, params)
|
return self:exec(query, params)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Case 3: Array of rows without predefined columns
|
|
||||||
if #data > 0 and type(data[1]) == "table" then
|
|
||||||
-- Extract columns from the first row
|
|
||||||
local first_row = data[1]
|
|
||||||
local inferred_columns = {}
|
|
||||||
|
|
||||||
-- Determine if first row is array or object
|
|
||||||
local is_array = first_row[1] ~= nil
|
|
||||||
|
|
||||||
if is_array then
|
|
||||||
-- Cannot infer column names from array
|
|
||||||
error("connection:insert: column names required for array data", 2)
|
|
||||||
else
|
|
||||||
-- Get columns from object keys
|
|
||||||
for col, _ in pairs(first_row) do
|
|
||||||
table.insert(inferred_columns, col)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Build multi-value INSERT
|
|
||||||
local placeholders = {}
|
|
||||||
local params = {}
|
|
||||||
local param_index = 1
|
|
||||||
|
|
||||||
for _, row in ipairs(data) do
|
|
||||||
local row_placeholders = {}
|
|
||||||
for _, col in ipairs(inferred_columns) do
|
|
||||||
local param_name = "p" .. param_index
|
|
||||||
table.insert(row_placeholders, ":" .. param_name)
|
|
||||||
params[param_name] = row[col]
|
|
||||||
param_index = param_index + 1
|
|
||||||
end
|
|
||||||
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
|
||||||
end
|
|
||||||
|
|
||||||
local query = string.format(
|
|
||||||
"INSERT INTO %s (%s) VALUES %s",
|
|
||||||
table_name,
|
|
||||||
table.concat(inferred_columns, ", "),
|
|
||||||
table.concat(placeholders, ", ")
|
|
||||||
)
|
|
||||||
|
|
||||||
return self:exec(query, params)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
error("connection:insert: invalid data format", 2)
|
error("connection:insert: invalid data format", 2)
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Update rows in a table
|
|
||||||
update = function(self, table_name, data, where, where_params, ...)
|
update = function(self, table_name, data, where, where_params, ...)
|
||||||
if type(data) ~= "table" then
|
if type(data) ~= "table" or next(data) == nil then
|
||||||
error("connection:update: data must be a table", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Fast path for when there's no data
|
|
||||||
if next(data) == nil then
|
|
||||||
return 0
|
return 0
|
||||||
end
|
end
|
||||||
|
|
||||||
local sets = {}
|
local sets = {}
|
||||||
local params = {}
|
local params = {}
|
||||||
local param_index = 1
|
local param_idx = 1
|
||||||
|
|
||||||
for col, val in pairs(data) do
|
for col, val in pairs(data) do
|
||||||
local param_name = "p" .. param_index
|
local param_name = "p" .. param_idx
|
||||||
table.insert(sets, col .. " = :" .. param_name)
|
table.insert(sets, col .. " = :" .. param_name)
|
||||||
params[param_name] = val
|
params[param_name] = val
|
||||||
param_index = param_index + 1
|
param_idx = param_idx + 1
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.format(
|
local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", "))
|
||||||
"UPDATE %s SET %s",
|
|
||||||
table_name,
|
|
||||||
table.concat(sets, ", ")
|
|
||||||
)
|
|
||||||
|
|
||||||
if where then
|
if where then
|
||||||
query = query .. " WHERE " .. where
|
query = query .. " WHERE " .. where
|
||||||
|
|
||||||
if where_params then
|
if where_params then
|
||||||
if type(where_params) == "table" then
|
local normalized = normalize_params(where_params, ...)
|
||||||
-- Handle named parameters in WHERE clause
|
if type(normalized) == "table" then
|
||||||
for k, v in pairs(where_params) do
|
for k, v in pairs(normalized) do
|
||||||
local param_name
|
if type(k) == "string" then
|
||||||
if type(k) == "string" and k:sub(1, 1) == ":" then
|
params[k] = v
|
||||||
param_name = k:sub(2)
|
|
||||||
else
|
else
|
||||||
param_name = "w" .. param_index
|
params["w" .. param_idx] = v
|
||||||
-- Replace the placeholder in the WHERE clause
|
param_idx = param_idx + 1
|
||||||
where = where:gsub(":" .. k, ":" .. param_name)
|
|
||||||
end
|
end
|
||||||
params[param_name] = v
|
|
||||||
param_index = param_index + 1
|
|
||||||
end
|
end
|
||||||
else
|
|
||||||
-- Handle positional parameters (? placeholders)
|
|
||||||
local args = {where_params, ...}
|
|
||||||
local pos = 1
|
|
||||||
local offset = 0
|
|
||||||
|
|
||||||
-- Replace ? with named parameters
|
|
||||||
while true do
|
|
||||||
local start_pos, end_pos = where:find("?", pos)
|
|
||||||
if not start_pos then break end
|
|
||||||
|
|
||||||
local param_name = "w" .. param_index
|
|
||||||
local replacement = ":" .. param_name
|
|
||||||
|
|
||||||
where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1)
|
|
||||||
|
|
||||||
if args[pos - offset] ~= nil then
|
|
||||||
params[param_name] = args[pos - offset]
|
|
||||||
else
|
|
||||||
params[param_name] = nil
|
|
||||||
end
|
|
||||||
|
|
||||||
param_index = param_index + 1
|
|
||||||
pos = start_pos + #replacement
|
|
||||||
offset = offset + 1
|
|
||||||
end
|
|
||||||
|
|
||||||
query = string.format(
|
|
||||||
"UPDATE %s SET %s WHERE %s",
|
|
||||||
table_name,
|
|
||||||
table.concat(sets, ", "),
|
|
||||||
where
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -288,129 +176,108 @@ local connection_mt = {
|
|||||||
return self:exec(query, params)
|
return self:exec(query, params)
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Create a new table
|
|
||||||
create_table = function(self, table_name, ...)
|
create_table = function(self, table_name, ...)
|
||||||
local columns = {}
|
local column_definitions = {}
|
||||||
local indices = {}
|
local index_definitions = {}
|
||||||
|
|
||||||
-- Process all arguments
|
for _, def_string in ipairs({...}) do
|
||||||
for _, def in ipairs({...}) do
|
if type(def_string) == "string" then
|
||||||
if type(def) == "string" then
|
local is_unique = false
|
||||||
-- Check if it's an index definition
|
local index_def = def_string
|
||||||
local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
|
|
||||||
|
|
||||||
if index_def then
|
if string.starts_with(def_string, "UNIQUE INDEX:") then
|
||||||
-- Parse index definition
|
is_unique = true
|
||||||
local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)")
|
index_def = string.trim(def_string:sub(14))
|
||||||
|
elseif string.starts_with(def_string, "INDEX:") then
|
||||||
if index_name and columns_str then
|
index_def = string.trim(def_string:sub(7))
|
||||||
-- Split columns by comma
|
|
||||||
local index_columns = {}
|
|
||||||
for col in columns_str:gmatch("[^,]+") do
|
|
||||||
table.insert(index_columns, col:match("^%s*(.-)%s*$")) -- Trim whitespace
|
|
||||||
end
|
|
||||||
|
|
||||||
table.insert(indices, {
|
|
||||||
name = index_name,
|
|
||||||
columns = index_columns,
|
|
||||||
unique = (index_type == "UNIQUE INDEX:")
|
|
||||||
})
|
|
||||||
end
|
|
||||||
else
|
else
|
||||||
-- Regular column definition
|
table.insert(column_definitions, def_string)
|
||||||
table.insert(columns, def)
|
goto continue
|
||||||
|
end
|
||||||
|
|
||||||
|
local paren_pos = index_def:find("%(")
|
||||||
|
if not paren_pos then goto continue end
|
||||||
|
|
||||||
|
local index_name = string.trim(index_def:sub(1, paren_pos - 1))
|
||||||
|
local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$")
|
||||||
|
if not columns_part then goto continue end
|
||||||
|
|
||||||
|
local columns = table.map(string.split(columns_part, ","), string.trim)
|
||||||
|
|
||||||
|
if #columns > 0 then
|
||||||
|
table.insert(index_definitions, {
|
||||||
|
name = index_name,
|
||||||
|
columns = columns,
|
||||||
|
unique = is_unique
|
||||||
|
})
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
::continue::
|
||||||
end
|
end
|
||||||
|
|
||||||
if #columns == 0 then
|
if #column_definitions == 0 then
|
||||||
error("connection:create_table: no columns specified", 2)
|
error("connection:create_table: no column definitions specified for table " .. table_name, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Build combined statement for table and indices
|
|
||||||
local statements = {}
|
local statements = {}
|
||||||
|
|
||||||
-- Add the CREATE TABLE statement
|
|
||||||
table.insert(statements, string.format(
|
table.insert(statements, string.format(
|
||||||
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
||||||
table_name,
|
table_name,
|
||||||
table.concat(columns, ", ")
|
table.concat(column_definitions, ", ")
|
||||||
))
|
))
|
||||||
|
|
||||||
-- Add CREATE INDEX statements
|
for _, idx in ipairs(index_definitions) do
|
||||||
for _, idx in ipairs(indices) do
|
local unique_prefix = idx.unique and "UNIQUE " or ""
|
||||||
local unique = idx.unique and "UNIQUE " or ""
|
|
||||||
|
|
||||||
table.insert(statements, string.format(
|
table.insert(statements, string.format(
|
||||||
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
||||||
unique,
|
unique_prefix,
|
||||||
idx.name,
|
idx.name,
|
||||||
table_name,
|
table_name,
|
||||||
table.concat(idx.columns, ", ")
|
table.concat(idx.columns, ", ")
|
||||||
))
|
))
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Execute all statements in a single transaction
|
return self:exec(table.concat(statements, ";\n"))
|
||||||
local combined_sql = table.concat(statements, ";\n")
|
|
||||||
return self:exec(combined_sql)
|
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Delete rows
|
delete = function(self, table_name, where, params, ...)
|
||||||
delete = function(self, table_name, where, params)
|
|
||||||
local query = "DELETE FROM " .. table_name
|
local query = "DELETE FROM " .. table_name
|
||||||
|
|
||||||
if where then
|
if where then
|
||||||
query = query .. " WHERE " .. where
|
query = query .. " WHERE " .. where
|
||||||
end
|
end
|
||||||
|
return self:exec(query, normalize_params(params, ...))
|
||||||
return self:exec(query, params)
|
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Get one row efficiently
|
|
||||||
get_one = function(self, query, params, ...)
|
get_one = function(self, query, params, ...)
|
||||||
if type(query) ~= "string" then
|
if type(query) ~= "string" then
|
||||||
error("connection:get_one: query must be a string", 2)
|
error("connection:get_one: query must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Add LIMIT 1 to query if not already limited
|
|
||||||
local limited_query = query
|
local limited_query = query
|
||||||
if not query:lower():match("limit%s+%d+") then
|
if not string.contains(query:lower(), "limit") then
|
||||||
limited_query = query .. " LIMIT 1"
|
limited_query = query .. " LIMIT 1"
|
||||||
end
|
end
|
||||||
|
|
||||||
local results
|
local results = self:query(limited_query, normalize_params(params, ...))
|
||||||
if select('#', ...) > 0 then
|
|
||||||
results = self:query(limited_query, params, ...)
|
|
||||||
else
|
|
||||||
results = self:query(limited_query, params)
|
|
||||||
end
|
|
||||||
|
|
||||||
return results[1]
|
return results[1]
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Begin transaction
|
|
||||||
begin = function(self)
|
begin = function(self)
|
||||||
return self:exec("BEGIN TRANSACTION")
|
return self:exec("BEGIN TRANSACTION")
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Commit transaction
|
|
||||||
commit = function(self)
|
commit = function(self)
|
||||||
return self:exec("COMMIT")
|
return self:exec("COMMIT")
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Rollback transaction
|
|
||||||
rollback = function(self)
|
rollback = function(self)
|
||||||
return self:exec("ROLLBACK")
|
return self:exec("ROLLBACK")
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Transaction wrapper function
|
|
||||||
transaction = function(self, callback)
|
transaction = function(self, callback)
|
||||||
self:begin()
|
self:begin()
|
||||||
|
local success, result = pcall(callback, self)
|
||||||
local success, result = pcall(function()
|
|
||||||
return callback(self)
|
|
||||||
end)
|
|
||||||
|
|
||||||
if success then
|
if success then
|
||||||
self:commit()
|
self:commit()
|
||||||
return result
|
return result
|
||||||
@ -422,16 +289,13 @@ local connection_mt = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
-- Create sqlite() function that returns a connection object
|
|
||||||
return function(db_name)
|
return function(db_name)
|
||||||
if type(db_name) ~= "string" then
|
if type(db_name) ~= "string" then
|
||||||
error("sqlite: database name must be a string", 2)
|
error("sqlite: database name must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local conn = {
|
return setmetatable({
|
||||||
db_name = db_name,
|
db_name = db_name,
|
||||||
conn_token = nil -- Will be populated on first query/exec
|
conn_token = nil
|
||||||
}
|
}, connection_mt)
|
||||||
|
|
||||||
return setmetatable(conn, connection_mt)
|
|
||||||
end
|
end
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user