Compare commits
No commits in common. "2d43c457e16b6855e298b2dc5d5d204a901701bf" and "e45d63cf24d5693d31edd3ee774986f0e84652f1" have entirely different histories.
2d43c457e1
...
e45d63cf24
1
go.mod
1
go.mod
@ -14,7 +14,6 @@ require (
|
|||||||
require (
|
require (
|
||||||
filippo.io/edwards25519 v1.1.0 // indirect
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/goccy/go-json v0.10.5 // indirect
|
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@ -11,8 +11,6 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
|||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
|
||||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
local json = require("json")
|
|
||||||
local http = {}
|
local http = {}
|
||||||
|
local json = require("json")
|
||||||
|
local string = require("string")
|
||||||
|
|
||||||
-- Global routing tables
|
-- Global routing tables
|
||||||
_G._http_routes = _G._http_routes or {}
|
_G._http_routes = _G._http_routes or {}
|
||||||
@ -21,7 +22,7 @@ local function parse_cookies(cookie_header)
|
|||||||
if string.is_empty(cookie_header) then
|
if string.is_empty(cookie_header) then
|
||||||
return cookies
|
return cookies
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Split by semicolon and parse each cookie
|
-- Split by semicolon and parse each cookie
|
||||||
local cookie_pairs = string.split(cookie_header, ";")
|
local cookie_pairs = string.split(cookie_header, ";")
|
||||||
for _, cookie_pair in ipairs(cookie_pairs) do
|
for _, cookie_pair in ipairs(cookie_pairs) do
|
||||||
@ -31,12 +32,12 @@ local function parse_cookies(cookie_header)
|
|||||||
if #parts >= 2 then
|
if #parts >= 2 then
|
||||||
local name = string.trim(parts[1])
|
local name = string.trim(parts[1])
|
||||||
local value = string.trim(parts[2])
|
local value = string.trim(parts[2])
|
||||||
|
|
||||||
-- URL decode the value
|
-- URL decode the value
|
||||||
local success, decoded = pcall(function()
|
local success, decoded = pcall(function()
|
||||||
return string.url_decode(value)
|
return string.url_decode(value)
|
||||||
end)
|
end)
|
||||||
|
|
||||||
cookies[name] = success and decoded or value
|
cookies[name] = success and decoded or value
|
||||||
elseif #parts == 1 then
|
elseif #parts == 1 then
|
||||||
-- Cookie without value
|
-- Cookie without value
|
||||||
@ -44,7 +45,7 @@ local function parse_cookies(cookie_header)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return cookies
|
return cookies
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -56,29 +57,29 @@ local function split_path(path)
|
|||||||
if string.is_empty(path) or path == "/" then
|
if string.is_empty(path) or path == "/" then
|
||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Remove leading/trailing slashes and split
|
-- Remove leading/trailing slashes and split
|
||||||
local clean_path = string.trim(path, "/")
|
local clean_path = string.trim(path, "/")
|
||||||
if string.is_empty(clean_path) then
|
if string.is_empty(clean_path) then
|
||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
return string.split(clean_path, "/")
|
return string.split(clean_path, "/")
|
||||||
end
|
end
|
||||||
|
|
||||||
local function match_route(method, path)
|
local function match_route(method, path)
|
||||||
local path_segments = split_path(path)
|
local path_segments = split_path(path)
|
||||||
|
|
||||||
for _, route in ipairs(_G._http_routes) do
|
for _, route in ipairs(_G._http_routes) do
|
||||||
if route.method == method then
|
if route.method == method then
|
||||||
local params = {}
|
local params = {}
|
||||||
local route_segments = route.segments
|
local route_segments = route.segments
|
||||||
local match = true
|
local match = true
|
||||||
local i = 1
|
local i = 1
|
||||||
|
|
||||||
while i <= #route_segments and match do
|
while i <= #route_segments and match do
|
||||||
local route_seg = route_segments[i]
|
local route_seg = route_segments[i]
|
||||||
|
|
||||||
if route_seg == "*" then
|
if route_seg == "*" then
|
||||||
-- Wildcard captures everything remaining
|
-- Wildcard captures everything remaining
|
||||||
local remaining = {}
|
local remaining = {}
|
||||||
@ -103,7 +104,7 @@ local function match_route(method, path)
|
|||||||
end
|
end
|
||||||
i = i + 1
|
i = i + 1
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Must consume all segments unless wildcard
|
-- Must consume all segments unless wildcard
|
||||||
if match and (i > #path_segments or (route_segments[i-1] and route_segments[i-1] == "*")) then
|
if match and (i > #path_segments or (route_segments[i-1] and route_segments[i-1] == "*")) then
|
||||||
return route, params
|
return route, params
|
||||||
@ -116,23 +117,23 @@ end
|
|||||||
function _http_handle_request(req_table, res_table)
|
function _http_handle_request(req_table, res_table)
|
||||||
local req = Request.new(req_table)
|
local req = Request.new(req_table)
|
||||||
local res = Response.new(res_table)
|
local res = Response.new(res_table)
|
||||||
|
|
||||||
-- Execute middleware chain first
|
-- Execute middleware chain first
|
||||||
local function run_middleware(index)
|
local function run_middleware(index)
|
||||||
if index > #_G._http_middleware then
|
if index > #_G._http_middleware then
|
||||||
local route, params = match_route(req.method, req.path)
|
local route, params = match_route(req.method, req.path)
|
||||||
req.params = params
|
req.params = params
|
||||||
|
|
||||||
if not route then
|
if not route then
|
||||||
res:status(404):send("Not Found")
|
res:status(404):send("Not Found")
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Run route handler
|
-- Run route handler
|
||||||
route.handler(req, res)
|
route.handler(req, res)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
local mw = _G._http_middleware[index]
|
local mw = _G._http_middleware[index]
|
||||||
if mw.path == nil or string.starts_with(req.path, mw.path) then
|
if mw.path == nil or string.starts_with(req.path, mw.path) then
|
||||||
mw.handler(req, res, function()
|
mw.handler(req, res, function()
|
||||||
@ -142,7 +143,7 @@ function _http_handle_request(req_table, res_table)
|
|||||||
run_middleware(index + 1)
|
run_middleware(index + 1)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
run_middleware(1)
|
run_middleware(1)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -177,13 +178,13 @@ function http.server()
|
|||||||
local server = setmetatable({
|
local server = setmetatable({
|
||||||
_server_created = false
|
_server_created = false
|
||||||
}, Server)
|
}, Server)
|
||||||
|
|
||||||
local success, err = moonshark.http_create_server()
|
local success, err = moonshark.http_create_server()
|
||||||
if not success then
|
if not success then
|
||||||
error("Failed to create HTTP server: " .. (err or "unknown error"))
|
error("Failed to create HTTP server: " .. (err or "unknown error"))
|
||||||
end
|
end
|
||||||
server._server_created = true
|
server._server_created = true
|
||||||
|
|
||||||
return server
|
return server
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -196,7 +197,7 @@ function Server:use(...)
|
|||||||
else
|
else
|
||||||
error("Invalid arguments to use()")
|
error("Invalid arguments to use()")
|
||||||
end
|
end
|
||||||
|
|
||||||
return self
|
return self
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -205,16 +206,16 @@ function Server:_add_route(method, path, handler)
|
|||||||
if not string.starts_with(path, "/") then
|
if not string.starts_with(path, "/") then
|
||||||
path = "/" .. path
|
path = "/" .. path
|
||||||
end
|
end
|
||||||
|
|
||||||
local segments = split_path(path)
|
local segments = split_path(path)
|
||||||
|
|
||||||
table.insert(_G._http_routes, {
|
table.insert(_G._http_routes, {
|
||||||
method = method,
|
method = method,
|
||||||
path = path,
|
path = path,
|
||||||
segments = segments,
|
segments = segments,
|
||||||
handler = handler
|
handler = handler
|
||||||
})
|
})
|
||||||
|
|
||||||
return self
|
return self
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -261,31 +262,31 @@ function Server:listen(port, host, callback)
|
|||||||
if callback then callback() end
|
if callback then callback() end
|
||||||
return self
|
return self
|
||||||
end
|
end
|
||||||
|
|
||||||
if type(host) == "function" then
|
if type(host) == "function" then
|
||||||
callback = host
|
callback = host
|
||||||
host = "localhost"
|
host = "localhost"
|
||||||
end
|
end
|
||||||
|
|
||||||
host = host or "localhost"
|
host = host or "localhost"
|
||||||
local addr = host .. ":" .. tostring(port)
|
local addr = host .. ":" .. tostring(port)
|
||||||
|
|
||||||
-- Spawn workers first
|
-- Spawn workers first
|
||||||
local success, err = moonshark.http_spawn_workers()
|
local success, err = moonshark.http_spawn_workers()
|
||||||
if not success then
|
if not success then
|
||||||
error("Failed to spawn workers: " .. (err or "unknown error"))
|
error("Failed to spawn workers: " .. (err or "unknown error"))
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Then start listening
|
-- Then start listening
|
||||||
success, err = moonshark.http_listen(addr)
|
success, err = moonshark.http_listen(addr)
|
||||||
if not success then
|
if not success then
|
||||||
error("Failed to start server: " .. (err or "unknown error"))
|
error("Failed to start server: " .. (err or "unknown error"))
|
||||||
end
|
end
|
||||||
|
|
||||||
if callback then
|
if callback then
|
||||||
callback()
|
callback()
|
||||||
end
|
end
|
||||||
|
|
||||||
return self
|
return self
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -310,12 +311,12 @@ function Request.new(req_table)
|
|||||||
body = req_table.body or "",
|
body = req_table.body or "",
|
||||||
cookies = {}
|
cookies = {}
|
||||||
}, Request)
|
}, Request)
|
||||||
|
|
||||||
local cookie_header = req.headers["Cookie"] or req.headers["cookie"]
|
local cookie_header = req.headers["Cookie"] or req.headers["cookie"]
|
||||||
if cookie_header then
|
if cookie_header then
|
||||||
req.cookies = parse_cookies(cookie_header)
|
req.cookies = parse_cookies(cookie_header)
|
||||||
end
|
end
|
||||||
|
|
||||||
return req
|
return req
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -383,11 +384,11 @@ function Request:json()
|
|||||||
if string.is_empty(self.body) then
|
if string.is_empty(self.body) then
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
local success, result = pcall(function()
|
local success, result = pcall(function()
|
||||||
return json.decode(self.body)
|
return json.decode(self.body)
|
||||||
end)
|
end)
|
||||||
|
|
||||||
if success then
|
if success then
|
||||||
return result
|
return result
|
||||||
else
|
else
|
||||||
@ -442,7 +443,7 @@ function Response.new(res_table)
|
|||||||
_table = res_table,
|
_table = res_table,
|
||||||
_sent = false
|
_sent = false
|
||||||
}, Response)
|
}, Response)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -474,7 +475,7 @@ function Response:send(data)
|
|||||||
if self._sent then
|
if self._sent then
|
||||||
error("Response already sent")
|
error("Response already sent")
|
||||||
end
|
end
|
||||||
|
|
||||||
if type(data) == "table" then
|
if type(data) == "table" then
|
||||||
self:json(data)
|
self:json(data)
|
||||||
elseif type(data) == "number" then
|
elseif type(data) == "number" then
|
||||||
@ -483,7 +484,7 @@ function Response:send(data)
|
|||||||
else
|
else
|
||||||
self._table.body = tostring(data or "")
|
self._table.body = tostring(data or "")
|
||||||
end
|
end
|
||||||
|
|
||||||
self._sent = true
|
self._sent = true
|
||||||
return self
|
return self
|
||||||
end
|
end
|
||||||
@ -492,19 +493,19 @@ function Response:json(data)
|
|||||||
if self._sent then
|
if self._sent then
|
||||||
error("Response already sent")
|
error("Response already sent")
|
||||||
end
|
end
|
||||||
|
|
||||||
self:type("application/json; charset=utf-8")
|
self:type("application/json; charset=utf-8")
|
||||||
|
|
||||||
local success, json_str = pcall(function()
|
local success, json_str = pcall(function()
|
||||||
return json.encode(data)
|
return json.encode(data)
|
||||||
end)
|
end)
|
||||||
|
|
||||||
if success then
|
if success then
|
||||||
self._table.body = json_str
|
self._table.body = json_str
|
||||||
else
|
else
|
||||||
error("Failed to encode JSON response")
|
error("Failed to encode JSON response")
|
||||||
end
|
end
|
||||||
|
|
||||||
self._sent = true
|
self._sent = true
|
||||||
return self
|
return self
|
||||||
end
|
end
|
||||||
@ -513,7 +514,7 @@ function Response:text(text)
|
|||||||
if self._sent then
|
if self._sent then
|
||||||
error("Response already sent")
|
error("Response already sent")
|
||||||
end
|
end
|
||||||
|
|
||||||
self:type("text/plain; charset=utf-8")
|
self:type("text/plain; charset=utf-8")
|
||||||
self._table.body = tostring(text or "")
|
self._table.body = tostring(text or "")
|
||||||
self._sent = true
|
self._sent = true
|
||||||
@ -524,7 +525,7 @@ function Response:html(html)
|
|||||||
if self._sent then
|
if self._sent then
|
||||||
error("Response already sent")
|
error("Response already sent")
|
||||||
end
|
end
|
||||||
|
|
||||||
self:type("text/html; charset=utf-8")
|
self:type("text/html; charset=utf-8")
|
||||||
self._table.body = tostring(html or "")
|
self._table.body = tostring(html or "")
|
||||||
self._sent = true
|
self._sent = true
|
||||||
@ -535,7 +536,7 @@ function Response:xml(xml)
|
|||||||
if self._sent then
|
if self._sent then
|
||||||
error("Response already sent")
|
error("Response already sent")
|
||||||
end
|
end
|
||||||
|
|
||||||
self:type("application/xml; charset=utf-8")
|
self:type("application/xml; charset=utf-8")
|
||||||
self._table.body = tostring(xml or "")
|
self._table.body = tostring(xml or "")
|
||||||
self._sent = true
|
self._sent = true
|
||||||
@ -546,7 +547,7 @@ function Response:redirect(url, status)
|
|||||||
if self._sent then
|
if self._sent then
|
||||||
error("Response already sent")
|
error("Response already sent")
|
||||||
end
|
end
|
||||||
|
|
||||||
status = status or 302
|
status = status or 302
|
||||||
self:status(status)
|
self:status(status)
|
||||||
self:header("Location", url)
|
self:header("Location", url)
|
||||||
@ -559,45 +560,45 @@ function Response:cookie(name, value, options)
|
|||||||
if self._sent then
|
if self._sent then
|
||||||
error("Cannot set cookies after response has been sent")
|
error("Cannot set cookies after response has been sent")
|
||||||
end
|
end
|
||||||
|
|
||||||
options = options or {}
|
options = options or {}
|
||||||
local cookie_value = tostring(value)
|
local cookie_value = tostring(value)
|
||||||
|
|
||||||
-- URL encode the cookie value if it contains special characters
|
-- URL encode the cookie value if it contains special characters
|
||||||
if string.match("[;,\\s]", cookie_value) then
|
if string.match("[;,\\s]", cookie_value) then
|
||||||
cookie_value = string.url_encode(cookie_value)
|
cookie_value = string.url_encode(cookie_value)
|
||||||
end
|
end
|
||||||
|
|
||||||
local cookie = name .. "=" .. cookie_value
|
local cookie = name .. "=" .. cookie_value
|
||||||
|
|
||||||
if options.expires then
|
if options.expires then
|
||||||
cookie = cookie .. "; Expires=" .. options.expires
|
cookie = cookie .. "; Expires=" .. options.expires
|
||||||
end
|
end
|
||||||
|
|
||||||
if options.max_age then
|
if options.max_age then
|
||||||
cookie = cookie .. "; Max-Age=" .. tostring(options.max_age)
|
cookie = cookie .. "; Max-Age=" .. tostring(options.max_age)
|
||||||
end
|
end
|
||||||
|
|
||||||
if options.domain then
|
if options.domain then
|
||||||
cookie = cookie .. "; Domain=" .. options.domain
|
cookie = cookie .. "; Domain=" .. options.domain
|
||||||
end
|
end
|
||||||
|
|
||||||
if options.path then
|
if options.path then
|
||||||
cookie = cookie .. "; Path=" .. options.path
|
cookie = cookie .. "; Path=" .. options.path
|
||||||
end
|
end
|
||||||
|
|
||||||
if options.secure then
|
if options.secure then
|
||||||
cookie = cookie .. "; Secure"
|
cookie = cookie .. "; Secure"
|
||||||
end
|
end
|
||||||
|
|
||||||
if options.http_only then
|
if options.http_only then
|
||||||
cookie = cookie .. "; HttpOnly"
|
cookie = cookie .. "; HttpOnly"
|
||||||
end
|
end
|
||||||
|
|
||||||
if options.same_site then
|
if options.same_site then
|
||||||
cookie = cookie .. "; SameSite=" .. options.same_site
|
cookie = cookie .. "; SameSite=" .. options.same_site
|
||||||
end
|
end
|
||||||
|
|
||||||
local existing = self._table.headers["Set-Cookie"]
|
local existing = self._table.headers["Set-Cookie"]
|
||||||
if existing then
|
if existing then
|
||||||
if type(existing) == "table" then
|
if type(existing) == "table" then
|
||||||
@ -608,7 +609,7 @@ function Response:cookie(name, value, options)
|
|||||||
else
|
else
|
||||||
self._table.headers["Set-Cookie"] = cookie
|
self._table.headers["Set-Cookie"] = cookie
|
||||||
end
|
end
|
||||||
|
|
||||||
return self
|
return self
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -632,7 +633,7 @@ function Response:download(data, filename, content_type)
|
|||||||
if filename then
|
if filename then
|
||||||
self:attachment(filename)
|
self:attachment(filename)
|
||||||
end
|
end
|
||||||
|
|
||||||
if content_type then
|
if content_type then
|
||||||
self:type(content_type)
|
self:type(content_type)
|
||||||
elseif filename then
|
elseif filename then
|
||||||
@ -651,7 +652,7 @@ function Response:download(data, filename, content_type)
|
|||||||
else
|
else
|
||||||
self:type("application/octet-stream")
|
self:type("application/octet-stream")
|
||||||
end
|
end
|
||||||
|
|
||||||
return self:send(data)
|
return self:send(data)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -664,16 +665,16 @@ function http.cors(options)
|
|||||||
local origin = options.origin or "*"
|
local origin = options.origin or "*"
|
||||||
local methods = options.methods or "GET,HEAD,PUT,PATCH,POST,DELETE"
|
local methods = options.methods or "GET,HEAD,PUT,PATCH,POST,DELETE"
|
||||||
local headers = options.headers or "Content-Type,Authorization"
|
local headers = options.headers or "Content-Type,Authorization"
|
||||||
|
|
||||||
return function(req, res, next)
|
return function(req, res, next)
|
||||||
res:header("Access-Control-Allow-Origin", origin)
|
res:header("Access-Control-Allow-Origin", origin)
|
||||||
res:header("Access-Control-Allow-Methods", methods)
|
res:header("Access-Control-Allow-Methods", methods)
|
||||||
res:header("Access-Control-Allow-Headers", headers)
|
res:header("Access-Control-Allow-Headers", headers)
|
||||||
|
|
||||||
if options.credentials then
|
if options.credentials then
|
||||||
res:header("Access-Control-Allow-Credentials", "true")
|
res:header("Access-Control-Allow-Credentials", "true")
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.iequals(req.method, "OPTIONS") then
|
if string.iequals(req.method, "OPTIONS") then
|
||||||
res:status(204):send("")
|
res:status(204):send("")
|
||||||
else
|
else
|
||||||
@ -684,19 +685,19 @@ end
|
|||||||
|
|
||||||
function http.static(root_path, url_prefix)
|
function http.static(root_path, url_prefix)
|
||||||
url_prefix = url_prefix or "/"
|
url_prefix = url_prefix or "/"
|
||||||
|
|
||||||
-- Ensure prefix starts with /
|
-- Ensure prefix starts with /
|
||||||
if not string.starts_with(url_prefix, "/") then
|
if not string.starts_with(url_prefix, "/") then
|
||||||
url_prefix = "/" .. url_prefix
|
url_prefix = "/" .. url_prefix
|
||||||
end
|
end
|
||||||
|
|
||||||
if not _G.__IS_WORKER then
|
if not _G.__IS_WORKER then
|
||||||
local success, err = moonshark.http_register_static(url_prefix, root_path)
|
local success, err = moonshark.http_register_static(url_prefix, root_path)
|
||||||
if not success then
|
if not success then
|
||||||
error("Failed to register static handler: " .. (err or "unknown error"))
|
error("Failed to register static handler: " .. (err or "unknown error"))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Return no-op middleware
|
-- Return no-op middleware
|
||||||
return function(req, res, next)
|
return function(req, res, next)
|
||||||
next()
|
next()
|
||||||
@ -709,7 +710,7 @@ function http.json_parser()
|
|||||||
local success, data = pcall(function()
|
local success, data = pcall(function()
|
||||||
return req:json()
|
return req:json()
|
||||||
end)
|
end)
|
||||||
|
|
||||||
if success then
|
if success then
|
||||||
req.json_body = data
|
req.json_body = data
|
||||||
else
|
else
|
||||||
@ -723,15 +724,15 @@ end
|
|||||||
|
|
||||||
function http.logger(format)
|
function http.logger(format)
|
||||||
format = format or ":method :path :status :response-time ms"
|
format = format or ":method :path :status :response-time ms"
|
||||||
|
|
||||||
return function(req, res, next)
|
return function(req, res, next)
|
||||||
local start_time = os.clock()
|
local start_time = os.clock()
|
||||||
|
|
||||||
next()
|
next()
|
||||||
|
|
||||||
local duration = (os.clock() - start_time) * 1000
|
local duration = (os.clock() - start_time) * 1000
|
||||||
local status = res._table.status or 200
|
local status = res._table.status or 200
|
||||||
|
|
||||||
local log_message = string.template(format, {
|
local log_message = string.template(format, {
|
||||||
method = req.method,
|
method = req.method,
|
||||||
path = req.path,
|
path = req.path,
|
||||||
@ -740,7 +741,7 @@ function http.logger(format)
|
|||||||
["user-agent"] = req:user_agent(),
|
["user-agent"] = req:user_agent(),
|
||||||
ip = req:ip()
|
ip = req:ip()
|
||||||
})
|
})
|
||||||
|
|
||||||
print(log_message)
|
print(log_message)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -773,11 +774,11 @@ function http.rate_limit(options)
|
|||||||
local max_requests = options.max or 100
|
local max_requests = options.max or 100
|
||||||
local window_ms = options.window or 60000 -- 1 minute
|
local window_ms = options.window or 60000 -- 1 minute
|
||||||
local clients = {}
|
local clients = {}
|
||||||
|
|
||||||
return function(req, res, next)
|
return function(req, res, next)
|
||||||
local client_ip = req:ip()
|
local client_ip = req:ip()
|
||||||
local now = os.time() * 1000
|
local now = os.time() * 1000
|
||||||
|
|
||||||
if not clients[client_ip] then
|
if not clients[client_ip] then
|
||||||
clients[client_ip] = {count = 1, reset_time = now + window_ms}
|
clients[client_ip] = {count = 1, reset_time = now + window_ms}
|
||||||
else
|
else
|
||||||
@ -796,7 +797,7 @@ function http.rate_limit(options)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
next()
|
next()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -809,4 +810,4 @@ function http.create_server(callback)
|
|||||||
return app
|
return app
|
||||||
end
|
end
|
||||||
|
|
||||||
return http
|
return http
|
||||||
492
modules/kv/kv.go
492
modules/kv/kv.go
@ -1,492 +0,0 @@
|
|||||||
package kv
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
"github.com/goccy/go-json"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
stores = make(map[string]*Store)
|
|
||||||
mutex sync.RWMutex
|
|
||||||
)
|
|
||||||
|
|
||||||
type Store struct {
|
|
||||||
data map[string]string
|
|
||||||
expires map[string]int64
|
|
||||||
filename string
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetFunctionList() map[string]luajit.GoFunction {
|
|
||||||
return map[string]luajit.GoFunction{
|
|
||||||
"kv_open": kv_open,
|
|
||||||
"kv_get": kv_get,
|
|
||||||
"kv_set": kv_set,
|
|
||||||
"kv_delete": kv_delete,
|
|
||||||
"kv_clear": kv_clear,
|
|
||||||
"kv_has": kv_has,
|
|
||||||
"kv_size": kv_size,
|
|
||||||
"kv_keys": kv_keys,
|
|
||||||
"kv_values": kv_values,
|
|
||||||
"kv_save": kv_save,
|
|
||||||
"kv_close": kv_close,
|
|
||||||
"kv_increment": kv_increment,
|
|
||||||
"kv_append": kv_append,
|
|
||||||
"kv_expire": kv_expire,
|
|
||||||
"kv_cleanup_expired": kv_cleanup_expired,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_open(name, filename) -> boolean
|
|
||||||
func kv_open(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
filename := s.ToString(2)
|
|
||||||
|
|
||||||
if name == "" {
|
|
||||||
s.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
mutex.Lock()
|
|
||||||
defer mutex.Unlock()
|
|
||||||
|
|
||||||
if store, exists := stores[name]; exists {
|
|
||||||
if filename != "" && store.filename != filename {
|
|
||||||
store.filename = filename
|
|
||||||
}
|
|
||||||
s.PushBoolean(true)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store := &Store{
|
|
||||||
data: make(map[string]string),
|
|
||||||
expires: make(map[string]int64),
|
|
||||||
filename: filename,
|
|
||||||
}
|
|
||||||
|
|
||||||
if filename != "" {
|
|
||||||
store.load()
|
|
||||||
}
|
|
||||||
|
|
||||||
stores[name] = store
|
|
||||||
s.PushBoolean(true)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_get(name, key, default) -> value or default
|
|
||||||
func kv_get(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
key := s.ToString(2)
|
|
||||||
hasDefault := s.GetTop() >= 3
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
if hasDefault {
|
|
||||||
s.PushCopy(3)
|
|
||||||
} else {
|
|
||||||
s.PushNil()
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.RLock()
|
|
||||||
value, found := store.data[key]
|
|
||||||
store.mutex.RUnlock()
|
|
||||||
|
|
||||||
if found {
|
|
||||||
s.PushString(value)
|
|
||||||
} else if hasDefault {
|
|
||||||
s.PushCopy(3)
|
|
||||||
} else {
|
|
||||||
s.PushNil()
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_set(name, key, value) -> boolean
|
|
||||||
func kv_set(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
key := s.ToString(2)
|
|
||||||
value := s.ToString(3)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.Lock()
|
|
||||||
store.data[key] = value
|
|
||||||
store.mutex.Unlock()
|
|
||||||
|
|
||||||
s.PushBoolean(true)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_delete(name, key) -> boolean
|
|
||||||
func kv_delete(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
key := s.ToString(2)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.Lock()
|
|
||||||
_, existed := store.data[key]
|
|
||||||
delete(store.data, key)
|
|
||||||
delete(store.expires, key)
|
|
||||||
store.mutex.Unlock()
|
|
||||||
|
|
||||||
s.PushBoolean(existed)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_clear(name) -> boolean
|
|
||||||
func kv_clear(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.Lock()
|
|
||||||
store.data = make(map[string]string)
|
|
||||||
store.expires = make(map[string]int64)
|
|
||||||
store.mutex.Unlock()
|
|
||||||
|
|
||||||
s.PushBoolean(true)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_has(name, key) -> boolean
|
|
||||||
func kv_has(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
key := s.ToString(2)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.RLock()
|
|
||||||
_, found := store.data[key]
|
|
||||||
store.mutex.RUnlock()
|
|
||||||
|
|
||||||
s.PushBoolean(found)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_size(name) -> number
|
|
||||||
func kv_size(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.PushNumber(0)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.RLock()
|
|
||||||
size := len(store.data)
|
|
||||||
store.mutex.RUnlock()
|
|
||||||
|
|
||||||
s.PushNumber(float64(size))
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_keys(name) -> table
|
|
||||||
func kv_keys(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.NewTable()
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.RLock()
|
|
||||||
keys := make([]string, 0, len(store.data))
|
|
||||||
for k := range store.data {
|
|
||||||
keys = append(keys, k)
|
|
||||||
}
|
|
||||||
store.mutex.RUnlock()
|
|
||||||
|
|
||||||
s.PushValue(keys)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_values(name) -> table
|
|
||||||
func kv_values(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.NewTable()
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.RLock()
|
|
||||||
values := make([]string, 0, len(store.data))
|
|
||||||
for _, v := range store.data {
|
|
||||||
values = append(values, v)
|
|
||||||
}
|
|
||||||
store.mutex.RUnlock()
|
|
||||||
|
|
||||||
s.PushValue(values)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_save(name) -> boolean
|
|
||||||
func kv_save(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists || store.filename == "" {
|
|
||||||
s.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
err := store.save()
|
|
||||||
s.PushBoolean(err == nil)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_close(name) -> boolean
|
|
||||||
func kv_close(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
|
|
||||||
mutex.Lock()
|
|
||||||
defer mutex.Unlock()
|
|
||||||
|
|
||||||
store, exists := stores[name]
|
|
||||||
if !exists {
|
|
||||||
s.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.filename != "" {
|
|
||||||
store.save()
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(stores, name)
|
|
||||||
s.PushBoolean(true)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_increment(name, key, delta) -> number
|
|
||||||
func kv_increment(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
key := s.ToString(2)
|
|
||||||
delta := 1.0
|
|
||||||
if s.GetTop() >= 3 {
|
|
||||||
delta = s.ToNumber(3)
|
|
||||||
}
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.PushNumber(0)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.Lock()
|
|
||||||
current, _ := strconv.ParseFloat(store.data[key], 64)
|
|
||||||
newValue := current + delta
|
|
||||||
store.data[key] = strconv.FormatFloat(newValue, 'g', -1, 64)
|
|
||||||
store.mutex.Unlock()
|
|
||||||
|
|
||||||
s.PushNumber(newValue)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_append(name, key, value, separator) -> boolean
|
|
||||||
func kv_append(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
key := s.ToString(2)
|
|
||||||
value := s.ToString(3)
|
|
||||||
separator := ""
|
|
||||||
if s.GetTop() >= 4 {
|
|
||||||
separator = s.ToString(4)
|
|
||||||
}
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.Lock()
|
|
||||||
current := store.data[key]
|
|
||||||
if current == "" {
|
|
||||||
store.data[key] = value
|
|
||||||
} else {
|
|
||||||
store.data[key] = current + separator + value
|
|
||||||
}
|
|
||||||
store.mutex.Unlock()
|
|
||||||
|
|
||||||
s.PushBoolean(true)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_expire(name, key, ttl) -> boolean
|
|
||||||
func kv_expire(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
key := s.ToString(2)
|
|
||||||
ttl := s.ToNumber(3)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
store.mutex.Lock()
|
|
||||||
store.expires[key] = time.Now().Unix() + int64(ttl)
|
|
||||||
store.mutex.Unlock()
|
|
||||||
|
|
||||||
s.PushBoolean(true)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// kv_cleanup_expired(name) -> number
|
|
||||||
func kv_cleanup_expired(s *luajit.State) int {
|
|
||||||
name := s.ToString(1)
|
|
||||||
|
|
||||||
mutex.RLock()
|
|
||||||
store, exists := stores[name]
|
|
||||||
mutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
s.PushNumber(0)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
currentTime := time.Now().Unix()
|
|
||||||
deleted := 0
|
|
||||||
|
|
||||||
store.mutex.Lock()
|
|
||||||
for key, expireTime := range store.expires {
|
|
||||||
if currentTime >= expireTime {
|
|
||||||
delete(store.data, key)
|
|
||||||
delete(store.expires, key)
|
|
||||||
deleted++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
store.mutex.Unlock()
|
|
||||||
|
|
||||||
s.PushNumber(float64(deleted))
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (store *Store) load() error {
|
|
||||||
if store.filename == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
file, err := os.Open(store.filename)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
if strings.HasSuffix(store.filename, ".json") {
|
|
||||||
decoder := json.NewDecoder(file)
|
|
||||||
return decoder.Decode(&store.data)
|
|
||||||
}
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(file)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := strings.TrimSpace(scanner.Text())
|
|
||||||
if line == "" || strings.HasPrefix(line, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.SplitN(line, "=", 2)
|
|
||||||
if len(parts) == 2 {
|
|
||||||
store.data[parts[0]] = parts[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return scanner.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (store *Store) save() error {
|
|
||||||
if store.filename == "" {
|
|
||||||
return fmt.Errorf("no filename specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
file, err := os.Create(store.filename)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
store.mutex.RLock()
|
|
||||||
defer store.mutex.RUnlock()
|
|
||||||
|
|
||||||
if strings.HasSuffix(store.filename, ".json") {
|
|
||||||
encoder := json.NewEncoder(file)
|
|
||||||
encoder.SetIndent("", "\t")
|
|
||||||
return encoder.Encode(store.data)
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value := range store.data {
|
|
||||||
if _, err := fmt.Fprintf(file, "%s=%s\n", key, value); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@ -1,196 +0,0 @@
|
|||||||
local kv = {}
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- BASIC KEY-VALUE OPERATIONS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
function kv.open(name, filename)
|
|
||||||
if type(name) ~= "string" then error("kv.open: store name must be a string", 2) end
|
|
||||||
if filename ~= nil and type(filename) ~= "string" then error("kv.open: filename must be a string", 2) end
|
|
||||||
|
|
||||||
filename = filename or ""
|
|
||||||
return moonshark.kv_open(name, filename)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.get(name, key, default)
|
|
||||||
if type(name) ~= "string" then error("kv.get: store name must be a string", 2) end
|
|
||||||
if type(key) ~= "string" then error("kv.get: key must be a string", 2) end
|
|
||||||
|
|
||||||
if default ~= nil then
|
|
||||||
return moonshark.kv_get(name, key, default)
|
|
||||||
else
|
|
||||||
return moonshark.kv_get(name, key)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.set(name, key, value)
|
|
||||||
if type(name) ~= "string" then error("kv.set: store name must be a string", 2) end
|
|
||||||
if type(key) ~= "string" then error("kv.set: key must be a string", 2) end
|
|
||||||
if type(value) ~= "string" then error("kv.set: value must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_set(name, key, value)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.delete(name, key)
|
|
||||||
if type(name) ~= "string" then error("kv.delete: store name must be a string", 2) end
|
|
||||||
if type(key) ~= "string" then error("kv.delete: key must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_delete(name, key)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.clear(name)
|
|
||||||
if type(name) ~= "string" then error("kv.clear: store name must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_clear(name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.has(name, key)
|
|
||||||
if type(name) ~= "string" then error("kv.has: store name must be a string", 2) end
|
|
||||||
if type(key) ~= "string" then error("kv.has: key must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_has(name, key)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.size(name)
|
|
||||||
if type(name) ~= "string" then error("kv.size: store name must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_size(name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.keys(name)
|
|
||||||
if type(name) ~= "string" then error("kv.keys: store name must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_keys(name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.values(name)
|
|
||||||
if type(name) ~= "string" then error("kv.values: store name must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_values(name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.save(name)
|
|
||||||
if type(name) ~= "string" then error("kv.save: store name must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_save(name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.close(name)
|
|
||||||
if type(name) ~= "string" then error("kv.close: store name must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_close(name)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- UTILITY FUNCTIONS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
function kv.increment(name, key, delta)
|
|
||||||
if type(name) ~= "string" then error("kv.increment: store name must be a string", 2) end
|
|
||||||
if type(key) ~= "string" then error("kv.increment: key must be a string", 2) end
|
|
||||||
delta = delta or 1
|
|
||||||
if type(delta) ~= "number" then error("kv.increment: delta must be a number", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_increment(name, key, delta)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.append(name, key, value, separator)
|
|
||||||
if type(name) ~= "string" then error("kv.append: store name must be a string", 2) end
|
|
||||||
if type(key) ~= "string" then error("kv.append: key must be a string", 2) end
|
|
||||||
if type(value) ~= "string" then error("kv.append: value must be a string", 2) end
|
|
||||||
separator = separator or ""
|
|
||||||
if type(separator) ~= "string" then error("kv.append: separator must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_append(name, key, value, separator)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.expire(name, key, ttl)
|
|
||||||
if type(name) ~= "string" then error("kv.expire: store name must be a string", 2) end
|
|
||||||
if type(key) ~= "string" then error("kv.expire: key must be a string", 2) end
|
|
||||||
if type(ttl) ~= "number" or ttl <= 0 then error("kv.expire: TTL must be a positive number", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_expire(name, key, ttl)
|
|
||||||
end
|
|
||||||
|
|
||||||
function kv.cleanup_expired(name)
|
|
||||||
if type(name) ~= "string" then error("kv.cleanup_expired: store name must be a string", 2) end
|
|
||||||
|
|
||||||
return moonshark.kv_cleanup_expired(name)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- OBJECT-ORIENTED INTERFACE
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
local Store = {}
|
|
||||||
Store.__index = Store
|
|
||||||
|
|
||||||
function kv.create(name, filename)
|
|
||||||
if type(name) ~= "string" then error("kv.create: store name must be a string", 2) end
|
|
||||||
if filename ~= nil and type(filename) ~= "string" then error("kv.create: filename must be a string", 2) end
|
|
||||||
|
|
||||||
local success = kv.open(name, filename)
|
|
||||||
if not success then
|
|
||||||
error("kv.create: failed to open store '" .. name .. "'", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
return setmetatable({name = name}, Store)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:get(key, default)
|
|
||||||
return kv.get(self.name, key, default)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:set(key, value)
|
|
||||||
return kv.set(self.name, key, value)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:delete(key)
|
|
||||||
return kv.delete(self.name, key)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:clear()
|
|
||||||
return kv.clear(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:has(key)
|
|
||||||
return kv.has(self.name, key)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:size()
|
|
||||||
return kv.size(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:keys()
|
|
||||||
return kv.keys(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:values()
|
|
||||||
return kv.values(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:save()
|
|
||||||
return kv.save(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:close()
|
|
||||||
return kv.close(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:increment(key, delta)
|
|
||||||
return kv.increment(self.name, key, delta)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:append(key, value, separator)
|
|
||||||
return kv.append(self.name, key, value, separator)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:expire(key, ttl)
|
|
||||||
return kv.expire(self.name, key, ttl)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Store:cleanup_expired()
|
|
||||||
return kv.cleanup_expired(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
return kv
|
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
local str = require("string")
|
||||||
local tbl = require("table")
|
local tbl = require("table")
|
||||||
local mysql = {}
|
local mysql = {}
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ function Connection:query(query_str, ...)
|
|||||||
if not self._id then
|
if not self._id then
|
||||||
error("Connection is closed")
|
error("Connection is closed")
|
||||||
end
|
end
|
||||||
query_str = string.normalize_whitespace(query_str)
|
query_str = str.normalize_whitespace(query_str)
|
||||||
return moonshark.sql_query(self._id, query_str, ...)
|
return moonshark.sql_query(self._id, query_str, ...)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ function Connection:exec(query_str, ...)
|
|||||||
if not self._id then
|
if not self._id then
|
||||||
error("Connection is closed")
|
error("Connection is closed")
|
||||||
end
|
end
|
||||||
query_str = string.normalize_whitespace(query_str)
|
query_str = str.normalize_whitespace(query_str)
|
||||||
return moonshark.sql_exec(self._id, query_str, ...)
|
return moonshark.sql_exec(self._id, query_str, ...)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -84,20 +85,20 @@ function Connection:begin()
|
|||||||
if not tx.active then
|
if not tx.active then
|
||||||
error("Transaction is not active")
|
error("Transaction is not active")
|
||||||
end
|
end
|
||||||
if string.is_blank(name) then
|
if str.is_blank(name) then
|
||||||
error("Savepoint name cannot be empty")
|
error("Savepoint name cannot be empty")
|
||||||
end
|
end
|
||||||
return tx.conn:exec(string.template("SAVEPOINT ${name}", {name = name}))
|
return tx.conn:exec(str.template("SAVEPOINT ${name}", {name = name}))
|
||||||
end,
|
end,
|
||||||
|
|
||||||
rollback_to = function(tx, name)
|
rollback_to = function(tx, name)
|
||||||
if not tx.active then
|
if not tx.active then
|
||||||
error("Transaction is not active")
|
error("Transaction is not active")
|
||||||
end
|
end
|
||||||
if string.is_blank(name) then
|
if str.is_blank(name) then
|
||||||
error("Savepoint name cannot be empty")
|
error("Savepoint name cannot be empty")
|
||||||
end
|
end
|
||||||
return tx.conn:exec(string.template("ROLLBACK TO SAVEPOINT ${name}", {name = name}))
|
return tx.conn:exec(str.template("ROLLBACK TO SAVEPOINT ${name}", {name = name}))
|
||||||
end,
|
end,
|
||||||
|
|
||||||
query = function(tx, query_str, ...)
|
query = function(tx, query_str, ...)
|
||||||
@ -134,7 +135,7 @@ end
|
|||||||
|
|
||||||
-- Simplified MySQL-specific query builder helpers
|
-- Simplified MySQL-specific query builder helpers
|
||||||
function Connection:insert(table_name, data)
|
function Connection:insert(table_name, data)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -142,7 +143,7 @@ function Connection:insert(table_name, data)
|
|||||||
local values = tbl.values(data)
|
local values = tbl.values(data)
|
||||||
local placeholders = tbl.map(keys, function() return "?" end)
|
local placeholders = tbl.map(keys, function() return "?" end)
|
||||||
|
|
||||||
local query = string.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
|
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
columns = tbl.concat(keys, ", "),
|
columns = tbl.concat(keys, ", "),
|
||||||
placeholders = tbl.concat(placeholders, ", ")
|
placeholders = tbl.concat(placeholders, ", ")
|
||||||
@ -152,7 +153,7 @@ function Connection:insert(table_name, data)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:upsert(table_name, data, update_data)
|
function Connection:upsert(table_name, data, update_data)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -163,10 +164,10 @@ function Connection:upsert(table_name, data, update_data)
|
|||||||
-- Use update_data if provided, otherwise update with same data
|
-- Use update_data if provided, otherwise update with same data
|
||||||
local update_source = update_data or data
|
local update_source = update_data or data
|
||||||
local updates = tbl.map(tbl.keys(update_source), function(key)
|
local updates = tbl.map(tbl.keys(update_source), function(key)
|
||||||
return string.template("${key} = VALUES(${key})", {key = key})
|
return str.template("${key} = VALUES(${key})", {key = key})
|
||||||
end)
|
end)
|
||||||
|
|
||||||
local query = string.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON DUPLICATE KEY UPDATE ${updates}", {
|
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON DUPLICATE KEY UPDATE ${updates}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
columns = tbl.concat(keys, ", "),
|
columns = tbl.concat(keys, ", "),
|
||||||
placeholders = tbl.concat(placeholders, ", "),
|
placeholders = tbl.concat(placeholders, ", "),
|
||||||
@ -177,7 +178,7 @@ function Connection:upsert(table_name, data, update_data)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:replace(table_name, data)
|
function Connection:replace(table_name, data)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -185,7 +186,7 @@ function Connection:replace(table_name, data)
|
|||||||
local values = tbl.values(data)
|
local values = tbl.values(data)
|
||||||
local placeholders = tbl.map(keys, function() return "?" end)
|
local placeholders = tbl.map(keys, function() return "?" end)
|
||||||
|
|
||||||
local query = string.template("REPLACE INTO ${table} (${columns}) VALUES (${placeholders})", {
|
local query = str.template("REPLACE INTO ${table} (${columns}) VALUES (${placeholders})", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
columns = tbl.concat(keys, ", "),
|
columns = tbl.concat(keys, ", "),
|
||||||
placeholders = tbl.concat(placeholders, ", ")
|
placeholders = tbl.concat(placeholders, ", ")
|
||||||
@ -195,20 +196,20 @@ function Connection:replace(table_name, data)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:update(table_name, data, where_clause, ...)
|
function Connection:update(table_name, data, where_clause, ...)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
if string.is_blank(where_clause) then
|
if str.is_blank(where_clause) then
|
||||||
error("WHERE clause cannot be empty for UPDATE")
|
error("WHERE clause cannot be empty for UPDATE")
|
||||||
end
|
end
|
||||||
|
|
||||||
local keys = tbl.keys(data)
|
local keys = tbl.keys(data)
|
||||||
local values = tbl.values(data)
|
local values = tbl.values(data)
|
||||||
local sets = tbl.map(keys, function(key)
|
local sets = tbl.map(keys, function(key)
|
||||||
return string.template("${key} = ?", {key = key})
|
return str.template("${key} = ?", {key = key})
|
||||||
end)
|
end)
|
||||||
|
|
||||||
local query = string.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
|
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
sets = tbl.concat(sets, ", "),
|
sets = tbl.concat(sets, ", "),
|
||||||
where = where_clause
|
where = where_clause
|
||||||
@ -222,14 +223,14 @@ function Connection:update(table_name, data, where_clause, ...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:delete(table_name, where_clause, ...)
|
function Connection:delete(table_name, where_clause, ...)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
if string.is_blank(where_clause) then
|
if str.is_blank(where_clause) then
|
||||||
error("WHERE clause cannot be empty for DELETE")
|
error("WHERE clause cannot be empty for DELETE")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("DELETE FROM ${table} WHERE ${where}", {
|
local query = str.template("DELETE FROM ${table} WHERE ${where}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
where = where_clause
|
where = where_clause
|
||||||
})
|
})
|
||||||
@ -237,7 +238,7 @@ function Connection:delete(table_name, where_clause, ...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:select(table_name, columns, where_clause, ...)
|
function Connection:select(table_name, columns, where_clause, ...)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -247,15 +248,15 @@ function Connection:select(table_name, columns, where_clause, ...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
local query
|
local query
|
||||||
if where_clause and not string.is_blank(where_clause) then
|
if where_clause and not str.is_blank(where_clause) then
|
||||||
query = string.template("SELECT ${columns} FROM ${table} WHERE ${where}", {
|
query = str.template("SELECT ${columns} FROM ${table} WHERE ${where}", {
|
||||||
columns = columns,
|
columns = columns,
|
||||||
table = table_name,
|
table = table_name,
|
||||||
where = where_clause
|
where = where_clause
|
||||||
})
|
})
|
||||||
return self:query(query, ...)
|
return self:query(query, ...)
|
||||||
else
|
else
|
||||||
query = string.template("SELECT ${columns} FROM ${table}", {
|
query = str.template("SELECT ${columns} FROM ${table}", {
|
||||||
columns = columns,
|
columns = columns,
|
||||||
table = table_name
|
table = table_name
|
||||||
})
|
})
|
||||||
@ -265,19 +266,19 @@ end
|
|||||||
|
|
||||||
-- MySQL schema helpers
|
-- MySQL schema helpers
|
||||||
function Connection:database_exists(database_name)
|
function Connection:database_exists(database_name)
|
||||||
if string.is_blank(database_name) then
|
if str.is_blank(database_name) then
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
local result = self:query_value(
|
local result = self:query_value(
|
||||||
"SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?",
|
"SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?",
|
||||||
string.trim(database_name)
|
str.trim(database_name)
|
||||||
)
|
)
|
||||||
return result ~= nil
|
return result ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:table_exists(table_name, database_name)
|
function Connection:table_exists(table_name, database_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -288,13 +289,13 @@ function Connection:table_exists(table_name, database_name)
|
|||||||
|
|
||||||
local result = self:query_value(
|
local result = self:query_value(
|
||||||
"SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?",
|
"SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?",
|
||||||
string.trim(database_name), string.trim(table_name)
|
str.trim(database_name), str.trim(table_name)
|
||||||
)
|
)
|
||||||
return result ~= nil
|
return result ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:column_exists(table_name, column_name, database_name)
|
function Connection:column_exists(table_name, column_name, database_name)
|
||||||
if string.is_blank(table_name) or string.is_blank(column_name) then
|
if str.is_blank(table_name) or str.is_blank(column_name) then
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -306,19 +307,19 @@ function Connection:column_exists(table_name, column_name, database_name)
|
|||||||
local result = self:query_value([[
|
local result = self:query_value([[
|
||||||
SELECT COLUMN_NAME FROM information_schema.COLUMNS
|
SELECT COLUMN_NAME FROM information_schema.COLUMNS
|
||||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?
|
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?
|
||||||
]], string.trim(database_name), string.trim(table_name), string.trim(column_name))
|
]], str.trim(database_name), str.trim(table_name), str.trim(column_name))
|
||||||
return result ~= nil
|
return result ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:create_database(database_name, charset, collation)
|
function Connection:create_database(database_name, charset, collation)
|
||||||
if string.is_blank(database_name) then
|
if str.is_blank(database_name) then
|
||||||
error("Database name cannot be empty")
|
error("Database name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local charset_clause = charset and string.template(" CHARACTER SET ${charset}", {charset = charset}) or ""
|
local charset_clause = charset and str.template(" CHARACTER SET ${charset}", {charset = charset}) or ""
|
||||||
local collation_clause = collation and string.template(" COLLATE ${collation}", {collation = collation}) or ""
|
local collation_clause = collation and str.template(" COLLATE ${collation}", {collation = collation}) or ""
|
||||||
|
|
||||||
local query = string.template("CREATE DATABASE IF NOT EXISTS ${database}${charset}${collation}", {
|
local query = str.template("CREATE DATABASE IF NOT EXISTS ${database}${charset}${collation}", {
|
||||||
database = database_name,
|
database = database_name,
|
||||||
charset = charset_clause,
|
charset = charset_clause,
|
||||||
collation = collation_clause
|
collation = collation_clause
|
||||||
@ -327,25 +328,25 @@ function Connection:create_database(database_name, charset, collation)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_database(database_name)
|
function Connection:drop_database(database_name)
|
||||||
if string.is_blank(database_name) then
|
if str.is_blank(database_name) then
|
||||||
error("Database name cannot be empty")
|
error("Database name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("DROP DATABASE IF EXISTS ${database}", {database = database_name})
|
local query = str.template("DROP DATABASE IF EXISTS ${database}", {database = database_name})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:create_table(table_name, schema, engine, charset)
|
function Connection:create_table(table_name, schema, engine, charset)
|
||||||
if string.is_blank(table_name) or string.is_blank(schema) then
|
if str.is_blank(table_name) or str.is_blank(schema) then
|
||||||
error("Table name and schema cannot be empty")
|
error("Table name and schema cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local engine_clause = engine and string.template(" ENGINE=${engine}", {engine = string.upper(engine)}) or ""
|
local engine_clause = engine and str.template(" ENGINE=${engine}", {engine = str.upper(engine)}) or ""
|
||||||
local charset_clause = charset and string.template(" CHARACTER SET ${charset}", {charset = charset}) or ""
|
local charset_clause = charset and str.template(" CHARACTER SET ${charset}", {charset = charset}) or ""
|
||||||
|
|
||||||
local query = string.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})${engine}${charset}", {
|
local query = str.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})${engine}${charset}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
schema = string.trim(schema),
|
schema = str.trim(schema),
|
||||||
engine = engine_clause,
|
engine = engine_clause,
|
||||||
charset = charset_clause
|
charset = charset_clause
|
||||||
})
|
})
|
||||||
@ -353,34 +354,34 @@ function Connection:create_table(table_name, schema, engine, charset)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_table(table_name)
|
function Connection:drop_table(table_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("DROP TABLE IF EXISTS ${table}", {table = table_name})
|
local query = str.template("DROP TABLE IF EXISTS ${table}", {table = table_name})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:add_column(table_name, column_def, position)
|
function Connection:add_column(table_name, column_def, position)
|
||||||
if string.is_blank(table_name) or string.is_blank(column_def) then
|
if str.is_blank(table_name) or str.is_blank(column_def) then
|
||||||
error("Table name and column definition cannot be empty")
|
error("Table name and column definition cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local position_clause = position and string.template(" ${position}", {position = position}) or ""
|
local position_clause = position and str.template(" ${position}", {position = position}) or ""
|
||||||
local query = string.template("ALTER TABLE ${table} ADD COLUMN ${column}${position}", {
|
local query = str.template("ALTER TABLE ${table} ADD COLUMN ${column}${position}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
column = string.trim(column_def),
|
column = str.trim(column_def),
|
||||||
position = position_clause
|
position = position_clause
|
||||||
})
|
})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_column(table_name, column_name)
|
function Connection:drop_column(table_name, column_name)
|
||||||
if string.is_blank(table_name) or string.is_blank(column_name) then
|
if str.is_blank(table_name) or str.is_blank(column_name) then
|
||||||
error("Table name and column name cannot be empty")
|
error("Table name and column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("ALTER TABLE ${table} DROP COLUMN ${column}", {
|
local query = str.template("ALTER TABLE ${table} DROP COLUMN ${column}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
column = column_name
|
column = column_name
|
||||||
})
|
})
|
||||||
@ -388,23 +389,23 @@ function Connection:drop_column(table_name, column_name)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:modify_column(table_name, column_def)
|
function Connection:modify_column(table_name, column_def)
|
||||||
if string.is_blank(table_name) or string.is_blank(column_def) then
|
if str.is_blank(table_name) or str.is_blank(column_def) then
|
||||||
error("Table name and column definition cannot be empty")
|
error("Table name and column definition cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("ALTER TABLE ${table} MODIFY COLUMN ${column}", {
|
local query = str.template("ALTER TABLE ${table} MODIFY COLUMN ${column}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
column = string.trim(column_def)
|
column = str.trim(column_def)
|
||||||
})
|
})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:rename_table(old_name, new_name)
|
function Connection:rename_table(old_name, new_name)
|
||||||
if string.is_blank(old_name) or string.is_blank(new_name) then
|
if str.is_blank(old_name) or str.is_blank(new_name) then
|
||||||
error("Old and new table names cannot be empty")
|
error("Old and new table names cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("RENAME TABLE ${old} TO ${new}", {
|
local query = str.template("RENAME TABLE ${old} TO ${new}", {
|
||||||
old = old_name,
|
old = old_name,
|
||||||
new = new_name
|
new = new_name
|
||||||
})
|
})
|
||||||
@ -412,15 +413,15 @@ function Connection:rename_table(old_name, new_name)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:create_index(index_name, table_name, columns, unique, type)
|
function Connection:create_index(index_name, table_name, columns, unique, type)
|
||||||
if string.is_blank(index_name) or string.is_blank(table_name) then
|
if str.is_blank(index_name) or str.is_blank(table_name) then
|
||||||
error("Index name and table name cannot be empty")
|
error("Index name and table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local unique_clause = unique and "UNIQUE " or ""
|
local unique_clause = unique and "UNIQUE " or ""
|
||||||
local type_clause = type and string.template(" USING ${type}", {type = string.upper(type)}) or ""
|
local type_clause = type and str.template(" USING ${type}", {type = str.upper(type)}) or ""
|
||||||
local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
|
local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
|
||||||
|
|
||||||
local query = string.template("CREATE ${unique}INDEX ${index} ON ${table} (${columns})${type}", {
|
local query = str.template("CREATE ${unique}INDEX ${index} ON ${table} (${columns})${type}", {
|
||||||
unique = unique_clause,
|
unique = unique_clause,
|
||||||
index = index_name,
|
index = index_name,
|
||||||
table = table_name,
|
table = table_name,
|
||||||
@ -431,11 +432,11 @@ function Connection:create_index(index_name, table_name, columns, unique, type)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_index(index_name, table_name)
|
function Connection:drop_index(index_name, table_name)
|
||||||
if string.is_blank(index_name) or string.is_blank(table_name) then
|
if str.is_blank(index_name) or str.is_blank(table_name) then
|
||||||
error("Index name and table name cannot be empty")
|
error("Index name and table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("DROP INDEX ${index} ON ${table}", {
|
local query = str.template("DROP INDEX ${index} ON ${table}", {
|
||||||
index = index_name,
|
index = index_name,
|
||||||
table = table_name
|
table = table_name
|
||||||
})
|
})
|
||||||
@ -444,51 +445,51 @@ end
|
|||||||
|
|
||||||
-- MySQL maintenance functions
|
-- MySQL maintenance functions
|
||||||
function Connection:optimize(table_name)
|
function Connection:optimize(table_name)
|
||||||
local table_clause = table_name and string.template(" ${table}", {table = table_name}) or ""
|
local table_clause = table_name and str.template(" ${table}", {table = table_name}) or ""
|
||||||
return self:query(string.template("OPTIMIZE TABLE${table}", {table = table_clause}))
|
return self:query(str.template("OPTIMIZE TABLE${table}", {table = table_clause}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:repair(table_name)
|
function Connection:repair(table_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty for REPAIR")
|
error("Table name cannot be empty for REPAIR")
|
||||||
end
|
end
|
||||||
return self:query(string.template("REPAIR TABLE ${table}", {table = table_name}))
|
return self:query(str.template("REPAIR TABLE ${table}", {table = table_name}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:check_table(table_name, options)
|
function Connection:check_table(table_name, options)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty for CHECK")
|
error("Table name cannot be empty for CHECK")
|
||||||
end
|
end
|
||||||
|
|
||||||
local options_clause = ""
|
local options_clause = ""
|
||||||
if options then
|
if options then
|
||||||
local valid_options = {"QUICK", "FAST", "MEDIUM", "EXTENDED", "CHANGED"}
|
local valid_options = {"QUICK", "FAST", "MEDIUM", "EXTENDED", "CHANGED"}
|
||||||
local options_upper = string.upper(options)
|
local options_upper = str.upper(options)
|
||||||
|
|
||||||
if tbl.contains(valid_options, options_upper) then
|
if tbl.contains(valid_options, options_upper) then
|
||||||
options_clause = string.template(" ${options}", {options = options_upper})
|
options_clause = str.template(" ${options}", {options = options_upper})
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return self:query(string.template("CHECK TABLE ${table}${options}", {
|
return self:query(str.template("CHECK TABLE ${table}${options}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
options = options_clause
|
options = options_clause
|
||||||
}))
|
}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:analyze_table(table_name)
|
function Connection:analyze_table(table_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty for ANALYZE")
|
error("Table name cannot be empty for ANALYZE")
|
||||||
end
|
end
|
||||||
return self:query(string.template("ANALYZE TABLE ${table}", {table = table_name}))
|
return self:query(str.template("ANALYZE TABLE ${table}", {table = table_name}))
|
||||||
end
|
end
|
||||||
|
|
||||||
-- MySQL settings and introspection
|
-- MySQL settings and introspection
|
||||||
function Connection:show(what)
|
function Connection:show(what)
|
||||||
if string.is_blank(what) then
|
if str.is_blank(what) then
|
||||||
error("SHOW parameter cannot be empty")
|
error("SHOW parameter cannot be empty")
|
||||||
end
|
end
|
||||||
return self:query(string.template("SHOW ${what}", {what = string.upper(what)}))
|
return self:query(str.template("SHOW ${what}", {what = str.upper(what)}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:current_database()
|
function Connection:current_database()
|
||||||
@ -508,36 +509,36 @@ function Connection:list_databases()
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:list_tables(database_name)
|
function Connection:list_tables(database_name)
|
||||||
if database_name and not string.is_blank(database_name) then
|
if database_name and not str.is_blank(database_name) then
|
||||||
return self:query(string.template("SHOW TABLES FROM ${database}", {database = database_name}))
|
return self:query(str.template("SHOW TABLES FROM ${database}", {database = database_name}))
|
||||||
else
|
else
|
||||||
return self:query("SHOW TABLES")
|
return self:query("SHOW TABLES")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:describe_table(table_name)
|
function Connection:describe_table(table_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
return self:query(string.template("DESCRIBE ${table}", {table = table_name}))
|
return self:query(str.template("DESCRIBE ${table}", {table = table_name}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:show_create_table(table_name)
|
function Connection:show_create_table(table_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
return self:query(string.template("SHOW CREATE TABLE ${table}", {table = table_name}))
|
return self:query(str.template("SHOW CREATE TABLE ${table}", {table = table_name}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:show_indexes(table_name)
|
function Connection:show_indexes(table_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
return self:query(string.template("SHOW INDEXES FROM ${table}", {table = table_name}))
|
return self:query(str.template("SHOW INDEXES FROM ${table}", {table = table_name}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:show_table_status(table_name)
|
function Connection:show_table_status(table_name)
|
||||||
if table_name and not string.is_blank(table_name) then
|
if table_name and not str.is_blank(table_name) then
|
||||||
return self:query("SHOW TABLE STATUS LIKE ?", table_name)
|
return self:query("SHOW TABLE STATUS LIKE ?", table_name)
|
||||||
else
|
else
|
||||||
return self:query("SHOW TABLE STATUS")
|
return self:query("SHOW TABLE STATUS")
|
||||||
@ -546,12 +547,12 @@ end
|
|||||||
|
|
||||||
-- MySQL user and privilege management
|
-- MySQL user and privilege management
|
||||||
function Connection:create_user(username, password, host)
|
function Connection:create_user(username, password, host)
|
||||||
if string.is_blank(username) or string.is_blank(password) then
|
if str.is_blank(username) or str.is_blank(password) then
|
||||||
error("Username and password cannot be empty")
|
error("Username and password cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
host = host or "%"
|
host = host or "%"
|
||||||
local query = string.template("CREATE USER '${username}'@'${host}' IDENTIFIED BY ?", {
|
local query = str.template("CREATE USER '${username}'@'${host}' IDENTIFIED BY ?", {
|
||||||
username = username,
|
username = username,
|
||||||
host = host
|
host = host
|
||||||
})
|
})
|
||||||
@ -559,12 +560,12 @@ function Connection:create_user(username, password, host)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_user(username, host)
|
function Connection:drop_user(username, host)
|
||||||
if string.is_blank(username) then
|
if str.is_blank(username) then
|
||||||
error("Username cannot be empty")
|
error("Username cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
host = host or "%"
|
host = host or "%"
|
||||||
local query = string.template("DROP USER IF EXISTS '${username}'@'${host}'", {
|
local query = str.template("DROP USER IF EXISTS '${username}'@'${host}'", {
|
||||||
username = username,
|
username = username,
|
||||||
host = host
|
host = host
|
||||||
})
|
})
|
||||||
@ -572,16 +573,16 @@ function Connection:drop_user(username, host)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:grant(privileges, database, table_name, username, host)
|
function Connection:grant(privileges, database, table_name, username, host)
|
||||||
if string.is_blank(privileges) or string.is_blank(database) or string.is_blank(username) then
|
if str.is_blank(privileges) or str.is_blank(database) or str.is_blank(username) then
|
||||||
error("Privileges, database, and username cannot be empty")
|
error("Privileges, database, and username cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
host = host or "%"
|
host = host or "%"
|
||||||
table_name = table_name or "*"
|
table_name = table_name or "*"
|
||||||
local object = string.template("${database}.${table}", {database = database, table = table_name})
|
local object = str.template("${database}.${table}", {database = database, table = table_name})
|
||||||
|
|
||||||
local query = string.template("GRANT ${privileges} ON ${object} TO '${username}'@'${host}'", {
|
local query = str.template("GRANT ${privileges} ON ${object} TO '${username}'@'${host}'", {
|
||||||
privileges = string.upper(privileges),
|
privileges = str.upper(privileges),
|
||||||
object = object,
|
object = object,
|
||||||
username = username,
|
username = username,
|
||||||
host = host
|
host = host
|
||||||
@ -590,16 +591,16 @@ function Connection:grant(privileges, database, table_name, username, host)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:revoke(privileges, database, table_name, username, host)
|
function Connection:revoke(privileges, database, table_name, username, host)
|
||||||
if string.is_blank(privileges) or string.is_blank(database) or string.is_blank(username) then
|
if str.is_blank(privileges) or str.is_blank(database) or str.is_blank(username) then
|
||||||
error("Privileges, database, and username cannot be empty")
|
error("Privileges, database, and username cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
host = host or "%"
|
host = host or "%"
|
||||||
table_name = table_name or "*"
|
table_name = table_name or "*"
|
||||||
local object = string.template("${database}.${table}", {database = database, table = table_name})
|
local object = str.template("${database}.${table}", {database = database, table = table_name})
|
||||||
|
|
||||||
local query = string.template("REVOKE ${privileges} ON ${object} FROM '${username}'@'${host}'", {
|
local query = str.template("REVOKE ${privileges} ON ${object} FROM '${username}'@'${host}'", {
|
||||||
privileges = string.upper(privileges),
|
privileges = str.upper(privileges),
|
||||||
object = object,
|
object = object,
|
||||||
username = username,
|
username = username,
|
||||||
host = host
|
host = host
|
||||||
@ -613,31 +614,31 @@ end
|
|||||||
|
|
||||||
-- MySQL variables and configuration
|
-- MySQL variables and configuration
|
||||||
function Connection:set_variable(name, value, global)
|
function Connection:set_variable(name, value, global)
|
||||||
if string.is_blank(name) then
|
if str.is_blank(name) then
|
||||||
error("Variable name cannot be empty")
|
error("Variable name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local scope = global and "GLOBAL " or "SESSION "
|
local scope = global and "GLOBAL " or "SESSION "
|
||||||
return self:exec(string.template("SET ${scope}${name} = ?", {
|
return self:exec(str.template("SET ${scope}${name} = ?", {
|
||||||
scope = scope,
|
scope = scope,
|
||||||
name = name
|
name = name
|
||||||
}), value)
|
}), value)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:get_variable(name, global)
|
function Connection:get_variable(name, global)
|
||||||
if string.is_blank(name) then
|
if str.is_blank(name) then
|
||||||
error("Variable name cannot be empty")
|
error("Variable name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local scope = global and "global." or "session."
|
local scope = global and "global." or "session."
|
||||||
return self:query_value(string.template("SELECT @@${scope}${name}", {
|
return self:query_value(str.template("SELECT @@${scope}${name}", {
|
||||||
scope = scope,
|
scope = scope,
|
||||||
name = name
|
name = name
|
||||||
}))
|
}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:show_variables(pattern)
|
function Connection:show_variables(pattern)
|
||||||
if pattern and not string.is_blank(pattern) then
|
if pattern and not str.is_blank(pattern) then
|
||||||
return self:query("SHOW VARIABLES LIKE ?", pattern)
|
return self:query("SHOW VARIABLES LIKE ?", pattern)
|
||||||
else
|
else
|
||||||
return self:query("SHOW VARIABLES")
|
return self:query("SHOW VARIABLES")
|
||||||
@ -645,7 +646,7 @@ function Connection:show_variables(pattern)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:show_status(pattern)
|
function Connection:show_status(pattern)
|
||||||
if pattern and not string.is_blank(pattern) then
|
if pattern and not str.is_blank(pattern) then
|
||||||
return self:query("SHOW STATUS LIKE ?", pattern)
|
return self:query("SHOW STATUS LIKE ?", pattern)
|
||||||
else
|
else
|
||||||
return self:query("SHOW STATUS")
|
return self:query("SHOW STATUS")
|
||||||
@ -654,11 +655,11 @@ end
|
|||||||
|
|
||||||
-- Connection management
|
-- Connection management
|
||||||
function mysql.connect(dsn)
|
function mysql.connect(dsn)
|
||||||
if string.is_blank(dsn) then
|
if str.is_blank(dsn) then
|
||||||
error("DSN cannot be empty")
|
error("DSN cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local conn_id = moonshark.sql_connect("mysql", string.trim(dsn))
|
local conn_id = moonshark.sql_connect("mysql", str.trim(dsn))
|
||||||
if conn_id then
|
if conn_id then
|
||||||
local conn = {_id = conn_id}
|
local conn = {_id = conn_id}
|
||||||
setmetatable(conn, Connection)
|
setmetatable(conn, Connection)
|
||||||
@ -718,8 +719,8 @@ function mysql.migrate(dsn, migrations, database_name)
|
|||||||
end
|
end
|
||||||
|
|
||||||
-- Use specified database if provided
|
-- Use specified database if provided
|
||||||
if database_name and not string.is_blank(database_name) then
|
if database_name and not str.is_blank(database_name) then
|
||||||
conn:exec(string.template("USE ${database}", {database = database_name}))
|
conn:exec(str.template("USE ${database}", {database = database_name}))
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Create migrations table
|
-- Create migrations table
|
||||||
@ -736,7 +737,7 @@ function mysql.migrate(dsn, migrations, database_name)
|
|||||||
local error_msg = ""
|
local error_msg = ""
|
||||||
|
|
||||||
for _, migration in ipairs(migrations) do
|
for _, migration in ipairs(migrations) do
|
||||||
if not migration.name or string.is_blank(migration.name) then
|
if not migration.name or str.is_blank(migration.name) then
|
||||||
error_msg = "Migration must have a non-empty name"
|
error_msg = "Migration must have a non-empty name"
|
||||||
success = false
|
success = false
|
||||||
break
|
break
|
||||||
@ -744,7 +745,7 @@ function mysql.migrate(dsn, migrations, database_name)
|
|||||||
|
|
||||||
-- Check if migration already applied
|
-- Check if migration already applied
|
||||||
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?",
|
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?",
|
||||||
string.trim(migration.name))
|
str.trim(migration.name))
|
||||||
if not existing then
|
if not existing then
|
||||||
local ok, err = pcall(function()
|
local ok, err = pcall(function()
|
||||||
if type(migration.up) == "string" then
|
if type(migration.up) == "string" then
|
||||||
@ -757,11 +758,11 @@ function mysql.migrate(dsn, migrations, database_name)
|
|||||||
end)
|
end)
|
||||||
|
|
||||||
if ok then
|
if ok then
|
||||||
conn:exec("INSERT INTO _migrations (name) VALUES (?)", string.trim(migration.name))
|
conn:exec("INSERT INTO _migrations (name) VALUES (?)", str.trim(migration.name))
|
||||||
print(string.template("Applied migration: ${name}", {name = migration.name}))
|
print(str.template("Applied migration: ${name}", {name = migration.name}))
|
||||||
else
|
else
|
||||||
success = false
|
success = false
|
||||||
error_msg = string.template("Migration '${name}' failed: ${error}", {
|
error_msg = str.template("Migration '${name}' failed: ${error}", {
|
||||||
name = migration.name,
|
name = migration.name,
|
||||||
error = err or "unknown error"
|
error = err or "unknown error"
|
||||||
})
|
})
|
||||||
@ -788,7 +789,7 @@ function mysql.to_array(results, column_name)
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.is_blank(column_name) then
|
if str.is_blank(column_name) then
|
||||||
error("Column name cannot be empty")
|
error("Column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -800,7 +801,7 @@ function mysql.to_map(results, key_column, value_column)
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.is_blank(key_column) then
|
if str.is_blank(key_column) then
|
||||||
error("Key column name cannot be empty")
|
error("Key column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -817,7 +818,7 @@ function mysql.group_by(results, column_name)
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.is_blank(column_name) then
|
if str.is_blank(column_name) then
|
||||||
error("Column name cannot be empty")
|
error("Column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -837,19 +838,19 @@ function mysql.print_results(results)
|
|||||||
-- Calculate column widths
|
-- Calculate column widths
|
||||||
local widths = {}
|
local widths = {}
|
||||||
for _, col in ipairs(columns) do
|
for _, col in ipairs(columns) do
|
||||||
widths[col] = string.length(col)
|
widths[col] = str.length(col)
|
||||||
end
|
end
|
||||||
|
|
||||||
for _, row in ipairs(results) do
|
for _, row in ipairs(results) do
|
||||||
for _, col in ipairs(columns) do
|
for _, col in ipairs(columns) do
|
||||||
local value = tostring(row[col] or "")
|
local value = tostring(row[col] or "")
|
||||||
widths[col] = math.max(widths[col], string.length(value))
|
widths[col] = math.max(widths[col], str.length(value))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Print header
|
-- Print header
|
||||||
local header_parts = tbl.map(columns, function(col) return string.pad_right(col, widths[col]) end)
|
local header_parts = tbl.map(columns, function(col) return str.pad_right(col, widths[col]) end)
|
||||||
local separator_parts = tbl.map(columns, function(col) return string.repeat_("-", widths[col]) end)
|
local separator_parts = tbl.map(columns, function(col) return str.repeat_("-", widths[col]) end)
|
||||||
|
|
||||||
print(tbl.concat(header_parts, " | "))
|
print(tbl.concat(header_parts, " | "))
|
||||||
print(tbl.concat(separator_parts, "-+-"))
|
print(tbl.concat(separator_parts, "-+-"))
|
||||||
@ -858,7 +859,7 @@ function mysql.print_results(results)
|
|||||||
for _, row in ipairs(results) do
|
for _, row in ipairs(results) do
|
||||||
local value_parts = tbl.map(columns, function(col)
|
local value_parts = tbl.map(columns, function(col)
|
||||||
local value = tostring(row[col] or "")
|
local value = tostring(row[col] or "")
|
||||||
return string.pad_right(value, widths[col])
|
return str.pad_right(value, widths[col])
|
||||||
end)
|
end)
|
||||||
print(tbl.concat(value_parts, " | "))
|
print(tbl.concat(value_parts, " | "))
|
||||||
end
|
end
|
||||||
@ -869,14 +870,14 @@ function mysql.escape_string(str_val)
|
|||||||
if type(str_val) ~= "string" then
|
if type(str_val) ~= "string" then
|
||||||
return tostring(str_val)
|
return tostring(str_val)
|
||||||
end
|
end
|
||||||
return string.replace(str_val, "'", "\\'")
|
return str.replace(str_val, "'", "\\'")
|
||||||
end
|
end
|
||||||
|
|
||||||
function mysql.escape_identifier(name)
|
function mysql.escape_identifier(name)
|
||||||
if string.is_blank(name) then
|
if str.is_blank(name) then
|
||||||
error("Identifier name cannot be empty")
|
error("Identifier name cannot be empty")
|
||||||
end
|
end
|
||||||
return string.template("`${name}`", {name = string.replace(name, "`", "``")})
|
return str.template("`${name}`", {name = str.replace(name, "`", "``")})
|
||||||
end
|
end
|
||||||
|
|
||||||
-- DSN builder helper
|
-- DSN builder helper
|
||||||
@ -887,10 +888,10 @@ function mysql.build_dsn(options)
|
|||||||
|
|
||||||
local parts = {}
|
local parts = {}
|
||||||
|
|
||||||
if options.username and not string.is_blank(options.username) then
|
if options.username and not str.is_blank(options.username) then
|
||||||
tbl.insert(parts, options.username)
|
tbl.insert(parts, options.username)
|
||||||
if options.password and not string.is_blank(options.password) then
|
if options.password and not str.is_blank(options.password) then
|
||||||
parts[#parts] = string.template("${user}:${pass}", {
|
parts[#parts] = str.template("${user}:${pass}", {
|
||||||
user = parts[#parts],
|
user = parts[#parts],
|
||||||
pass = options.password
|
pass = options.password
|
||||||
})
|
})
|
||||||
@ -898,22 +899,22 @@ function mysql.build_dsn(options)
|
|||||||
parts[#parts] = parts[#parts] .. "@"
|
parts[#parts] = parts[#parts] .. "@"
|
||||||
end
|
end
|
||||||
|
|
||||||
if options.protocol and not string.is_blank(options.protocol) then
|
if options.protocol and not str.is_blank(options.protocol) then
|
||||||
tbl.insert(parts, string.template("${protocol}(", {protocol = options.protocol}))
|
tbl.insert(parts, str.template("${protocol}(", {protocol = options.protocol}))
|
||||||
if options.host and not string.is_blank(options.host) then
|
if options.host and not str.is_blank(options.host) then
|
||||||
tbl.insert(parts, options.host)
|
tbl.insert(parts, options.host)
|
||||||
if options.port then
|
if options.port then
|
||||||
parts[#parts] = string.template("${host}:${port}", {
|
parts[#parts] = str.template("${host}:${port}", {
|
||||||
host = parts[#parts],
|
host = parts[#parts],
|
||||||
port = tostring(options.port)
|
port = tostring(options.port)
|
||||||
})
|
})
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
parts[#parts] = parts[#parts] .. ")"
|
parts[#parts] = parts[#parts] .. ")"
|
||||||
elseif options.host and not string.is_blank(options.host) then
|
elseif options.host and not str.is_blank(options.host) then
|
||||||
local host_part = string.template("tcp(${host}", {host = options.host})
|
local host_part = str.template("tcp(${host}", {host = options.host})
|
||||||
if options.port then
|
if options.port then
|
||||||
host_part = string.template("${host}:${port}", {
|
host_part = str.template("${host}:${port}", {
|
||||||
host = host_part,
|
host = host_part,
|
||||||
port = tostring(options.port)
|
port = tostring(options.port)
|
||||||
})
|
})
|
||||||
@ -921,27 +922,27 @@ function mysql.build_dsn(options)
|
|||||||
tbl.insert(parts, host_part .. ")")
|
tbl.insert(parts, host_part .. ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
if options.database and not string.is_blank(options.database) then
|
if options.database and not str.is_blank(options.database) then
|
||||||
tbl.insert(parts, string.template("/${database}", {database = options.database}))
|
tbl.insert(parts, str.template("/${database}", {database = options.database}))
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Add parameters
|
-- Add parameters
|
||||||
local params = {}
|
local params = {}
|
||||||
if options.charset and not string.is_blank(options.charset) then
|
if options.charset and not str.is_blank(options.charset) then
|
||||||
tbl.insert(params, string.template("charset=${charset}", {charset = options.charset}))
|
tbl.insert(params, str.template("charset=${charset}", {charset = options.charset}))
|
||||||
end
|
end
|
||||||
if options.parseTime ~= nil then
|
if options.parseTime ~= nil then
|
||||||
tbl.insert(params, string.template("parseTime=${parse}", {parse = tostring(options.parseTime)}))
|
tbl.insert(params, str.template("parseTime=${parse}", {parse = tostring(options.parseTime)}))
|
||||||
end
|
end
|
||||||
if options.timeout and not string.is_blank(options.timeout) then
|
if options.timeout and not str.is_blank(options.timeout) then
|
||||||
tbl.insert(params, string.template("timeout=${timeout}", {timeout = options.timeout}))
|
tbl.insert(params, str.template("timeout=${timeout}", {timeout = options.timeout}))
|
||||||
end
|
end
|
||||||
if options.tls and not string.is_blank(options.tls) then
|
if options.tls and not str.is_blank(options.tls) then
|
||||||
tbl.insert(params, string.template("tls=${tls}", {tls = options.tls}))
|
tbl.insert(params, str.template("tls=${tls}", {tls = options.tls}))
|
||||||
end
|
end
|
||||||
|
|
||||||
if #params > 0 then
|
if #params > 0 then
|
||||||
tbl.insert(parts, string.template("?${params}", {params = tbl.concat(params, "&")}))
|
tbl.insert(parts, str.template("?${params}", {params = tbl.concat(params, "&")}))
|
||||||
end
|
end
|
||||||
|
|
||||||
return tbl.concat(parts, "")
|
return tbl.concat(parts, "")
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
local str = require("string")
|
||||||
local tbl = require("table")
|
local tbl = require("table")
|
||||||
local postgres = {}
|
local postgres = {}
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ function Connection:query(query_str, ...)
|
|||||||
if not self._id then
|
if not self._id then
|
||||||
error("Connection is closed")
|
error("Connection is closed")
|
||||||
end
|
end
|
||||||
query_str = string.normalize_whitespace(query_str)
|
query_str = str.normalize_whitespace(query_str)
|
||||||
return moonshark.sql_query(self._id, query_str, ...)
|
return moonshark.sql_query(self._id, query_str, ...)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ function Connection:exec(query_str, ...)
|
|||||||
if not self._id then
|
if not self._id then
|
||||||
error("Connection is closed")
|
error("Connection is closed")
|
||||||
end
|
end
|
||||||
query_str = string.normalize_whitespace(query_str)
|
query_str = str.normalize_whitespace(query_str)
|
||||||
return moonshark.sql_exec(self._id, query_str, ...)
|
return moonshark.sql_exec(self._id, query_str, ...)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -84,20 +85,20 @@ function Connection:begin()
|
|||||||
if not tx.active then
|
if not tx.active then
|
||||||
error("Transaction is not active")
|
error("Transaction is not active")
|
||||||
end
|
end
|
||||||
if string.is_blank(name) then
|
if str.is_blank(name) then
|
||||||
error("Savepoint name cannot be empty")
|
error("Savepoint name cannot be empty")
|
||||||
end
|
end
|
||||||
return tx.conn:exec(string.template("SAVEPOINT ${name}", {name = name}))
|
return tx.conn:exec(str.template("SAVEPOINT ${name}", {name = name}))
|
||||||
end,
|
end,
|
||||||
|
|
||||||
rollback_to = function(tx, name)
|
rollback_to = function(tx, name)
|
||||||
if not tx.active then
|
if not tx.active then
|
||||||
error("Transaction is not active")
|
error("Transaction is not active")
|
||||||
end
|
end
|
||||||
if string.is_blank(name) then
|
if str.is_blank(name) then
|
||||||
error("Savepoint name cannot be empty")
|
error("Savepoint name cannot be empty")
|
||||||
end
|
end
|
||||||
return tx.conn:exec(string.template("ROLLBACK TO SAVEPOINT ${name}", {name = name}))
|
return tx.conn:exec(str.template("ROLLBACK TO SAVEPOINT ${name}", {name = name}))
|
||||||
end,
|
end,
|
||||||
|
|
||||||
query = function(tx, query_str, ...)
|
query = function(tx, query_str, ...)
|
||||||
@ -139,7 +140,7 @@ local function build_postgres_params(data)
|
|||||||
local placeholders = {}
|
local placeholders = {}
|
||||||
|
|
||||||
for i = 1, #keys do
|
for i = 1, #keys do
|
||||||
tbl.insert(placeholders, string.template("$${num}", {num = tostring(i)}))
|
tbl.insert(placeholders, str.template("$${num}", {num = tostring(i)}))
|
||||||
end
|
end
|
||||||
|
|
||||||
return keys, values, placeholders, #keys
|
return keys, values, placeholders, #keys
|
||||||
@ -147,20 +148,20 @@ end
|
|||||||
|
|
||||||
-- Simplified query builders using table utilities
|
-- Simplified query builders using table utilities
|
||||||
function Connection:insert(table_name, data, returning)
|
function Connection:insert(table_name, data, returning)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local keys, values, placeholders = build_postgres_params(data)
|
local keys, values, placeholders = build_postgres_params(data)
|
||||||
|
|
||||||
local query = string.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
|
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
columns = tbl.concat(keys, ", "),
|
columns = tbl.concat(keys, ", "),
|
||||||
placeholders = tbl.concat(placeholders, ", ")
|
placeholders = tbl.concat(placeholders, ", ")
|
||||||
})
|
})
|
||||||
|
|
||||||
if returning and not string.is_blank(returning) then
|
if returning and not str.is_blank(returning) then
|
||||||
query = string.template("${query} RETURNING ${returning}", {
|
query = str.template("${query} RETURNING ${returning}", {
|
||||||
query = query,
|
query = query,
|
||||||
returning = returning
|
returning = returning
|
||||||
})
|
})
|
||||||
@ -171,25 +172,25 @@ function Connection:insert(table_name, data, returning)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:upsert(table_name, data, conflict_columns, returning)
|
function Connection:upsert(table_name, data, conflict_columns, returning)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local keys, values, placeholders = build_postgres_params(data)
|
local keys, values, placeholders = build_postgres_params(data)
|
||||||
local updates = tbl.map(keys, function(key)
|
local updates = tbl.map(keys, function(key)
|
||||||
return string.template("${key} = EXCLUDED.${key}", {key = key})
|
return str.template("${key} = EXCLUDED.${key}", {key = key})
|
||||||
end)
|
end)
|
||||||
|
|
||||||
local conflict_clause = ""
|
local conflict_clause = ""
|
||||||
if conflict_columns then
|
if conflict_columns then
|
||||||
if type(conflict_columns) == "string" then
|
if type(conflict_columns) == "string" then
|
||||||
conflict_clause = string.template("(${columns})", {columns = conflict_columns})
|
conflict_clause = str.template("(${columns})", {columns = conflict_columns})
|
||||||
else
|
else
|
||||||
conflict_clause = string.template("(${columns})", {columns = tbl.concat(conflict_columns, ", ")})
|
conflict_clause = str.template("(${columns})", {columns = tbl.concat(conflict_columns, ", ")})
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
|
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
columns = tbl.concat(keys, ", "),
|
columns = tbl.concat(keys, ", "),
|
||||||
placeholders = tbl.concat(placeholders, ", "),
|
placeholders = tbl.concat(placeholders, ", "),
|
||||||
@ -197,8 +198,8 @@ function Connection:upsert(table_name, data, conflict_columns, returning)
|
|||||||
updates = tbl.concat(updates, ", ")
|
updates = tbl.concat(updates, ", ")
|
||||||
})
|
})
|
||||||
|
|
||||||
if returning and not string.is_blank(returning) then
|
if returning and not str.is_blank(returning) then
|
||||||
query = string.template("${query} RETURNING ${returning}", {
|
query = str.template("${query} RETURNING ${returning}", {
|
||||||
query = query,
|
query = query,
|
||||||
returning = returning
|
returning = returning
|
||||||
})
|
})
|
||||||
@ -209,10 +210,10 @@ function Connection:upsert(table_name, data, conflict_columns, returning)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:update(table_name, data, where_clause, returning, ...)
|
function Connection:update(table_name, data, where_clause, returning, ...)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
if string.is_blank(where_clause) then
|
if str.is_blank(where_clause) then
|
||||||
error("WHERE clause cannot be empty for UPDATE")
|
error("WHERE clause cannot be empty for UPDATE")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -222,7 +223,7 @@ function Connection:update(table_name, data, where_clause, returning, ...)
|
|||||||
|
|
||||||
local sets = {}
|
local sets = {}
|
||||||
for i, key in ipairs(keys) do
|
for i, key in ipairs(keys) do
|
||||||
tbl.insert(sets, string.template("${key} = $${num}", {
|
tbl.insert(sets, str.template("${key} = $${num}", {
|
||||||
key = key,
|
key = key,
|
||||||
num = tostring(i)
|
num = tostring(i)
|
||||||
}))
|
}))
|
||||||
@ -234,18 +235,18 @@ function Connection:update(table_name, data, where_clause, returning, ...)
|
|||||||
for i = 1, #where_args do
|
for i = 1, #where_args do
|
||||||
param_count = param_count + 1
|
param_count = param_count + 1
|
||||||
tbl.insert(values, where_args[i])
|
tbl.insert(values, where_args[i])
|
||||||
where_clause_with_params = string.replace(where_clause_with_params, "?",
|
where_clause_with_params = str.replace(where_clause_with_params, "?",
|
||||||
string.template("$${num}", {num = tostring(param_count)}), 1)
|
str.template("$${num}", {num = tostring(param_count)}), 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
|
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
sets = tbl.concat(sets, ", "),
|
sets = tbl.concat(sets, ", "),
|
||||||
where = where_clause_with_params
|
where = where_clause_with_params
|
||||||
})
|
})
|
||||||
|
|
||||||
if returning and not string.is_blank(returning) then
|
if returning and not str.is_blank(returning) then
|
||||||
query = string.template("${query} RETURNING ${returning}", {
|
query = str.template("${query} RETURNING ${returning}", {
|
||||||
query = query,
|
query = query,
|
||||||
returning = returning
|
returning = returning
|
||||||
})
|
})
|
||||||
@ -256,10 +257,10 @@ function Connection:update(table_name, data, where_clause, returning, ...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:delete(table_name, where_clause, returning, ...)
|
function Connection:delete(table_name, where_clause, returning, ...)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
if string.is_blank(where_clause) then
|
if str.is_blank(where_clause) then
|
||||||
error("WHERE clause cannot be empty for DELETE")
|
error("WHERE clause cannot be empty for DELETE")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -269,17 +270,17 @@ function Connection:delete(table_name, where_clause, returning, ...)
|
|||||||
local where_clause_with_params = where_clause
|
local where_clause_with_params = where_clause
|
||||||
for i = 1, #where_args do
|
for i = 1, #where_args do
|
||||||
tbl.insert(values, where_args[i])
|
tbl.insert(values, where_args[i])
|
||||||
where_clause_with_params = string.replace(where_clause_with_params, "?",
|
where_clause_with_params = str.replace(where_clause_with_params, "?",
|
||||||
string.template("$${num}", {num = tostring(i)}), 1)
|
str.template("$${num}", {num = tostring(i)}), 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("DELETE FROM ${table} WHERE ${where}", {
|
local query = str.template("DELETE FROM ${table} WHERE ${where}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
where = where_clause_with_params
|
where = where_clause_with_params
|
||||||
})
|
})
|
||||||
|
|
||||||
if returning and not string.is_blank(returning) then
|
if returning and not str.is_blank(returning) then
|
||||||
query = string.template("${query} RETURNING ${returning}", {
|
query = str.template("${query} RETURNING ${returning}", {
|
||||||
query = query,
|
query = query,
|
||||||
returning = returning
|
returning = returning
|
||||||
})
|
})
|
||||||
@ -290,7 +291,7 @@ function Connection:delete(table_name, where_clause, returning, ...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:select(table_name, columns, where_clause, ...)
|
function Connection:select(table_name, columns, where_clause, ...)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -300,25 +301,25 @@ function Connection:select(table_name, columns, where_clause, ...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
local query
|
local query
|
||||||
if where_clause and not string.is_blank(where_clause) then
|
if where_clause and not str.is_blank(where_clause) then
|
||||||
-- Handle WHERE clause parameters
|
-- Handle WHERE clause parameters
|
||||||
local where_args = {...}
|
local where_args = {...}
|
||||||
local values = {}
|
local values = {}
|
||||||
local where_clause_with_params = where_clause
|
local where_clause_with_params = where_clause
|
||||||
for i = 1, #where_args do
|
for i = 1, #where_args do
|
||||||
tbl.insert(values, where_args[i])
|
tbl.insert(values, where_args[i])
|
||||||
where_clause_with_params = string.replace(where_clause_with_params, "?",
|
where_clause_with_params = str.replace(where_clause_with_params, "?",
|
||||||
string.template("$${num}", {num = tostring(i)}), 1)
|
str.template("$${num}", {num = tostring(i)}), 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
query = string.template("SELECT ${columns} FROM ${table} WHERE ${where}", {
|
query = str.template("SELECT ${columns} FROM ${table} WHERE ${where}", {
|
||||||
columns = columns,
|
columns = columns,
|
||||||
table = table_name,
|
table = table_name,
|
||||||
where = where_clause_with_params
|
where = where_clause_with_params
|
||||||
})
|
})
|
||||||
return self:query(query, unpack(values))
|
return self:query(query, unpack(values))
|
||||||
else
|
else
|
||||||
query = string.template("SELECT ${columns} FROM ${table}", {
|
query = str.template("SELECT ${columns} FROM ${table}", {
|
||||||
columns = columns,
|
columns = columns,
|
||||||
table = table_name
|
table = table_name
|
||||||
})
|
})
|
||||||
@ -328,20 +329,20 @@ end
|
|||||||
|
|
||||||
-- Enhanced PostgreSQL schema helpers
|
-- Enhanced PostgreSQL schema helpers
|
||||||
function Connection:table_exists(table_name, schema_name)
|
function Connection:table_exists(table_name, schema_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
schema_name = schema_name or "public"
|
schema_name = schema_name or "public"
|
||||||
local result = self:query_value(
|
local result = self:query_value(
|
||||||
"SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2",
|
"SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2",
|
||||||
string.trim(schema_name), string.trim(table_name)
|
str.trim(schema_name), str.trim(table_name)
|
||||||
)
|
)
|
||||||
return result ~= nil
|
return result ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:column_exists(table_name, column_name, schema_name)
|
function Connection:column_exists(table_name, column_name, schema_name)
|
||||||
if string.is_blank(table_name) or string.is_blank(column_name) then
|
if str.is_blank(table_name) or str.is_blank(column_name) then
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -349,29 +350,29 @@ function Connection:column_exists(table_name, column_name, schema_name)
|
|||||||
local result = self:query_value([[
|
local result = self:query_value([[
|
||||||
SELECT column_name FROM information_schema.columns
|
SELECT column_name FROM information_schema.columns
|
||||||
WHERE table_schema = $1 AND table_name = $2 AND column_name = $3
|
WHERE table_schema = $1 AND table_name = $2 AND column_name = $3
|
||||||
]], string.trim(schema_name), string.trim(table_name), string.trim(column_name))
|
]], str.trim(schema_name), str.trim(table_name), str.trim(column_name))
|
||||||
return result ~= nil
|
return result ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:create_table(table_name, schema)
|
function Connection:create_table(table_name, schema)
|
||||||
if string.is_blank(table_name) or string.is_blank(schema) then
|
if str.is_blank(table_name) or str.is_blank(schema) then
|
||||||
error("Table name and schema cannot be empty")
|
error("Table name and schema cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})", {
|
local query = str.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
schema = string.trim(schema)
|
schema = str.trim(schema)
|
||||||
})
|
})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_table(table_name, cascade)
|
function Connection:drop_table(table_name, cascade)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local cascade_clause = cascade and " CASCADE" or ""
|
local cascade_clause = cascade and " CASCADE" or ""
|
||||||
local query = string.template("DROP TABLE IF EXISTS ${table}${cascade}", {
|
local query = str.template("DROP TABLE IF EXISTS ${table}${cascade}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
cascade = cascade_clause
|
cascade = cascade_clause
|
||||||
})
|
})
|
||||||
@ -379,24 +380,24 @@ function Connection:drop_table(table_name, cascade)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:add_column(table_name, column_def)
|
function Connection:add_column(table_name, column_def)
|
||||||
if string.is_blank(table_name) or string.is_blank(column_def) then
|
if str.is_blank(table_name) or str.is_blank(column_def) then
|
||||||
error("Table name and column definition cannot be empty")
|
error("Table name and column definition cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("ALTER TABLE ${table} ADD COLUMN IF NOT EXISTS ${column}", {
|
local query = str.template("ALTER TABLE ${table} ADD COLUMN IF NOT EXISTS ${column}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
column = string.trim(column_def)
|
column = str.trim(column_def)
|
||||||
})
|
})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_column(table_name, column_name, cascade)
|
function Connection:drop_column(table_name, column_name, cascade)
|
||||||
if string.is_blank(table_name) or string.is_blank(column_name) then
|
if str.is_blank(table_name) or str.is_blank(column_name) then
|
||||||
error("Table name and column name cannot be empty")
|
error("Table name and column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local cascade_clause = cascade and " CASCADE" or ""
|
local cascade_clause = cascade and " CASCADE" or ""
|
||||||
local query = string.template("ALTER TABLE ${table} DROP COLUMN IF EXISTS ${column}${cascade}", {
|
local query = str.template("ALTER TABLE ${table} DROP COLUMN IF EXISTS ${column}${cascade}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
column = column_name,
|
column = column_name,
|
||||||
cascade = cascade_clause
|
cascade = cascade_clause
|
||||||
@ -405,15 +406,15 @@ function Connection:drop_column(table_name, column_name, cascade)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:create_index(index_name, table_name, columns, unique, method)
|
function Connection:create_index(index_name, table_name, columns, unique, method)
|
||||||
if string.is_blank(index_name) or string.is_blank(table_name) then
|
if str.is_blank(index_name) or str.is_blank(table_name) then
|
||||||
error("Index name and table name cannot be empty")
|
error("Index name and table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local unique_clause = unique and "UNIQUE " or ""
|
local unique_clause = unique and "UNIQUE " or ""
|
||||||
local method_clause = method and string.template(" USING ${method}", {method = string.upper(method)}) or ""
|
local method_clause = method and str.template(" USING ${method}", {method = str.upper(method)}) or ""
|
||||||
local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
|
local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
|
||||||
|
|
||||||
local query = string.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table}${method} (${columns})", {
|
local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table}${method} (${columns})", {
|
||||||
unique = unique_clause,
|
unique = unique_clause,
|
||||||
index = index_name,
|
index = index_name,
|
||||||
table = table_name,
|
table = table_name,
|
||||||
@ -424,12 +425,12 @@ function Connection:create_index(index_name, table_name, columns, unique, method
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_index(index_name, cascade)
|
function Connection:drop_index(index_name, cascade)
|
||||||
if string.is_blank(index_name) then
|
if str.is_blank(index_name) then
|
||||||
error("Index name cannot be empty")
|
error("Index name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local cascade_clause = cascade and " CASCADE" or ""
|
local cascade_clause = cascade and " CASCADE" or ""
|
||||||
local query = string.template("DROP INDEX IF EXISTS ${index}${cascade}", {
|
local query = str.template("DROP INDEX IF EXISTS ${index}${cascade}", {
|
||||||
index = index_name,
|
index = index_name,
|
||||||
cascade = cascade_clause
|
cascade = cascade_clause
|
||||||
})
|
})
|
||||||
@ -439,32 +440,32 @@ end
|
|||||||
-- PostgreSQL-specific functions
|
-- PostgreSQL-specific functions
|
||||||
function Connection:vacuum(table_name, analyze)
|
function Connection:vacuum(table_name, analyze)
|
||||||
local analyze_clause = analyze and " ANALYZE" or ""
|
local analyze_clause = analyze and " ANALYZE" or ""
|
||||||
local table_clause = table_name and string.template(" ${table}", {table = table_name}) or ""
|
local table_clause = table_name and str.template(" ${table}", {table = table_name}) or ""
|
||||||
return self:exec(string.template("VACUUM${analyze}${table}", {
|
return self:exec(str.template("VACUUM${analyze}${table}", {
|
||||||
analyze = analyze_clause,
|
analyze = analyze_clause,
|
||||||
table = table_clause
|
table = table_clause
|
||||||
}))
|
}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:analyze(table_name)
|
function Connection:analyze(table_name)
|
||||||
local table_clause = table_name and string.template(" ${table}", {table = table_name}) or ""
|
local table_clause = table_name and str.template(" ${table}", {table = table_name}) or ""
|
||||||
return self:exec(string.template("ANALYZE${table}", {table = table_clause}))
|
return self:exec(str.template("ANALYZE${table}", {table = table_clause}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:reindex(name, type)
|
function Connection:reindex(name, type)
|
||||||
if string.is_blank(name) then
|
if str.is_blank(name) then
|
||||||
error("Name cannot be empty for REINDEX")
|
error("Name cannot be empty for REINDEX")
|
||||||
end
|
end
|
||||||
|
|
||||||
type = type or "INDEX"
|
type = type or "INDEX"
|
||||||
local valid_types = {"INDEX", "TABLE", "SCHEMA", "DATABASE", "SYSTEM"}
|
local valid_types = {"INDEX", "TABLE", "SCHEMA", "DATABASE", "SYSTEM"}
|
||||||
local type_upper = string.upper(type)
|
local type_upper = str.upper(type)
|
||||||
|
|
||||||
if not tbl.contains(valid_types, type_upper) then
|
if not tbl.contains(valid_types, type_upper) then
|
||||||
error(string.template("Invalid REINDEX type: ${type}", {type = type}))
|
error(str.template("Invalid REINDEX type: ${type}", {type = type}))
|
||||||
end
|
end
|
||||||
|
|
||||||
return self:exec(string.template("REINDEX ${type} ${name}", {
|
return self:exec(str.template("REINDEX ${type} ${name}", {
|
||||||
type = type_upper,
|
type = type_upper,
|
||||||
name = name
|
name = name
|
||||||
}))
|
}))
|
||||||
@ -472,17 +473,17 @@ end
|
|||||||
|
|
||||||
-- PostgreSQL settings and introspection
|
-- PostgreSQL settings and introspection
|
||||||
function Connection:show(setting)
|
function Connection:show(setting)
|
||||||
if string.is_blank(setting) then
|
if str.is_blank(setting) then
|
||||||
error("Setting name cannot be empty")
|
error("Setting name cannot be empty")
|
||||||
end
|
end
|
||||||
return self:query_value(string.template("SHOW ${setting}", {setting = setting}))
|
return self:query_value(str.template("SHOW ${setting}", {setting = setting}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:set(setting, value)
|
function Connection:set(setting, value)
|
||||||
if string.is_blank(setting) then
|
if str.is_blank(setting) then
|
||||||
error("Setting name cannot be empty")
|
error("Setting name cannot be empty")
|
||||||
end
|
end
|
||||||
return self:exec(string.template("SET ${setting} = ${value}", {
|
return self:exec(str.template("SET ${setting} = ${value}", {
|
||||||
setting = setting,
|
setting = setting,
|
||||||
value = tostring(value)
|
value = tostring(value)
|
||||||
}))
|
}))
|
||||||
@ -507,11 +508,11 @@ end
|
|||||||
function Connection:list_tables(schema_name)
|
function Connection:list_tables(schema_name)
|
||||||
schema_name = schema_name or "public"
|
schema_name = schema_name or "public"
|
||||||
return self:query("SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename",
|
return self:query("SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename",
|
||||||
string.trim(schema_name))
|
str.trim(schema_name))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:describe_table(table_name, schema_name)
|
function Connection:describe_table(table_name, schema_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -521,64 +522,64 @@ function Connection:describe_table(table_name, schema_name)
|
|||||||
FROM information_schema.columns
|
FROM information_schema.columns
|
||||||
WHERE table_schema = $1 AND table_name = $2
|
WHERE table_schema = $1 AND table_name = $2
|
||||||
ORDER BY ordinal_position
|
ORDER BY ordinal_position
|
||||||
]], string.trim(schema_name), string.trim(table_name))
|
]], str.trim(schema_name), str.trim(table_name))
|
||||||
end
|
end
|
||||||
|
|
||||||
-- JSON/JSONB helpers
|
-- JSON/JSONB helpers
|
||||||
function Connection:json_extract(column, path)
|
function Connection:json_extract(column, path)
|
||||||
if string.is_blank(column) or string.is_blank(path) then
|
if str.is_blank(column) or str.is_blank(path) then
|
||||||
error("Column and path cannot be empty")
|
error("Column and path cannot be empty")
|
||||||
end
|
end
|
||||||
return string.template("${column}->'${path}'", {column = column, path = path})
|
return str.template("${column}->'${path}'", {column = column, path = path})
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:json_extract_text(column, path)
|
function Connection:json_extract_text(column, path)
|
||||||
if string.is_blank(column) or string.is_blank(path) then
|
if str.is_blank(column) or str.is_blank(path) then
|
||||||
error("Column and path cannot be empty")
|
error("Column and path cannot be empty")
|
||||||
end
|
end
|
||||||
return string.template("${column}->>'${path}'", {column = column, path = path})
|
return str.template("${column}->>'${path}'", {column = column, path = path})
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:jsonb_contains(column, value)
|
function Connection:jsonb_contains(column, value)
|
||||||
if string.is_blank(column) or string.is_blank(value) then
|
if str.is_blank(column) or str.is_blank(value) then
|
||||||
error("Column and value cannot be empty")
|
error("Column and value cannot be empty")
|
||||||
end
|
end
|
||||||
return string.template("${column} @> '${value}'", {column = column, value = value})
|
return str.template("${column} @> '${value}'", {column = column, value = value})
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:jsonb_contained_by(column, value)
|
function Connection:jsonb_contained_by(column, value)
|
||||||
if string.is_blank(column) or string.is_blank(value) then
|
if str.is_blank(column) or str.is_blank(value) then
|
||||||
error("Column and value cannot be empty")
|
error("Column and value cannot be empty")
|
||||||
end
|
end
|
||||||
return string.template("${column} <@ '${value}'", {column = column, value = value})
|
return str.template("${column} <@ '${value}'", {column = column, value = value})
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Array helpers
|
-- Array helpers
|
||||||
function Connection:array_contains(column, value)
|
function Connection:array_contains(column, value)
|
||||||
if string.is_blank(column) then
|
if str.is_blank(column) then
|
||||||
error("Column cannot be empty")
|
error("Column cannot be empty")
|
||||||
end
|
end
|
||||||
return string.template("$1 = ANY(${column})", {column = column})
|
return str.template("$1 = ANY(${column})", {column = column})
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:array_length(column)
|
function Connection:array_length(column)
|
||||||
if string.is_blank(column) then
|
if str.is_blank(column) then
|
||||||
error("Column cannot be empty")
|
error("Column cannot be empty")
|
||||||
end
|
end
|
||||||
return string.template("array_length(${column}, 1)", {column = column})
|
return str.template("array_length(${column}, 1)", {column = column})
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Connection management
|
-- Connection management
|
||||||
function postgres.parse_dsn(dsn)
|
function postgres.parse_dsn(dsn)
|
||||||
if string.is_blank(dsn) then
|
if str.is_blank(dsn) then
|
||||||
return nil, "DSN cannot be empty"
|
return nil, "DSN cannot be empty"
|
||||||
end
|
end
|
||||||
|
|
||||||
local parts = {}
|
local parts = {}
|
||||||
for pair in string.trim(dsn):gmatch("[^%s]+") do
|
for pair in str.trim(dsn):gmatch("[^%s]+") do
|
||||||
local key, value = pair:match("([^=]+)=(.+)")
|
local key, value = pair:match("([^=]+)=(.+)")
|
||||||
if key and value then
|
if key and value then
|
||||||
parts[string.trim(key)] = string.trim(value)
|
parts[str.trim(key)] = str.trim(value)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -586,11 +587,11 @@ function postgres.parse_dsn(dsn)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function postgres.connect(dsn)
|
function postgres.connect(dsn)
|
||||||
if string.is_blank(dsn) then
|
if str.is_blank(dsn) then
|
||||||
error("DSN cannot be empty")
|
error("DSN cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local conn_id = moonshark.sql_connect("postgres", string.trim(dsn))
|
local conn_id = moonshark.sql_connect("postgres", str.trim(dsn))
|
||||||
if conn_id then
|
if conn_id then
|
||||||
local conn = {_id = conn_id}
|
local conn = {_id = conn_id}
|
||||||
setmetatable(conn, Connection)
|
setmetatable(conn, Connection)
|
||||||
@ -663,14 +664,14 @@ function postgres.migrate(dsn, migrations, schema)
|
|||||||
local error_msg = ""
|
local error_msg = ""
|
||||||
|
|
||||||
for _, migration in ipairs(migrations) do
|
for _, migration in ipairs(migrations) do
|
||||||
if not migration.name or string.is_blank(migration.name) then
|
if not migration.name or str.is_blank(migration.name) then
|
||||||
error_msg = "Migration must have a non-empty name"
|
error_msg = "Migration must have a non-empty name"
|
||||||
success = false
|
success = false
|
||||||
break
|
break
|
||||||
end
|
end
|
||||||
|
|
||||||
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = $1",
|
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = $1",
|
||||||
string.trim(migration.name))
|
str.trim(migration.name))
|
||||||
if not existing then
|
if not existing then
|
||||||
local ok, err = pcall(function()
|
local ok, err = pcall(function()
|
||||||
if type(migration.up) == "string" then
|
if type(migration.up) == "string" then
|
||||||
@ -683,11 +684,11 @@ function postgres.migrate(dsn, migrations, schema)
|
|||||||
end)
|
end)
|
||||||
|
|
||||||
if ok then
|
if ok then
|
||||||
conn:exec("INSERT INTO _migrations (name) VALUES ($1)", string.trim(migration.name))
|
conn:exec("INSERT INTO _migrations (name) VALUES ($1)", str.trim(migration.name))
|
||||||
print(string.template("Applied migration: ${name}", {name = migration.name}))
|
print(str.template("Applied migration: ${name}", {name = migration.name}))
|
||||||
else
|
else
|
||||||
success = false
|
success = false
|
||||||
error_msg = string.template("Migration '${name}' failed: ${error}", {
|
error_msg = str.template("Migration '${name}' failed: ${error}", {
|
||||||
name = migration.name,
|
name = migration.name,
|
||||||
error = err or "unknown error"
|
error = err or "unknown error"
|
||||||
})
|
})
|
||||||
@ -714,7 +715,7 @@ function postgres.to_array(results, column_name)
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.is_blank(column_name) then
|
if str.is_blank(column_name) then
|
||||||
error("Column name cannot be empty")
|
error("Column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -726,7 +727,7 @@ function postgres.to_map(results, key_column, value_column)
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.is_blank(key_column) then
|
if str.is_blank(key_column) then
|
||||||
error("Key column name cannot be empty")
|
error("Key column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -743,7 +744,7 @@ function postgres.group_by(results, column_name)
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.is_blank(column_name) then
|
if str.is_blank(column_name) then
|
||||||
error("Column name cannot be empty")
|
error("Column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -763,19 +764,19 @@ function postgres.print_results(results)
|
|||||||
-- Calculate column widths
|
-- Calculate column widths
|
||||||
local widths = {}
|
local widths = {}
|
||||||
for _, col in ipairs(columns) do
|
for _, col in ipairs(columns) do
|
||||||
widths[col] = string.length(col)
|
widths[col] = str.length(col)
|
||||||
end
|
end
|
||||||
|
|
||||||
for _, row in ipairs(results) do
|
for _, row in ipairs(results) do
|
||||||
for _, col in ipairs(columns) do
|
for _, col in ipairs(columns) do
|
||||||
local value = tostring(row[col] or "")
|
local value = tostring(row[col] or "")
|
||||||
widths[col] = math.max(widths[col], string.length(value))
|
widths[col] = math.max(widths[col], str.length(value))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Print header
|
-- Print header
|
||||||
local header_parts = tbl.map(columns, function(col) return string.pad_right(col, widths[col]) end)
|
local header_parts = tbl.map(columns, function(col) return str.pad_right(col, widths[col]) end)
|
||||||
local separator_parts = tbl.map(columns, function(col) return string.repeat_("-", widths[col]) end)
|
local separator_parts = tbl.map(columns, function(col) return str.repeat_("-", widths[col]) end)
|
||||||
|
|
||||||
print(tbl.concat(header_parts, " | "))
|
print(tbl.concat(header_parts, " | "))
|
||||||
print(tbl.concat(separator_parts, "-+-"))
|
print(tbl.concat(separator_parts, "-+-"))
|
||||||
@ -784,22 +785,22 @@ function postgres.print_results(results)
|
|||||||
for _, row in ipairs(results) do
|
for _, row in ipairs(results) do
|
||||||
local value_parts = tbl.map(columns, function(col)
|
local value_parts = tbl.map(columns, function(col)
|
||||||
local value = tostring(row[col] or "")
|
local value = tostring(row[col] or "")
|
||||||
return string.pad_right(value, widths[col])
|
return str.pad_right(value, widths[col])
|
||||||
end)
|
end)
|
||||||
print(tbl.concat(value_parts, " | "))
|
print(tbl.concat(value_parts, " | "))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function postgres.escape_identifier(name)
|
function postgres.escape_identifier(name)
|
||||||
if string.is_blank(name) then
|
if str.is_blank(name) then
|
||||||
error("Identifier name cannot be empty")
|
error("Identifier name cannot be empty")
|
||||||
end
|
end
|
||||||
return string.template('"${name}"', {name = string.replace(name, '"', '""')})
|
return str.template('"${name}"', {name = str.replace(name, '"', '""')})
|
||||||
end
|
end
|
||||||
|
|
||||||
function postgres.escape_literal(value)
|
function postgres.escape_literal(value)
|
||||||
if type(value) == "string" then
|
if type(value) == "string" then
|
||||||
return string.template("'${value}'", {value = string.replace(value, "'", "''")})
|
return str.template("'${value}'", {value = str.replace(value, "'", "''")})
|
||||||
end
|
end
|
||||||
return tostring(value)
|
return tostring(value)
|
||||||
end
|
end
|
||||||
|
|||||||
@ -4,50 +4,51 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"Moonshark/modules/crypto"
|
"Moonshark/modules/crypto"
|
||||||
"Moonshark/modules/fs"
|
"Moonshark/modules/fs"
|
||||||
"Moonshark/modules/http"
|
"Moonshark/modules/http"
|
||||||
"Moonshark/modules/kv"
|
|
||||||
"Moonshark/modules/math"
|
"Moonshark/modules/math"
|
||||||
"Moonshark/modules/sql"
|
"Moonshark/modules/sql"
|
||||||
lua_string "Moonshark/modules/string+"
|
lua_string "Moonshark/modules/string"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Global registry instance
|
||||||
var Global *Registry
|
var Global *Registry
|
||||||
|
|
||||||
//go:embed **/*.lua
|
//go:embed **/*.lua
|
||||||
var embeddedModules embed.FS
|
var embeddedModules embed.FS
|
||||||
|
|
||||||
|
// Registry manages all Lua modules and Go functions
|
||||||
type Registry struct {
|
type Registry struct {
|
||||||
modules map[string]string
|
modules map[string]string
|
||||||
globalModules map[string]string // globalName -> moduleSource
|
goFuncs map[string]luajit.GoFunction
|
||||||
goFuncs map[string]luajit.GoFunction
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New creates a new registry with all modules loaded
|
||||||
func New() *Registry {
|
func New() *Registry {
|
||||||
r := &Registry{
|
r := &Registry{
|
||||||
modules: make(map[string]string),
|
modules: make(map[string]string),
|
||||||
globalModules: make(map[string]string),
|
goFuncs: make(map[string]luajit.GoFunction),
|
||||||
goFuncs: make(map[string]luajit.GoFunction),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load all Go functions
|
||||||
maps.Copy(r.goFuncs, lua_string.GetFunctionList())
|
maps.Copy(r.goFuncs, lua_string.GetFunctionList())
|
||||||
maps.Copy(r.goFuncs, math.GetFunctionList())
|
maps.Copy(r.goFuncs, math.GetFunctionList())
|
||||||
maps.Copy(r.goFuncs, crypto.GetFunctionList())
|
maps.Copy(r.goFuncs, crypto.GetFunctionList())
|
||||||
maps.Copy(r.goFuncs, fs.GetFunctionList())
|
maps.Copy(r.goFuncs, fs.GetFunctionList())
|
||||||
maps.Copy(r.goFuncs, http.GetFunctionList())
|
maps.Copy(r.goFuncs, http.GetFunctionList())
|
||||||
maps.Copy(r.goFuncs, sql.GetFunctionList())
|
maps.Copy(r.goFuncs, sql.GetFunctionList())
|
||||||
maps.Copy(r.goFuncs, kv.GetFunctionList())
|
|
||||||
|
|
||||||
r.loadEmbeddedModules()
|
r.loadEmbeddedModules()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadEmbeddedModules discovers and loads all .lua files
|
||||||
func (r *Registry) loadEmbeddedModules() {
|
func (r *Registry) loadEmbeddedModules() {
|
||||||
|
// Discover all directories from embed
|
||||||
dirs, _ := embeddedModules.ReadDir(".")
|
dirs, _ := embeddedModules.ReadDir(".")
|
||||||
|
|
||||||
for _, dir := range dirs {
|
for _, dir := range dirs {
|
||||||
@ -55,27 +56,15 @@ func (r *Registry) loadEmbeddedModules() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
dirName := dir.Name()
|
// Assume one module file per directory: dirname/dirname.lua
|
||||||
isGlobal := strings.HasSuffix(dirName, "+")
|
modulePath := fmt.Sprintf("%s/%s.lua", dir.Name(), dir.Name())
|
||||||
|
|
||||||
var moduleName, globalName string
|
|
||||||
if isGlobal {
|
|
||||||
moduleName = strings.TrimSuffix(dirName, "+")
|
|
||||||
globalName = moduleName
|
|
||||||
} else {
|
|
||||||
moduleName = dirName
|
|
||||||
}
|
|
||||||
|
|
||||||
modulePath := fmt.Sprintf("%s/%s.lua", dirName, moduleName)
|
|
||||||
if source, err := embeddedModules.ReadFile(modulePath); err == nil {
|
if source, err := embeddedModules.ReadFile(modulePath); err == nil {
|
||||||
r.modules[moduleName] = string(source)
|
r.modules[dir.Name()] = string(source)
|
||||||
if isGlobal {
|
|
||||||
r.globalModules[globalName] = string(source)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InstallInState sets up the complete module system in a Lua state
|
||||||
func (r *Registry) InstallInState(state *luajit.State) error {
|
func (r *Registry) InstallInState(state *luajit.State) error {
|
||||||
// Create moonshark global table with Go functions
|
// Create moonshark global table with Go functions
|
||||||
state.NewTable()
|
state.NewTable()
|
||||||
@ -87,13 +76,6 @@ func (r *Registry) InstallInState(state *luajit.State) error {
|
|||||||
}
|
}
|
||||||
state.SetGlobal("moonshark")
|
state.SetGlobal("moonshark")
|
||||||
|
|
||||||
// Auto-enhance all global modules
|
|
||||||
for globalName, source := range r.globalModules {
|
|
||||||
if err := r.enhanceGlobal(state, globalName, source); err != nil {
|
|
||||||
return fmt.Errorf("failed to enhance %s global: %w", globalName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backup original require and install custom one
|
// Backup original require and install custom one
|
||||||
state.GetGlobal("require")
|
state.GetGlobal("require")
|
||||||
state.SetGlobal("_require_original")
|
state.SetGlobal("_require_original")
|
||||||
@ -108,13 +90,7 @@ func (r *Registry) InstallInState(state *luajit.State) error {
|
|||||||
return s.PushError("require: module name must be a string")
|
return s.PushError("require: module name must be a string")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return global if this module enhances a global
|
// Check built-in modules first
|
||||||
if _, isGlobal := r.globalModules[moduleName]; isGlobal {
|
|
||||||
s.GetGlobal(moduleName)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check built-in modules
|
|
||||||
if source, exists := r.modules[moduleName]; exists {
|
if source, exists := r.modules[moduleName]; exists {
|
||||||
if err := s.LoadString(source); err != nil {
|
if err := s.LoadString(source); err != nil {
|
||||||
return s.PushError("require: failed to load module '%s': %v", moduleName, err)
|
return s.PushError("require: failed to load module '%s': %v", moduleName, err)
|
||||||
@ -139,18 +115,7 @@ func (r *Registry) InstallInState(state *luajit.State) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Registry) enhanceGlobal(state *luajit.State, globalName, source string) error {
|
// Initialize sets up the global registry
|
||||||
// Execute the module - it directly modifies the global
|
|
||||||
if err := state.LoadString(source); err != nil {
|
|
||||||
return fmt.Errorf("failed to load %s module: %w", globalName, err)
|
|
||||||
}
|
|
||||||
if err := state.Call(0, 0); err != nil { // 0 results expected
|
|
||||||
return fmt.Errorf("failed to execute %s module: %w", globalName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Initialize() error {
|
func Initialize() error {
|
||||||
Global = New()
|
Global = New()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -1,212 +0,0 @@
|
|||||||
local kv = require("kv")
|
|
||||||
local crypto = require("crypto")
|
|
||||||
local json = require("json")
|
|
||||||
|
|
||||||
local sessions = {}
|
|
||||||
local stores = {}
|
|
||||||
local default_store = nil
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- CORE FUNCTIONS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
function sessions.init(store_name, filename)
|
|
||||||
store_name = store_name or "sessions"
|
|
||||||
if not kv.open(store_name, filename) then return false end
|
|
||||||
stores[store_name] = true
|
|
||||||
if not default_store then default_store = store_name end
|
|
||||||
return true
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.create(session_id, data, store_name)
|
|
||||||
if type(session_id) ~= "string" then error("session ID must be a string", 2) end
|
|
||||||
if data ~= nil and type(data) ~= "table" then error("data must be a table", 2) end
|
|
||||||
|
|
||||||
store_name = store_name or default_store
|
|
||||||
if not store_name then error("No session store initialized", 2) end
|
|
||||||
|
|
||||||
local session_data = {
|
|
||||||
data = data or {},
|
|
||||||
_created = os.time(),
|
|
||||||
_last_accessed = os.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv.set(store_name, "session:" .. session_id, json.encode(session_data))
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.get(session_id, store_name)
|
|
||||||
if type(session_id) ~= "string" then error("session ID must be a string", 2) end
|
|
||||||
|
|
||||||
store_name = store_name or default_store
|
|
||||||
if not store_name then error("No session store initialized", 2) end
|
|
||||||
|
|
||||||
local json_str = kv.get(store_name, "session:" .. session_id)
|
|
||||||
if not json_str then return nil end
|
|
||||||
|
|
||||||
local session_data = json.decode(json_str)
|
|
||||||
if not session_data then return nil end
|
|
||||||
|
|
||||||
-- Update last accessed
|
|
||||||
session_data._last_accessed = os.time()
|
|
||||||
kv.set(store_name, "session:" .. session_id, json.encode(session_data))
|
|
||||||
|
|
||||||
-- Return flattened data with metadata
|
|
||||||
local result = session_data.data or {}
|
|
||||||
result._created = session_data._created
|
|
||||||
result._last_accessed = session_data._last_accessed
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.update(session_id, data, store_name)
|
|
||||||
if type(session_id) ~= "string" then error("session ID must be a string", 2) end
|
|
||||||
if type(data) ~= "table" then error("data must be a table", 2) end
|
|
||||||
|
|
||||||
store_name = store_name or default_store
|
|
||||||
if not store_name then error("No session store initialized", 2) end
|
|
||||||
|
|
||||||
local json_str = kv.get(store_name, "session:" .. session_id)
|
|
||||||
if not json_str then return false end
|
|
||||||
|
|
||||||
local session_data = json.decode(json_str)
|
|
||||||
if not session_data then return false end
|
|
||||||
|
|
||||||
session_data.data = data
|
|
||||||
session_data._last_accessed = os.time()
|
|
||||||
|
|
||||||
return kv.set(store_name, "session:" .. session_id, json.encode(session_data))
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.delete(session_id, store_name)
|
|
||||||
if type(session_id) ~= "string" then error("session ID must be a string", 2) end
|
|
||||||
|
|
||||||
store_name = store_name or default_store
|
|
||||||
if not store_name then error("No session store initialized", 2) end
|
|
||||||
return kv.delete(store_name, "session:" .. session_id)
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.cleanup(max_age, store_name)
|
|
||||||
store_name = store_name or default_store
|
|
||||||
if not store_name then error("No session store initialized", 2) end
|
|
||||||
|
|
||||||
local keys = kv.keys(store_name)
|
|
||||||
local current_time = os.time()
|
|
||||||
local deleted = 0
|
|
||||||
|
|
||||||
for _, key in ipairs(keys) do
|
|
||||||
if key:match("^session:") then
|
|
||||||
local json_str = kv.get(store_name, key)
|
|
||||||
if json_str then
|
|
||||||
local session_data = json.decode(json_str)
|
|
||||||
if session_data and session_data._last_accessed then
|
|
||||||
if current_time - session_data._last_accessed > max_age then
|
|
||||||
kv.delete(store_name, key)
|
|
||||||
deleted = deleted + 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return deleted
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.close(store_name)
|
|
||||||
local success = kv.close(store_name)
|
|
||||||
stores[store_name] = nil
|
|
||||||
if default_store == store_name then
|
|
||||||
default_store = next(stores)
|
|
||||||
end
|
|
||||||
return success
|
|
||||||
end
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- UTILITIES
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
function sessions.generate_id()
|
|
||||||
return crypto.random_alphanumeric(32)
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.exists(session_id, store_name)
|
|
||||||
store_name = store_name or default_store
|
|
||||||
if not store_name then error("No session store initialized", 2) end
|
|
||||||
return kv.has(store_name, "session:" .. session_id)
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.list(store_name)
|
|
||||||
store_name = store_name or default_store
|
|
||||||
if not store_name then error("No session store initialized", 2) end
|
|
||||||
|
|
||||||
local keys = kv.keys(store_name)
|
|
||||||
local session_ids = {}
|
|
||||||
|
|
||||||
for _, key in ipairs(keys) do
|
|
||||||
local session_id = key:match("^session:(.+)")
|
|
||||||
if session_id then
|
|
||||||
table.insert(session_ids, session_id)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return session_ids
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.count(store_name)
|
|
||||||
return #sessions.list(store_name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function sessions.reset()
|
|
||||||
stores = {}
|
|
||||||
default_store = nil
|
|
||||||
end
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- OOP INTERFACE
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
local SessionStore = {}
|
|
||||||
SessionStore.__index = SessionStore
|
|
||||||
|
|
||||||
function sessions.create_store(store_name, filename)
|
|
||||||
if not sessions.init(store_name, filename) then
|
|
||||||
error("Failed to initialize store '" .. store_name .. "'", 2)
|
|
||||||
end
|
|
||||||
return setmetatable({name = store_name}, SessionStore)
|
|
||||||
end
|
|
||||||
|
|
||||||
function SessionStore:create(session_id, data)
|
|
||||||
return sessions.create(session_id, data, self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function SessionStore:get(session_id)
|
|
||||||
return sessions.get(session_id, self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function SessionStore:update(session_id, data)
|
|
||||||
return sessions.update(session_id, data, self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function SessionStore:delete(session_id)
|
|
||||||
return sessions.delete(session_id, self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function SessionStore:cleanup(max_age)
|
|
||||||
return sessions.cleanup(max_age, self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function SessionStore:exists(session_id)
|
|
||||||
return sessions.exists(session_id, self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function SessionStore:list()
|
|
||||||
return sessions.list(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function SessionStore:count()
|
|
||||||
return sessions.count(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
function SessionStore:close()
|
|
||||||
return sessions.close(self.name)
|
|
||||||
end
|
|
||||||
|
|
||||||
return sessions
|
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
local str = require("string")
|
||||||
local tbl = require("table")
|
local tbl = require("table")
|
||||||
local sqlite = {}
|
local sqlite = {}
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ function Connection:query(query_str, ...)
|
|||||||
if not self._id then
|
if not self._id then
|
||||||
error("Connection is closed")
|
error("Connection is closed")
|
||||||
end
|
end
|
||||||
query_str = string.normalize_whitespace(query_str)
|
query_str = str.normalize_whitespace(query_str)
|
||||||
return moonshark.sql_query(self._id, query_str, ...)
|
return moonshark.sql_query(self._id, query_str, ...)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ function Connection:exec(query_str, ...)
|
|||||||
if not self._id then
|
if not self._id then
|
||||||
error("Connection is closed")
|
error("Connection is closed")
|
||||||
end
|
end
|
||||||
query_str = string.normalize_whitespace(query_str)
|
query_str = str.normalize_whitespace(query_str)
|
||||||
return moonshark.sql_exec(self._id, query_str, ...)
|
return moonshark.sql_exec(self._id, query_str, ...)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -114,7 +115,7 @@ end
|
|||||||
|
|
||||||
-- Simplified query builders using table utilities
|
-- Simplified query builders using table utilities
|
||||||
function Connection:insert(table_name, data)
|
function Connection:insert(table_name, data)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -122,7 +123,7 @@ function Connection:insert(table_name, data)
|
|||||||
local values = tbl.values(data)
|
local values = tbl.values(data)
|
||||||
local placeholders = tbl.map(keys, function() return "?" end)
|
local placeholders = tbl.map(keys, function() return "?" end)
|
||||||
|
|
||||||
local query = string.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
|
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
columns = tbl.concat(keys, ", "),
|
columns = tbl.concat(keys, ", "),
|
||||||
placeholders = tbl.concat(placeholders, ", ")
|
placeholders = tbl.concat(placeholders, ", ")
|
||||||
@ -132,7 +133,7 @@ function Connection:insert(table_name, data)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:upsert(table_name, data, conflict_columns)
|
function Connection:upsert(table_name, data, conflict_columns)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -140,19 +141,19 @@ function Connection:upsert(table_name, data, conflict_columns)
|
|||||||
local values = tbl.values(data)
|
local values = tbl.values(data)
|
||||||
local placeholders = tbl.map(keys, function() return "?" end)
|
local placeholders = tbl.map(keys, function() return "?" end)
|
||||||
local updates = tbl.map(keys, function(key)
|
local updates = tbl.map(keys, function(key)
|
||||||
return string.template("${key} = excluded.${key}", {key = key})
|
return str.template("${key} = excluded.${key}", {key = key})
|
||||||
end)
|
end)
|
||||||
|
|
||||||
local conflict_clause = ""
|
local conflict_clause = ""
|
||||||
if conflict_columns then
|
if conflict_columns then
|
||||||
if type(conflict_columns) == "string" then
|
if type(conflict_columns) == "string" then
|
||||||
conflict_clause = string.template("(${columns})", {columns = conflict_columns})
|
conflict_clause = str.template("(${columns})", {columns = conflict_columns})
|
||||||
else
|
else
|
||||||
conflict_clause = string.template("(${columns})", {columns = tbl.concat(conflict_columns, ", ")})
|
conflict_clause = str.template("(${columns})", {columns = tbl.concat(conflict_columns, ", ")})
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
|
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
columns = tbl.concat(keys, ", "),
|
columns = tbl.concat(keys, ", "),
|
||||||
placeholders = tbl.concat(placeholders, ", "),
|
placeholders = tbl.concat(placeholders, ", "),
|
||||||
@ -164,20 +165,20 @@ function Connection:upsert(table_name, data, conflict_columns)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:update(table_name, data, where_clause, ...)
|
function Connection:update(table_name, data, where_clause, ...)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
if string.is_blank(where_clause) then
|
if str.is_blank(where_clause) then
|
||||||
error("WHERE clause cannot be empty for UPDATE")
|
error("WHERE clause cannot be empty for UPDATE")
|
||||||
end
|
end
|
||||||
|
|
||||||
local keys = tbl.keys(data)
|
local keys = tbl.keys(data)
|
||||||
local values = tbl.values(data)
|
local values = tbl.values(data)
|
||||||
local sets = tbl.map(keys, function(key)
|
local sets = tbl.map(keys, function(key)
|
||||||
return string.template("${key} = ?", {key = key})
|
return str.template("${key} = ?", {key = key})
|
||||||
end)
|
end)
|
||||||
|
|
||||||
local query = string.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
|
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
sets = tbl.concat(sets, ", "),
|
sets = tbl.concat(sets, ", "),
|
||||||
where = where_clause
|
where = where_clause
|
||||||
@ -191,14 +192,14 @@ function Connection:update(table_name, data, where_clause, ...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:delete(table_name, where_clause, ...)
|
function Connection:delete(table_name, where_clause, ...)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
if string.is_blank(where_clause) then
|
if str.is_blank(where_clause) then
|
||||||
error("WHERE clause cannot be empty for DELETE")
|
error("WHERE clause cannot be empty for DELETE")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("DELETE FROM ${table} WHERE ${where}", {
|
local query = str.template("DELETE FROM ${table} WHERE ${where}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
where = where_clause
|
where = where_clause
|
||||||
})
|
})
|
||||||
@ -206,7 +207,7 @@ function Connection:delete(table_name, where_clause, ...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:select(table_name, columns, where_clause, ...)
|
function Connection:select(table_name, columns, where_clause, ...)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -216,15 +217,15 @@ function Connection:select(table_name, columns, where_clause, ...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
local query
|
local query
|
||||||
if where_clause and not string.is_blank(where_clause) then
|
if where_clause and not str.is_blank(where_clause) then
|
||||||
query = string.template("SELECT ${columns} FROM ${table} WHERE ${where}", {
|
query = str.template("SELECT ${columns} FROM ${table} WHERE ${where}", {
|
||||||
columns = columns,
|
columns = columns,
|
||||||
table = table_name,
|
table = table_name,
|
||||||
where = where_clause
|
where = where_clause
|
||||||
})
|
})
|
||||||
return self:query(query, ...)
|
return self:query(query, ...)
|
||||||
else
|
else
|
||||||
query = string.template("SELECT ${columns} FROM ${table}", {
|
query = str.template("SELECT ${columns} FROM ${table}", {
|
||||||
columns = columns,
|
columns = columns,
|
||||||
table = table_name
|
table = table_name
|
||||||
})
|
})
|
||||||
@ -234,73 +235,73 @@ end
|
|||||||
|
|
||||||
-- Schema helpers
|
-- Schema helpers
|
||||||
function Connection:table_exists(table_name)
|
function Connection:table_exists(table_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
local result = self:query_value(
|
local result = self:query_value(
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||||
string.trim(table_name)
|
str.trim(table_name)
|
||||||
)
|
)
|
||||||
return result ~= nil
|
return result ~= nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:column_exists(table_name, column_name)
|
function Connection:column_exists(table_name, column_name)
|
||||||
if string.is_blank(table_name) or string.is_blank(column_name) then
|
if str.is_blank(table_name) or str.is_blank(column_name) then
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
local result = self:query(string.template("PRAGMA table_info(${table})", {table = table_name}))
|
local result = self:query(str.template("PRAGMA table_info(${table})", {table = table_name}))
|
||||||
if result then
|
if result then
|
||||||
return tbl.any(result, function(row)
|
return tbl.any(result, function(row)
|
||||||
return string.iequals(row.name, string.trim(column_name))
|
return str.iequals(row.name, str.trim(column_name))
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:create_table(table_name, schema)
|
function Connection:create_table(table_name, schema)
|
||||||
if string.is_blank(table_name) or string.is_blank(schema) then
|
if str.is_blank(table_name) or str.is_blank(schema) then
|
||||||
error("Table name and schema cannot be empty")
|
error("Table name and schema cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})", {
|
local query = str.template("CREATE TABLE IF NOT EXISTS ${table} (${schema})", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
schema = string.trim(schema)
|
schema = str.trim(schema)
|
||||||
})
|
})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_table(table_name)
|
function Connection:drop_table(table_name)
|
||||||
if string.is_blank(table_name) then
|
if str.is_blank(table_name) then
|
||||||
error("Table name cannot be empty")
|
error("Table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("DROP TABLE IF EXISTS ${table}", {table = table_name})
|
local query = str.template("DROP TABLE IF EXISTS ${table}", {table = table_name})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:add_column(table_name, column_def)
|
function Connection:add_column(table_name, column_def)
|
||||||
if string.is_blank(table_name) or string.is_blank(column_def) then
|
if str.is_blank(table_name) or str.is_blank(column_def) then
|
||||||
error("Table name and column definition cannot be empty")
|
error("Table name and column definition cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("ALTER TABLE ${table} ADD COLUMN ${column}", {
|
local query = str.template("ALTER TABLE ${table} ADD COLUMN ${column}", {
|
||||||
table = table_name,
|
table = table_name,
|
||||||
column = string.trim(column_def)
|
column = str.trim(column_def)
|
||||||
})
|
})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:create_index(index_name, table_name, columns, unique)
|
function Connection:create_index(index_name, table_name, columns, unique)
|
||||||
if string.is_blank(index_name) or string.is_blank(table_name) then
|
if str.is_blank(index_name) or str.is_blank(table_name) then
|
||||||
error("Index name and table name cannot be empty")
|
error("Index name and table name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local unique_clause = unique and "UNIQUE " or ""
|
local unique_clause = unique and "UNIQUE " or ""
|
||||||
local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
|
local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
|
||||||
|
|
||||||
local query = string.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table} (${columns})", {
|
local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table} (${columns})", {
|
||||||
unique = unique_clause,
|
unique = unique_clause,
|
||||||
index = index_name,
|
index = index_name,
|
||||||
table = table_name,
|
table = table_name,
|
||||||
@ -310,11 +311,11 @@ function Connection:create_index(index_name, table_name, columns, unique)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function Connection:drop_index(index_name)
|
function Connection:drop_index(index_name)
|
||||||
if string.is_blank(index_name) then
|
if str.is_blank(index_name) then
|
||||||
error("Index name cannot be empty")
|
error("Index name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
local query = string.template("DROP INDEX IF EXISTS ${index}", {index = index_name})
|
local query = str.template("DROP INDEX IF EXISTS ${index}", {index = index_name})
|
||||||
return self:exec(query)
|
return self:exec(query)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -333,29 +334,29 @@ end
|
|||||||
|
|
||||||
function Connection:foreign_keys(enabled)
|
function Connection:foreign_keys(enabled)
|
||||||
local value = enabled and "ON" or "OFF"
|
local value = enabled and "ON" or "OFF"
|
||||||
return self:exec(string.template("PRAGMA foreign_keys = ${value}", {value = value}))
|
return self:exec(str.template("PRAGMA foreign_keys = ${value}", {value = value}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:journal_mode(mode)
|
function Connection:journal_mode(mode)
|
||||||
mode = mode or "WAL"
|
mode = mode or "WAL"
|
||||||
local valid_modes = {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}
|
local valid_modes = {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}
|
||||||
|
|
||||||
if not tbl.contains(tbl.map(valid_modes, string.upper), string.upper(mode)) then
|
if not tbl.contains(tbl.map(valid_modes, str.upper), str.upper(mode)) then
|
||||||
error("Invalid journal mode: " .. mode)
|
error("Invalid journal mode: " .. mode)
|
||||||
end
|
end
|
||||||
|
|
||||||
return self:query(string.template("PRAGMA journal_mode = ${mode}", {mode = string.upper(mode)}))
|
return self:query(str.template("PRAGMA journal_mode = ${mode}", {mode = str.upper(mode)}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:synchronous(level)
|
function Connection:synchronous(level)
|
||||||
level = level or "NORMAL"
|
level = level or "NORMAL"
|
||||||
local valid_levels = {"OFF", "NORMAL", "FULL", "EXTRA"}
|
local valid_levels = {"OFF", "NORMAL", "FULL", "EXTRA"}
|
||||||
|
|
||||||
if not tbl.contains(valid_levels, string.upper(level)) then
|
if not tbl.contains(valid_levels, str.upper(level)) then
|
||||||
error("Invalid synchronous level: " .. level)
|
error("Invalid synchronous level: " .. level)
|
||||||
end
|
end
|
||||||
|
|
||||||
return self:exec(string.template("PRAGMA synchronous = ${level}", {level = string.upper(level)}))
|
return self:exec(str.template("PRAGMA synchronous = ${level}", {level = str.upper(level)}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:cache_size(size)
|
function Connection:cache_size(size)
|
||||||
@ -363,18 +364,18 @@ function Connection:cache_size(size)
|
|||||||
if type(size) ~= "number" then
|
if type(size) ~= "number" then
|
||||||
error("Cache size must be a number")
|
error("Cache size must be a number")
|
||||||
end
|
end
|
||||||
return self:exec(string.template("PRAGMA cache_size = ${size}", {size = tostring(size)}))
|
return self:exec(str.template("PRAGMA cache_size = ${size}", {size = tostring(size)}))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Connection:temp_store(mode)
|
function Connection:temp_store(mode)
|
||||||
mode = mode or "MEMORY"
|
mode = mode or "MEMORY"
|
||||||
local valid_modes = {"DEFAULT", "FILE", "MEMORY"}
|
local valid_modes = {"DEFAULT", "FILE", "MEMORY"}
|
||||||
|
|
||||||
if not tbl.contains(valid_modes, string.upper(mode)) then
|
if not tbl.contains(valid_modes, str.upper(mode)) then
|
||||||
error("Invalid temp_store mode: " .. mode)
|
error("Invalid temp_store mode: " .. mode)
|
||||||
end
|
end
|
||||||
|
|
||||||
return self:exec(string.template("PRAGMA temp_store = ${mode}", {mode = string.upper(mode)}))
|
return self:exec(str.template("PRAGMA temp_store = ${mode}", {mode = str.upper(mode)}))
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Connection management
|
-- Connection management
|
||||||
@ -382,8 +383,8 @@ function sqlite.open(database_path)
|
|||||||
database_path = database_path or ":memory:"
|
database_path = database_path or ":memory:"
|
||||||
|
|
||||||
if database_path ~= ":memory:" then
|
if database_path ~= ":memory:" then
|
||||||
database_path = string.trim(database_path)
|
database_path = str.trim(database_path)
|
||||||
if string.is_blank(database_path) then
|
if str.is_blank(database_path) then
|
||||||
database_path = ":memory:"
|
database_path = ":memory:"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -403,7 +404,7 @@ sqlite.connect = sqlite.open
|
|||||||
function sqlite.query(database_path, query_str, ...)
|
function sqlite.query(database_path, query_str, ...)
|
||||||
local conn = sqlite.open(database_path)
|
local conn = sqlite.open(database_path)
|
||||||
if not conn then
|
if not conn then
|
||||||
error(string.template("Failed to open SQLite database: ${path}", {
|
error(str.template("Failed to open SQLite database: ${path}", {
|
||||||
path = database_path or ":memory:"
|
path = database_path or ":memory:"
|
||||||
}))
|
}))
|
||||||
end
|
end
|
||||||
@ -416,7 +417,7 @@ end
|
|||||||
function sqlite.exec(database_path, query_str, ...)
|
function sqlite.exec(database_path, query_str, ...)
|
||||||
local conn = sqlite.open(database_path)
|
local conn = sqlite.open(database_path)
|
||||||
if not conn then
|
if not conn then
|
||||||
error(string.template("Failed to open SQLite database: ${path}", {
|
error(str.template("Failed to open SQLite database: ${path}", {
|
||||||
path = database_path or ":memory:"
|
path = database_path or ":memory:"
|
||||||
}))
|
}))
|
||||||
end
|
end
|
||||||
@ -464,14 +465,14 @@ function sqlite.migrate(database_path, migrations)
|
|||||||
local error_msg = ""
|
local error_msg = ""
|
||||||
|
|
||||||
for _, migration in ipairs(migrations) do
|
for _, migration in ipairs(migrations) do
|
||||||
if not migration.name or string.is_blank(migration.name) then
|
if not migration.name or str.is_blank(migration.name) then
|
||||||
error_msg = "Migration must have a non-empty name"
|
error_msg = "Migration must have a non-empty name"
|
||||||
success = false
|
success = false
|
||||||
break
|
break
|
||||||
end
|
end
|
||||||
|
|
||||||
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?",
|
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?",
|
||||||
string.trim(migration.name))
|
str.trim(migration.name))
|
||||||
if not existing then
|
if not existing then
|
||||||
local ok, err = pcall(function()
|
local ok, err = pcall(function()
|
||||||
if type(migration.up) == "string" then
|
if type(migration.up) == "string" then
|
||||||
@ -484,11 +485,11 @@ function sqlite.migrate(database_path, migrations)
|
|||||||
end)
|
end)
|
||||||
|
|
||||||
if ok then
|
if ok then
|
||||||
conn:exec("INSERT INTO _migrations (name) VALUES (?)", string.trim(migration.name))
|
conn:exec("INSERT INTO _migrations (name) VALUES (?)", str.trim(migration.name))
|
||||||
print(string.template("Applied migration: ${name}", {name = migration.name}))
|
print(str.template("Applied migration: ${name}", {name = migration.name}))
|
||||||
else
|
else
|
||||||
success = false
|
success = false
|
||||||
error_msg = string.template("Migration '${name}' failed: ${error}", {
|
error_msg = str.template("Migration '${name}' failed: ${error}", {
|
||||||
name = migration.name,
|
name = migration.name,
|
||||||
error = err or "unknown error"
|
error = err or "unknown error"
|
||||||
})
|
})
|
||||||
@ -515,7 +516,7 @@ function sqlite.to_array(results, column_name)
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.is_blank(column_name) then
|
if str.is_blank(column_name) then
|
||||||
error("Column name cannot be empty")
|
error("Column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -527,7 +528,7 @@ function sqlite.to_map(results, key_column, value_column)
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.is_blank(key_column) then
|
if str.is_blank(key_column) then
|
||||||
error("Key column name cannot be empty")
|
error("Key column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -544,7 +545,7 @@ function sqlite.group_by(results, column_name)
|
|||||||
return {}
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if string.is_blank(column_name) then
|
if str.is_blank(column_name) then
|
||||||
error("Column name cannot be empty")
|
error("Column name cannot be empty")
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -562,18 +563,18 @@ function sqlite.print_results(results)
|
|||||||
tbl.sort(columns)
|
tbl.sort(columns)
|
||||||
|
|
||||||
-- Calculate column widths
|
-- Calculate column widths
|
||||||
local widths = tbl.map_values(tbl.to_map(columns, function(col) return col end, function(col) return string.length(col) end), function(width) return width end)
|
local widths = tbl.map_values(tbl.to_map(columns, function(col) return col end, function(col) return str.length(col) end), function(width) return width end)
|
||||||
|
|
||||||
for _, row in ipairs(results) do
|
for _, row in ipairs(results) do
|
||||||
for _, col in ipairs(columns) do
|
for _, col in ipairs(columns) do
|
||||||
local value = tostring(row[col] or "")
|
local value = tostring(row[col] or "")
|
||||||
widths[col] = math.max(widths[col], string.length(value))
|
widths[col] = math.max(widths[col], str.length(value))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Print header
|
-- Print header
|
||||||
local header_parts = tbl.map(columns, function(col) return string.pad_right(col, widths[col]) end)
|
local header_parts = tbl.map(columns, function(col) return str.pad_right(col, widths[col]) end)
|
||||||
local separator_parts = tbl.map(columns, function(col) return string.repeat_("-", widths[col]) end)
|
local separator_parts = tbl.map(columns, function(col) return str.repeat_("-", widths[col]) end)
|
||||||
|
|
||||||
print(tbl.concat(header_parts, " | "))
|
print(tbl.concat(header_parts, " | "))
|
||||||
print(tbl.concat(separator_parts, "-+-"))
|
print(tbl.concat(separator_parts, "-+-"))
|
||||||
@ -582,7 +583,7 @@ function sqlite.print_results(results)
|
|||||||
for _, row in ipairs(results) do
|
for _, row in ipairs(results) do
|
||||||
local value_parts = tbl.map(columns, function(col)
|
local value_parts = tbl.map(columns, function(col)
|
||||||
local value = tostring(row[col] or "")
|
local value = tostring(row[col] or "")
|
||||||
return string.pad_right(value, widths[col])
|
return str.pad_right(value, widths[col])
|
||||||
end)
|
end)
|
||||||
print(tbl.concat(value_parts, " | "))
|
print(tbl.concat(value_parts, " | "))
|
||||||
end
|
end
|
||||||
|
|||||||
@ -1,666 +0,0 @@
|
|||||||
local _orig_find = string.find
|
|
||||||
local _orig_match = string.match
|
|
||||||
local REVERSE_THRESHOLD = 100
|
|
||||||
local LENGTH_THRESHOLD = 1000
|
|
||||||
|
|
||||||
function string.split(s, delimiter)
|
|
||||||
if type(s) ~= "string" then error("string.split: first argument must be a string", 2) end
|
|
||||||
if type(delimiter) ~= "string" then error("string.split: second argument must be a string", 2) end
|
|
||||||
|
|
||||||
if delimiter == "" then
|
|
||||||
local result = {}
|
|
||||||
for i = 1, #s do
|
|
||||||
result[i] = s:sub(i, i)
|
|
||||||
end
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
|
|
||||||
local result = {}
|
|
||||||
local start = 1
|
|
||||||
local delimiter_len = #delimiter
|
|
||||||
|
|
||||||
while true do
|
|
||||||
local pos = _orig_find(s, delimiter, start, true) -- Use original find
|
|
||||||
if not pos then
|
|
||||||
table.insert(result, s:sub(start))
|
|
||||||
break
|
|
||||||
end
|
|
||||||
table.insert(result, s:sub(start, pos - 1))
|
|
||||||
start = pos + delimiter_len
|
|
||||||
end
|
|
||||||
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
getmetatable("").__index.split = string.split
|
|
||||||
|
|
||||||
function string.join(arr, separator)
|
|
||||||
if type(arr) ~= "table" then error("string.join: first argument must be a table", 2) end
|
|
||||||
if type(separator) ~= "string" then error("string.join: second argument must be a string", 2) end
|
|
||||||
|
|
||||||
return table.concat(arr, separator)
|
|
||||||
end
|
|
||||||
|
|
||||||
function string.trim(s, cutset)
|
|
||||||
if type(s) ~= "string" then error("string.trim: first argument must be a string", 2) end
|
|
||||||
if cutset then
|
|
||||||
if type(cutset) ~= "string" then error("string.trim: second argument must be a string", 2) end
|
|
||||||
local escaped = cutset:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1")
|
|
||||||
local pattern = "^[" .. escaped .. "]*(.-)[" .. escaped .. "]*$"
|
|
||||||
return s:match(pattern)
|
|
||||||
else
|
|
||||||
return s:match("^%s*(.-)%s*$")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
getmetatable("").__index.trim = string.trim
|
|
||||||
|
|
||||||
function string.trim_left(s, cutset)
|
|
||||||
if type(s) ~= "string" then error("string.trim_left: first argument must be a string", 2) end
|
|
||||||
if cutset then
|
|
||||||
if type(cutset) ~= "string" then error("string.trim_left: second argument must be a string", 2) end
|
|
||||||
local pattern = "^[" .. cutset:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1") .. "]*"
|
|
||||||
return s:gsub(pattern, "")
|
|
||||||
else
|
|
||||||
return s:match("^%s*(.*)")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
getmetatable("").__index.trim_left = string.trim_left
|
|
||||||
|
|
||||||
function string.trim_right(s, cutset)
|
|
||||||
if type(s) ~= "string" then error("string.trim_right: first argument must be a string", 2) end
|
|
||||||
if cutset then
|
|
||||||
if type(cutset) ~= "string" then error("string.trim_right: second argument must be a string", 2) end
|
|
||||||
local escaped = cutset:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1")
|
|
||||||
local pattern = "[" .. escaped .. "]*$"
|
|
||||||
return s:gsub(pattern, "")
|
|
||||||
else
|
|
||||||
return s:match("(.-)%s*$")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function string.title(s)
|
|
||||||
if type(s) ~= "string" then error("string.title: argument must be a string", 2) end
|
|
||||||
return s:gsub("(%w)([%w]*)", function(first, rest)
|
|
||||||
return first:upper() .. rest:lower()
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.title = string.title
|
|
||||||
|
|
||||||
function string.contains(s, substr)
|
|
||||||
if type(s) ~= "string" then error("string.contains: first argument must be a string", 2) end
|
|
||||||
if type(substr) ~= "string" then error("string.contains: second argument must be a string", 2) end
|
|
||||||
return _orig_find(s, substr, 1, true) ~= nil
|
|
||||||
end
|
|
||||||
getmetatable("").__index.contains = string.contains
|
|
||||||
|
|
||||||
function string.starts_with(s, prefix)
|
|
||||||
if type(s) ~= "string" then error("string.starts_with: first argument must be a string", 2) end
|
|
||||||
if type(prefix) ~= "string" then error("string.starts_with: second argument must be a string", 2) end
|
|
||||||
return s:sub(1, #prefix) == prefix
|
|
||||||
end
|
|
||||||
getmetatable("").__index.starts_with = string.starts_with
|
|
||||||
|
|
||||||
function string.ends_with(s, suffix)
|
|
||||||
if type(s) ~= "string" then error("string.ends_with: first argument must be a string", 2) end
|
|
||||||
if type(suffix) ~= "string" then error("string.ends_with: second argument must be a string", 2) end
|
|
||||||
if #suffix == 0 then return true end
|
|
||||||
return s:sub(-#suffix) == suffix
|
|
||||||
end
|
|
||||||
getmetatable("").__index.ends_with = string.ends_with
|
|
||||||
|
|
||||||
function string.replace(s, old, new)
|
|
||||||
if type(s) ~= "string" then error("string.replace: first argument must be a string", 2) end
|
|
||||||
if type(old) ~= "string" then error("string.replace: second argument must be a string", 2) end
|
|
||||||
if type(new) ~= "string" then error("string.replace: third argument must be a string", 2) end
|
|
||||||
if old == "" then error("string.replace: cannot replace empty string", 2) end
|
|
||||||
return s:gsub(old:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1"), new)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.replace = string.replace
|
|
||||||
|
|
||||||
function string.replace_n(s, old, new, n)
|
|
||||||
if type(s) ~= "string" then error("string.replace_n: first argument must be a string", 2) end
|
|
||||||
if type(old) ~= "string" then error("string.replace_n: second argument must be a string", 2) end
|
|
||||||
if type(new) ~= "string" then error("string.replace_n: third argument must be a string", 2) end
|
|
||||||
if type(n) ~= "number" or n < 0 or n ~= math.floor(n) then
|
|
||||||
error("string.replace_n: fourth argument must be a non-negative integer", 2)
|
|
||||||
end
|
|
||||||
if old == "" then error("string.replace_n: cannot replace empty string", 2) end
|
|
||||||
local escaped = old:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1")
|
|
||||||
return (s:gsub(escaped, new, n))
|
|
||||||
end
|
|
||||||
getmetatable("").__index.replace_n = string.replace_n
|
|
||||||
|
|
||||||
function string.index(s, substr)
|
|
||||||
if type(s) ~= "string" then error("string.index: first argument must be a string", 2) end
|
|
||||||
if type(substr) ~= "string" then error("string.index: second argument must be a string", 2) end
|
|
||||||
local pos = _orig_find(s, substr, 1, true)
|
|
||||||
return pos
|
|
||||||
end
|
|
||||||
getmetatable("").__index.index = string.index
|
|
||||||
|
|
||||||
function string.last_index(s, substr)
|
|
||||||
if type(s) ~= "string" then error("string.last_index: first argument must be a string", 2) end
|
|
||||||
if type(substr) ~= "string" then error("string.last_index: second argument must be a string", 2) end
|
|
||||||
local last_pos = nil
|
|
||||||
local pos = 1
|
|
||||||
while true do
|
|
||||||
local found = _orig_find(s, substr, pos, true)
|
|
||||||
if not found then break end
|
|
||||||
last_pos = found
|
|
||||||
pos = found + 1
|
|
||||||
end
|
|
||||||
return last_pos
|
|
||||||
end
|
|
||||||
getmetatable("").__index.last_index = string.last_index
|
|
||||||
|
|
||||||
function string.count(s, substr)
|
|
||||||
if type(s) ~= "string" then error("string.count: first argument must be a string", 2) end
|
|
||||||
if type(substr) ~= "string" then error("string.count: second argument must be a string", 2) end
|
|
||||||
if substr == "" then return #s + 1 end
|
|
||||||
local count = 0
|
|
||||||
local pos = 1
|
|
||||||
while true do
|
|
||||||
local found = _orig_find(s, substr, pos, true)
|
|
||||||
if not found then break end
|
|
||||||
count = count + 1
|
|
||||||
pos = found + #substr
|
|
||||||
end
|
|
||||||
return count
|
|
||||||
end
|
|
||||||
getmetatable("").__index.count = string.count
|
|
||||||
|
|
||||||
function string.repeat_(s, n)
|
|
||||||
if type(s) ~= "string" then error("string.repeat_: first argument must be a string", 2) end
|
|
||||||
if type(n) ~= "number" or n < 0 or n ~= math.floor(n) then
|
|
||||||
error("string.repeat_: second argument must be a non-negative integer", 2)
|
|
||||||
end
|
|
||||||
return string.rep(s, n)
|
|
||||||
end
|
|
||||||
|
|
||||||
function string.reverse(s)
|
|
||||||
if type(s) ~= "string" then error("string.reverse: argument must be a string", 2) end
|
|
||||||
|
|
||||||
if #s > REVERSE_THRESHOLD then
|
|
||||||
local result, err = moonshark.string_reverse(s)
|
|
||||||
if not result then error("string.reverse: " .. err, 2) end
|
|
||||||
return result
|
|
||||||
else
|
|
||||||
local result = {}
|
|
||||||
for i = #s, 1, -1 do
|
|
||||||
result[#result + 1] = s:sub(i, i)
|
|
||||||
end
|
|
||||||
return table.concat(result)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
getmetatable("").__index.reverse = string.reverse
|
|
||||||
|
|
||||||
function string.length(s)
|
|
||||||
if type(s) ~= "string" then error("string.length: argument must be a string", 2) end
|
|
||||||
return moonshark.string_length(s)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.length = string.length
|
|
||||||
|
|
||||||
function string.byte_length(s)
|
|
||||||
if type(s) ~= "string" then error("string.byte_length: argument must be a string", 2) end
|
|
||||||
return moonshark.string_byte_length(s)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.byte_length = string.byte_length
|
|
||||||
|
|
||||||
function string.lines(s)
|
|
||||||
if type(s) ~= "string" then error("string.lines: argument must be a string", 2) end
|
|
||||||
if s == "" then return {""} end
|
|
||||||
|
|
||||||
s = s:gsub("\r\n", "\n"):gsub("\r", "\n")
|
|
||||||
local lines = {}
|
|
||||||
for line in (s .. "\n"):gmatch("([^\n]*)\n") do
|
|
||||||
table.insert(lines, line)
|
|
||||||
end
|
|
||||||
if #lines > 0 and lines[#lines] == "" then
|
|
||||||
table.remove(lines)
|
|
||||||
end
|
|
||||||
return lines
|
|
||||||
end
|
|
||||||
getmetatable("").__index.lines = string.lines
|
|
||||||
|
|
||||||
function string.words(s)
|
|
||||||
if type(s) ~= "string" then error("string.words: argument must be a string", 2) end
|
|
||||||
local words = {}
|
|
||||||
for word in s:gmatch("%S+") do
|
|
||||||
table.insert(words, word)
|
|
||||||
end
|
|
||||||
return words
|
|
||||||
end
|
|
||||||
getmetatable("").__index.words = string.words
|
|
||||||
|
|
||||||
function string.pad_left(s, width, pad_char)
|
|
||||||
if type(s) ~= "string" then error("string.pad_left: first argument must be a string", 2) end
|
|
||||||
if type(width) ~= "number" or width < 0 or width ~= math.floor(width) then
|
|
||||||
error("string.pad_left: second argument must be a non-negative integer", 2)
|
|
||||||
end
|
|
||||||
pad_char = pad_char or " "
|
|
||||||
if type(pad_char) ~= "string" then error("string.pad_left: third argument must be a string", 2) end
|
|
||||||
if #pad_char == 0 then pad_char = " " else pad_char = pad_char:sub(1,1) end
|
|
||||||
local current_len = string.length(s)
|
|
||||||
if current_len >= width then return s end
|
|
||||||
return string.rep(pad_char, width - current_len) .. s
|
|
||||||
end
|
|
||||||
getmetatable("").__index.pad_left = string.pad_left
|
|
||||||
|
|
||||||
function string.pad_right(s, width, pad_char)
|
|
||||||
if type(s) ~= "string" then error("string.pad_right: first argument must be a string", 2) end
|
|
||||||
if type(width) ~= "number" or width < 0 or width ~= math.floor(width) then
|
|
||||||
error("string.pad_right: second argument must be a non-negative integer", 2)
|
|
||||||
end
|
|
||||||
pad_char = pad_char or " "
|
|
||||||
if type(pad_char) ~= "string" then error("string.pad_right: third argument must be a string", 2) end
|
|
||||||
if #pad_char == 0 then pad_char = " " else pad_char = pad_char:sub(1,1) end
|
|
||||||
local current_len = string.length(s)
|
|
||||||
if current_len >= width then return s end
|
|
||||||
return s .. string.rep(pad_char, width - current_len)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.pad_right = string.pad_right
|
|
||||||
|
|
||||||
function string.slice(s, start, end_pos)
|
|
||||||
if type(s) ~= "string" then error("string.slice: first argument must be a string", 2) end
|
|
||||||
if type(start) ~= "number" or start ~= math.floor(start) then
|
|
||||||
error("string.slice: second argument must be an integer", 2)
|
|
||||||
end
|
|
||||||
if end_pos ~= nil and (type(end_pos) ~= "number" or end_pos ~= math.floor(end_pos)) then
|
|
||||||
error("string.slice: third argument must be an integer", 2)
|
|
||||||
end
|
|
||||||
local result, err = moonshark.string_slice(s, start, end_pos)
|
|
||||||
if not result then error("string.slice: " .. err, 2) end
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
getmetatable("").__index.slice = string.slice
|
|
||||||
|
|
||||||
-- Custom find that returns matched substring instead of position
|
|
||||||
function string.find(s, pattern, init, plain)
|
|
||||||
if type(s) ~= "string" then error("string.find: first argument must be a string", 2) end
|
|
||||||
if type(pattern) ~= "string" then error("string.find: second argument must be a string", 2) end
|
|
||||||
local start_pos, end_pos = _orig_find(s, pattern, init, plain)
|
|
||||||
if start_pos then
|
|
||||||
return s:sub(start_pos, end_pos)
|
|
||||||
end
|
|
||||||
return nil
|
|
||||||
end
|
|
||||||
getmetatable("").__index.find = string.find
|
|
||||||
|
|
||||||
function string.find_all(s, pattern)
|
|
||||||
if type(s) ~= "string" then error("string.find_all: first argument must be a string", 2) end
|
|
||||||
if type(pattern) ~= "string" then error("string.find_all: second argument must be a string", 2) end
|
|
||||||
local matches = {}
|
|
||||||
for match in s:gmatch(pattern) do
|
|
||||||
table.insert(matches, match)
|
|
||||||
end
|
|
||||||
return matches
|
|
||||||
end
|
|
||||||
getmetatable("").__index.find_all = string.find_all
|
|
||||||
|
|
||||||
function string.to_number(s)
|
|
||||||
if type(s) ~= "string" then error("string.to_number: argument must be a string", 2) end
|
|
||||||
s = string.trim(s)
|
|
||||||
return tonumber(s)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.to_number = string.to_number
|
|
||||||
|
|
||||||
function string.is_numeric(s)
|
|
||||||
if type(s) ~= "string" then error("string.is_numeric: argument must be a string", 2) end
|
|
||||||
s = string.trim(s)
|
|
||||||
return tonumber(s) ~= nil
|
|
||||||
end
|
|
||||||
getmetatable("").__index.is_numeric = string.is_numeric
|
|
||||||
|
|
||||||
function string.is_alpha(s)
|
|
||||||
if type(s) ~= "string" then error("string.is_alpha: argument must be a string", 2) end
|
|
||||||
if #s == 0 then return false end
|
|
||||||
return s:match("^%a+$") ~= nil
|
|
||||||
end
|
|
||||||
getmetatable("").__index.is_alpha = string.is_alpha
|
|
||||||
|
|
||||||
function string.is_alphanumeric(s)
|
|
||||||
if type(s) ~= "string" then error("string.is_alphanumeric: argument must be a string", 2) end
|
|
||||||
if #s == 0 then return false end
|
|
||||||
return s:match("^%w+$") ~= nil
|
|
||||||
end
|
|
||||||
getmetatable("").__index.is_alphanumeric = string.is_alphanumeric
|
|
||||||
|
|
||||||
function string.is_utf8(s)
|
|
||||||
if type(s) ~= "string" then error("string.is_utf8: argument must be a string", 2) end
|
|
||||||
return moonshark.string_is_valid_utf8(s)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.is_utf8 = string.is_utf8
|
|
||||||
|
|
||||||
function string.is_empty(s)
|
|
||||||
return s == nil or s == ""
|
|
||||||
end
|
|
||||||
getmetatable("").__index.is_empty = string.is_empty
|
|
||||||
|
|
||||||
function string.is_blank(s)
|
|
||||||
return s == nil or s == "" or string.trim(s) == ""
|
|
||||||
end
|
|
||||||
getmetatable("").__index.is_blank = string.is_blank
|
|
||||||
|
|
||||||
function string.capitalize(s)
|
|
||||||
if type(s) ~= "string" then error("string.capitalize: argument must be a string", 2) end
|
|
||||||
return s:gsub("(%a)([%w_']*)", function(first, rest)
|
|
||||||
return first:upper() .. rest:lower()
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.capitalize = string.capitalize
|
|
||||||
|
|
||||||
function string.camel_case(s)
|
|
||||||
if type(s) ~= "string" then error("string.camel_case: argument must be a string", 2) end
|
|
||||||
local words = string.words(s)
|
|
||||||
if #words == 0 then return s end
|
|
||||||
local result = words[1]:lower()
|
|
||||||
for i = 2, #words do
|
|
||||||
result = result .. words[i]:sub(1,1):upper() .. words[i]:sub(2):lower()
|
|
||||||
end
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
getmetatable("").__index.camel_case = string.camel_case
|
|
||||||
|
|
||||||
function string.pascal_case(s)
|
|
||||||
if type(s) ~= "string" then error("string.pascal_case: argument must be a string", 2) end
|
|
||||||
local words = string.words(s)
|
|
||||||
local result = ""
|
|
||||||
for _, word in ipairs(words) do
|
|
||||||
result = result .. word:sub(1,1):upper() .. word:sub(2):lower()
|
|
||||||
end
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
getmetatable("").__index.pascal_case = string.pascal_case
|
|
||||||
|
|
||||||
function string.snake_case(s)
|
|
||||||
if type(s) ~= "string" then error("string.snake_case: argument must be a string", 2) end
|
|
||||||
local words = string.words(s)
|
|
||||||
local result = {}
|
|
||||||
for _, word in ipairs(words) do
|
|
||||||
table.insert(result, word:lower())
|
|
||||||
end
|
|
||||||
return table.concat(result, "_")
|
|
||||||
end
|
|
||||||
getmetatable("").__index.snake_case = string.snake_case
|
|
||||||
|
|
||||||
function string.kebab_case(s)
|
|
||||||
if type(s) ~= "string" then error("string.kebab_case: argument must be a string", 2) end
|
|
||||||
local words = string.words(s)
|
|
||||||
local result = {}
|
|
||||||
for _, word in ipairs(words) do
|
|
||||||
table.insert(result, word:lower())
|
|
||||||
end
|
|
||||||
return table.concat(result, "-")
|
|
||||||
end
|
|
||||||
getmetatable("").__index.kebab_case = string.kebab_case
|
|
||||||
|
|
||||||
function string.screaming_snake_case(s)
|
|
||||||
if type(s) ~= "string" then error("string.screaming_snake_case: argument must be a string", 2) end
|
|
||||||
return string.snake_case(s):upper()
|
|
||||||
end
|
|
||||||
getmetatable("").__index.screaming_snake_case = string.screaming_snake_case
|
|
||||||
|
|
||||||
function string.center(s, width, fill_char)
|
|
||||||
if type(s) ~= "string" then error("string.center: first argument must be a string", 2) end
|
|
||||||
if type(width) ~= "number" or width < 0 or width ~= math.floor(width) then
|
|
||||||
error("string.center: second argument must be a non-negative integer", 2)
|
|
||||||
end
|
|
||||||
fill_char = fill_char or " "
|
|
||||||
if type(fill_char) ~= "string" or #fill_char == 0 then
|
|
||||||
error("string.center: fill character must be a non-empty string", 2)
|
|
||||||
end
|
|
||||||
fill_char = fill_char:sub(1,1)
|
|
||||||
|
|
||||||
local len = string.length(s)
|
|
||||||
if len >= width then return s end
|
|
||||||
|
|
||||||
local pad_total = width - len
|
|
||||||
local pad_left = math.floor(pad_total / 2)
|
|
||||||
local pad_right = pad_total - pad_left
|
|
||||||
|
|
||||||
return string.rep(fill_char, pad_left) .. s .. string.rep(fill_char, pad_right)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.center = string.center
|
|
||||||
|
|
||||||
function string.truncate(s, max_length, suffix)
|
|
||||||
if type(s) ~= "string" then error("string.truncate: first argument must be a string", 2) end
|
|
||||||
if type(max_length) ~= "number" or max_length < 0 or max_length ~= math.floor(max_length) then
|
|
||||||
error("string.truncate: second argument must be a non-negative integer", 2)
|
|
||||||
end
|
|
||||||
suffix = suffix or "..."
|
|
||||||
if type(suffix) ~= "string" then error("string.truncate: third argument must be a string", 2) end
|
|
||||||
|
|
||||||
local len = string.length(s)
|
|
||||||
if len <= max_length then return s end
|
|
||||||
|
|
||||||
local suffix_len = string.length(suffix)
|
|
||||||
if max_length <= suffix_len then
|
|
||||||
return string.slice(suffix, 1, max_length)
|
|
||||||
end
|
|
||||||
|
|
||||||
local main_part = string.slice(s, 1, max_length - suffix_len)
|
|
||||||
main_part = string.trim_right(main_part)
|
|
||||||
return main_part .. suffix
|
|
||||||
end
|
|
||||||
getmetatable("").__index.truncate = string.truncate
|
|
||||||
|
|
||||||
function string.wrap(s, width)
|
|
||||||
if type(s) ~= "string" then error("string.wrap: first argument must be a string", 2) end
|
|
||||||
if type(width) ~= "number" or width <= 0 or width ~= math.floor(width) then
|
|
||||||
error("string.wrap: second argument must be a positive integer", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
if s == "" then return {""} end
|
|
||||||
|
|
||||||
local words = string.words(s)
|
|
||||||
if #words == 0 then return {""} end
|
|
||||||
|
|
||||||
local lines = {}
|
|
||||||
local current_line = ""
|
|
||||||
|
|
||||||
for _, word in ipairs(words) do
|
|
||||||
if string.length(word) > width then
|
|
||||||
if current_line ~= "" then
|
|
||||||
table.insert(lines, current_line)
|
|
||||||
current_line = ""
|
|
||||||
end
|
|
||||||
table.insert(lines, word)
|
|
||||||
elseif current_line == "" then
|
|
||||||
current_line = word
|
|
||||||
elseif string.length(current_line) + 1 + string.length(word) <= width then
|
|
||||||
current_line = current_line .. " " .. word
|
|
||||||
else
|
|
||||||
table.insert(lines, current_line)
|
|
||||||
current_line = word
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if current_line ~= "" then
|
|
||||||
table.insert(lines, current_line)
|
|
||||||
end
|
|
||||||
|
|
||||||
return lines
|
|
||||||
end
|
|
||||||
getmetatable("").__index.wrap = string.wrap
|
|
||||||
|
|
||||||
function string.dedent(s)
|
|
||||||
if type(s) ~= "string" then error("string.dedent: argument must be a string", 2) end
|
|
||||||
|
|
||||||
local lines = string.lines(s)
|
|
||||||
if #lines == 0 then return s end
|
|
||||||
|
|
||||||
local min_indent = math.huge
|
|
||||||
for _, line in ipairs(lines) do
|
|
||||||
if string.trim(line) ~= "" then
|
|
||||||
local indent = 0
|
|
||||||
for i = 1, #line do
|
|
||||||
if line:sub(i,i) == " " then
|
|
||||||
indent = indent + 1
|
|
||||||
else
|
|
||||||
break
|
|
||||||
end
|
|
||||||
end
|
|
||||||
min_indent = math.min(min_indent, indent)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if min_indent == math.huge then return s end
|
|
||||||
|
|
||||||
local result = {}
|
|
||||||
for _, line in ipairs(lines) do
|
|
||||||
if string.trim(line) == "" then
|
|
||||||
table.insert(result, "")
|
|
||||||
else
|
|
||||||
table.insert(result, line:sub(min_indent + 1))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return table.concat(result, "\n")
|
|
||||||
end
|
|
||||||
getmetatable("").__index.dedent = string.dedent
|
|
||||||
|
|
||||||
function string.escape(s)
|
|
||||||
if type(s) ~= "string" then error("string.escape: argument must be a string", 2) end
|
|
||||||
return (s:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1"))
|
|
||||||
end
|
|
||||||
getmetatable("").__index.escape = string.escape
|
|
||||||
|
|
||||||
function string.shell_quote(s)
|
|
||||||
if type(s) ~= "string" then error("string.shell_quote: argument must be a string", 2) end
|
|
||||||
if s:match("^[%w%.%-_/]+$") then
|
|
||||||
return s
|
|
||||||
end
|
|
||||||
return "'" .. s:gsub("'", "'\"'\"'") .. "'"
|
|
||||||
end
|
|
||||||
getmetatable("").__index.shell_quote = string.shell_quote
|
|
||||||
|
|
||||||
function string.url_encode(s)
|
|
||||||
if type(s) ~= "string" then error("string.url_encode: argument must be a string", 2) end
|
|
||||||
return s:gsub("([^%w%-%.%_%~])", function(c)
|
|
||||||
return string.format("%%%02X", string.byte(c))
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.url_encode = string.url_encode
|
|
||||||
|
|
||||||
function string.url_decode(s)
|
|
||||||
if type(s) ~= "string" then error("string.url_decode: argument must be a string", 2) end
|
|
||||||
s = s:gsub("+", " ")
|
|
||||||
return s:gsub("%%(%x%x)", function(hex)
|
|
||||||
return string.char(tonumber(hex, 16))
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.url_decode = string.url_decode
|
|
||||||
|
|
||||||
function string.slug(s)
|
|
||||||
if type(s) ~= "string" then error("string.slug: argument must be a string", 2) end
|
|
||||||
if s == "" then return "" end
|
|
||||||
|
|
||||||
local result = s:lower()
|
|
||||||
-- Remove accents first
|
|
||||||
result = string.remove_accents(result)
|
|
||||||
-- Keep only alphanumeric, spaces, and hyphens
|
|
||||||
result = result:gsub("[^%w%s%-]", "")
|
|
||||||
-- Replace spaces with hyphens
|
|
||||||
result = result:gsub("%s+", "-")
|
|
||||||
-- Remove duplicate hyphens
|
|
||||||
result = result:gsub("%-+", "-")
|
|
||||||
-- Remove leading/trailing hyphens
|
|
||||||
result = result:gsub("^%-", "")
|
|
||||||
result = result:gsub("%-$", "")
|
|
||||||
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
getmetatable("").__index.slug = string.slug
|
|
||||||
|
|
||||||
function string.iequals(a, b)
|
|
||||||
if type(a) ~= "string" then error("string.iequals: first argument must be a string", 2) end
|
|
||||||
if type(b) ~= "string" then error("string.iequals: second argument must be a string", 2) end
|
|
||||||
return string.lower(a) == string.lower(b)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.iequals = string.iequals
|
|
||||||
|
|
||||||
function string.is_whitespace(s)
|
|
||||||
if type(s) ~= "string" then error("string.is_whitespace: argument must be a string", 2) end
|
|
||||||
return s:match("^%s*$") ~= nil
|
|
||||||
end
|
|
||||||
getmetatable("").__index.is_whitespace = string.is_whitespace
|
|
||||||
|
|
||||||
function string.strip_whitespace(s)
|
|
||||||
if type(s) ~= "string" then error("string.strip_whitespace: argument must be a string", 2) end
|
|
||||||
return s:gsub("%s", "")
|
|
||||||
end
|
|
||||||
getmetatable("").__index.strip_whitespace = string.strip_whitespace
|
|
||||||
|
|
||||||
function string.normalize_whitespace(s)
|
|
||||||
if type(s) ~= "string" then error("string.normalize_whitespace: argument must be a string", 2) end
|
|
||||||
return string.trim((s:gsub("%s+", " ")))
|
|
||||||
end
|
|
||||||
getmetatable("").__index.normalize_whitespace = string.normalize_whitespace
|
|
||||||
|
|
||||||
function string.extract_numbers(s)
|
|
||||||
if type(s) ~= "string" then error("string.extract_numbers: argument must be a string", 2) end
|
|
||||||
local numbers = {}
|
|
||||||
for num in s:gmatch("%-?%d+%.?%d*") do
|
|
||||||
local n = tonumber(num)
|
|
||||||
if n then table.insert(numbers, n) end
|
|
||||||
end
|
|
||||||
return numbers
|
|
||||||
end
|
|
||||||
getmetatable("").__index.extract_numbers = string.extract_numbers
|
|
||||||
|
|
||||||
function string.remove_accents(s)
|
|
||||||
if type(s) ~= "string" then error("string.remove_accents: argument must be a string", 2) end
|
|
||||||
local accents = {
|
|
||||||
["á"] = "a", ["à"] = "a", ["ä"] = "a", ["â"] = "a", ["ã"] = "a", ["å"] = "a",
|
|
||||||
["Á"] = "A", ["À"] = "A", ["Ä"] = "A", ["Â"] = "A", ["Ã"] = "A", ["Å"] = "A",
|
|
||||||
["é"] = "e", ["è"] = "e", ["ë"] = "e", ["ê"] = "e",
|
|
||||||
["É"] = "E", ["È"] = "E", ["Ë"] = "E", ["Ê"] = "E",
|
|
||||||
["í"] = "i", ["ì"] = "i", ["ï"] = "i", ["î"] = "i",
|
|
||||||
["Í"] = "I", ["Ì"] = "I", ["Ï"] = "I", ["Î"] = "I",
|
|
||||||
["ó"] = "o", ["ò"] = "o", ["ö"] = "o", ["ô"] = "o", ["õ"] = "o",
|
|
||||||
["Ó"] = "O", ["Ò"] = "O", ["Ö"] = "O", ["Ô"] = "O", ["Õ"] = "O",
|
|
||||||
["ú"] = "u", ["ù"] = "u", ["ü"] = "u", ["û"] = "u",
|
|
||||||
["Ú"] = "U", ["Ù"] = "U", ["Ü"] = "U", ["Û"] = "U",
|
|
||||||
["ñ"] = "n", ["Ñ"] = "N",
|
|
||||||
["ç"] = "c", ["Ç"] = "C"
|
|
||||||
}
|
|
||||||
|
|
||||||
local result = s
|
|
||||||
for accented, plain in pairs(accents) do
|
|
||||||
result = result:gsub(accented, plain)
|
|
||||||
end
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
getmetatable("").__index.remove_accents = string.remove_accents
|
|
||||||
|
|
||||||
function string.template(template_str, vars)
|
|
||||||
if type(template_str) ~= "string" then error("string.template: first argument must be a string", 2) end
|
|
||||||
if type(vars) ~= "table" then error("string.template: second argument must be a table", 2) end
|
|
||||||
|
|
||||||
return template_str:gsub("%${([%w_%.]+)}", function(path)
|
|
||||||
local value = vars
|
|
||||||
|
|
||||||
-- Handle simple variables (no dots)
|
|
||||||
if not path:match("%.") then
|
|
||||||
return tostring(value[path] or "")
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Handle nested properties
|
|
||||||
for key in path:gmatch("[^%.]+") do
|
|
||||||
if type(value) == "table" and value[key] ~= nil then
|
|
||||||
value = value[key]
|
|
||||||
else
|
|
||||||
return ""
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return tostring(value)
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
getmetatable("").__index.template = string.template
|
|
||||||
|
|
||||||
function string.random(length, charset)
|
|
||||||
local result, err = moonshark.random_string(length, charset)
|
|
||||||
if not result then
|
|
||||||
error(err)
|
|
||||||
end
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
715
modules/string/string.lua
Normal file
715
modules/string/string.lua
Normal file
@ -0,0 +1,715 @@
|
|||||||
|
local str = {}
|
||||||
|
|
||||||
|
-- Performance thresholds based on benchmark results
|
||||||
|
local REVERSE_THRESHOLD = 100 -- Use Go for strings longer than this
|
||||||
|
local LENGTH_THRESHOLD = 1000 -- Use Go for ASCII strings longer than this
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- BASIC STRING OPERATIONS (Optimized Lua/Go hybrid)
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function str.split(s, delimiter)
|
||||||
|
if type(s) ~= "string" then error("str.split: first argument must be a string", 2) end
|
||||||
|
if type(delimiter) ~= "string" then error("str.split: second argument must be a string", 2) end
|
||||||
|
|
||||||
|
if delimiter == "" then
|
||||||
|
local result = {}
|
||||||
|
for i = 1, #s do
|
||||||
|
result[i] = s:sub(i, i)
|
||||||
|
end
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
local result = {}
|
||||||
|
local start = 1
|
||||||
|
local delimiter_len = #delimiter
|
||||||
|
|
||||||
|
while true do
|
||||||
|
local pos = s:find(delimiter, start, true)
|
||||||
|
if not pos then
|
||||||
|
table.insert(result, s:sub(start))
|
||||||
|
break
|
||||||
|
end
|
||||||
|
table.insert(result, s:sub(start, pos - 1))
|
||||||
|
start = pos + delimiter_len
|
||||||
|
end
|
||||||
|
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.join(arr, separator)
|
||||||
|
if type(arr) ~= "table" then error("str.join: first argument must be a table", 2) end
|
||||||
|
if type(separator) ~= "string" then error("str.join: second argument must be a string", 2) end
|
||||||
|
|
||||||
|
return table.concat(arr, separator)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.trim(s)
|
||||||
|
if type(s) ~= "string" then error("str.trim: argument must be a string", 2) end
|
||||||
|
return s:match("^%s*(.-)%s*$")
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.trim_left(s, cutset)
|
||||||
|
if type(s) ~= "string" then error("str.trim_left: first argument must be a string", 2) end
|
||||||
|
if cutset then
|
||||||
|
if type(cutset) ~= "string" then error("str.trim_left: second argument must be a string", 2) end
|
||||||
|
local pattern = "^[" .. cutset:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1") .. "]*"
|
||||||
|
return s:gsub(pattern, "")
|
||||||
|
else
|
||||||
|
return s:match("^%s*(.*)")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.trim_right(s, cutset)
|
||||||
|
if type(s) ~= "string" then error("str.trim_right: first argument must be a string", 2) end
|
||||||
|
if cutset then
|
||||||
|
if type(cutset) ~= "string" then error("str.trim_right: second argument must be a string", 2) end
|
||||||
|
local pattern = "[" .. cutset:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1") .. "]*$"
|
||||||
|
return s:gsub(pattern, "")
|
||||||
|
else
|
||||||
|
return s:match("(.-)%s*$")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.upper(s)
|
||||||
|
if type(s) ~= "string" then error("str.upper: argument must be a string", 2) end
|
||||||
|
return s:upper()
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.lower(s)
|
||||||
|
if type(s) ~= "string" then error("str.lower: argument must be a string", 2) end
|
||||||
|
return s:lower()
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.title(s)
|
||||||
|
if type(s) ~= "string" then error("str.title: argument must be a string", 2) end
|
||||||
|
return s:gsub("(%a)([%w_']*)", function(first, rest)
|
||||||
|
return first:upper() .. rest:lower()
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.contains(s, substr)
|
||||||
|
if type(s) ~= "string" then error("str.contains: first argument must be a string", 2) end
|
||||||
|
if type(substr) ~= "string" then error("str.contains: second argument must be a string", 2) end
|
||||||
|
return s:find(substr, 1, true) ~= nil
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.starts_with(s, prefix)
|
||||||
|
if type(s) ~= "string" then error("str.starts_with: first argument must be a string", 2) end
|
||||||
|
if type(prefix) ~= "string" then error("str.starts_with: second argument must be a string", 2) end
|
||||||
|
return s:sub(1, #prefix) == prefix
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.ends_with(s, suffix)
|
||||||
|
if type(s) ~= "string" then error("str.ends_with: first argument must be a string", 2) end
|
||||||
|
if type(suffix) ~= "string" then error("str.ends_with: second argument must be a string", 2) end
|
||||||
|
return s:sub(-#suffix) == suffix
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.replace(s, old, new)
|
||||||
|
if type(s) ~= "string" then error("str.replace: first argument must be a string", 2) end
|
||||||
|
if type(old) ~= "string" then error("str.replace: second argument must be a string", 2) end
|
||||||
|
if type(new) ~= "string" then error("str.replace: third argument must be a string", 2) end
|
||||||
|
if old == "" then error("str.replace: cannot replace empty string", 2) end
|
||||||
|
return s:gsub(old:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1"), new)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.replace_n(s, old, new, n)
|
||||||
|
if type(s) ~= "string" then error("str.replace_n: first argument must be a string", 2) end
|
||||||
|
if type(old) ~= "string" then error("str.replace_n: second argument must be a string", 2) end
|
||||||
|
if type(new) ~= "string" then error("str.replace_n: third argument must be a string", 2) end
|
||||||
|
if type(n) ~= "number" or n < 0 or n ~= math.floor(n) then
|
||||||
|
error("str.replace_n: fourth argument must be a non-negative integer", 2)
|
||||||
|
end
|
||||||
|
if old == "" then error("str.replace_n: cannot replace empty string", 2) end
|
||||||
|
local escaped = old:gsub("([%^%$%(%)%%%.%[%]%*%+%-%?])", "%%%1")
|
||||||
|
return (s:gsub(escaped, new, n))
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.index(s, substr)
|
||||||
|
if type(s) ~= "string" then error("str.index: first argument must be a string", 2) end
|
||||||
|
if type(substr) ~= "string" then error("str.index: second argument must be a string", 2) end
|
||||||
|
local pos = s:find(substr, 1, true)
|
||||||
|
return pos
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.last_index(s, substr)
|
||||||
|
if type(s) ~= "string" then error("str.last_index: first argument must be a string", 2) end
|
||||||
|
if type(substr) ~= "string" then error("str.last_index: second argument must be a string", 2) end
|
||||||
|
local last_pos = nil
|
||||||
|
local pos = 1
|
||||||
|
while true do
|
||||||
|
local found = s:find(substr, pos, true)
|
||||||
|
if not found then break end
|
||||||
|
last_pos = found
|
||||||
|
pos = found + 1
|
||||||
|
end
|
||||||
|
return last_pos
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.count(s, substr)
|
||||||
|
if type(s) ~= "string" then error("str.count: first argument must be a string", 2) end
|
||||||
|
if type(substr) ~= "string" then error("str.count: second argument must be a string", 2) end
|
||||||
|
if substr == "" then return #s + 1 end
|
||||||
|
local count = 0
|
||||||
|
local pos = 1
|
||||||
|
while true do
|
||||||
|
local found = s:find(substr, pos, true)
|
||||||
|
if not found then break end
|
||||||
|
count = count + 1
|
||||||
|
pos = found + #substr
|
||||||
|
end
|
||||||
|
return count
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.repeat_(s, n)
|
||||||
|
if type(s) ~= "string" then error("str.repeat_: first argument must be a string", 2) end
|
||||||
|
if type(n) ~= "number" or n < 0 or n ~= math.floor(n) then
|
||||||
|
error("str.repeat_: second argument must be a non-negative integer", 2)
|
||||||
|
end
|
||||||
|
return string.rep(s, n)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.reverse(s)
|
||||||
|
if type(s) ~= "string" then error("str.reverse: argument must be a string", 2) end
|
||||||
|
|
||||||
|
if #s > REVERSE_THRESHOLD then
|
||||||
|
local result, err = moonshark.string_reverse(s)
|
||||||
|
if not result then error("str.reverse: " .. err, 2) end
|
||||||
|
return result
|
||||||
|
else
|
||||||
|
local result = {}
|
||||||
|
for i = #s, 1, -1 do
|
||||||
|
result[#result + 1] = s:sub(i, i)
|
||||||
|
end
|
||||||
|
return table.concat(result)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.length(s)
|
||||||
|
if type(s) ~= "string" then error("str.length: argument must be a string", 2) end
|
||||||
|
|
||||||
|
-- For long ASCII strings, Go is faster. For unicode or short strings, use Go consistently
|
||||||
|
-- since UTF-8 handling is more reliable in Go
|
||||||
|
return moonshark.string_length(s)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.byte_length(s)
|
||||||
|
if type(s) ~= "string" then error("str.byte_length: argument must be a string", 2) end
|
||||||
|
return #s
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.lines(s)
|
||||||
|
if type(s) ~= "string" then error("str.lines: argument must be a string", 2) end
|
||||||
|
if s == "" then return {""} end
|
||||||
|
|
||||||
|
s = s:gsub("\r\n", "\n"):gsub("\r", "\n")
|
||||||
|
local lines = {}
|
||||||
|
for line in (s .. "\n"):gmatch("([^\n]*)\n") do
|
||||||
|
table.insert(lines, line)
|
||||||
|
end
|
||||||
|
if #lines > 0 and lines[#lines] == "" then
|
||||||
|
table.remove(lines)
|
||||||
|
end
|
||||||
|
return lines
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.words(s)
|
||||||
|
if type(s) ~= "string" then error("str.words: argument must be a string", 2) end
|
||||||
|
local words = {}
|
||||||
|
for word in s:gmatch("%S+") do
|
||||||
|
table.insert(words, word)
|
||||||
|
end
|
||||||
|
return words
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.pad_left(s, width, pad_char)
|
||||||
|
if type(s) ~= "string" then error("str.pad_left: first argument must be a string", 2) end
|
||||||
|
if type(width) ~= "number" or width < 0 or width ~= math.floor(width) then
|
||||||
|
error("str.pad_left: second argument must be a non-negative integer", 2)
|
||||||
|
end
|
||||||
|
pad_char = pad_char or " "
|
||||||
|
if type(pad_char) ~= "string" then error("str.pad_left: third argument must be a string", 2) end
|
||||||
|
if #pad_char == 0 then pad_char = " " else pad_char = pad_char:sub(1,1) end
|
||||||
|
local current_len = str.length(s)
|
||||||
|
if current_len >= width then return s end
|
||||||
|
return string.rep(pad_char, width - current_len) .. s
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.pad_right(s, width, pad_char)
|
||||||
|
if type(s) ~= "string" then error("str.pad_right: first argument must be a string", 2) end
|
||||||
|
if type(width) ~= "number" or width < 0 or width ~= math.floor(width) then
|
||||||
|
error("str.pad_right: second argument must be a non-negative integer", 2)
|
||||||
|
end
|
||||||
|
pad_char = pad_char or " "
|
||||||
|
if type(pad_char) ~= "string" then error("str.pad_right: third argument must be a string", 2) end
|
||||||
|
if #pad_char == 0 then pad_char = " " else pad_char = pad_char:sub(1,1) end
|
||||||
|
local current_len = str.length(s)
|
||||||
|
if current_len >= width then return s end
|
||||||
|
return s .. string.rep(pad_char, width - current_len)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.slice(s, start, end_pos)
|
||||||
|
if type(s) ~= "string" then error("str.slice: first argument must be a string", 2) end
|
||||||
|
if type(start) ~= "number" or start ~= math.floor(start) then
|
||||||
|
error("str.slice: second argument must be an integer", 2)
|
||||||
|
end
|
||||||
|
if end_pos ~= nil and (type(end_pos) ~= "number" or end_pos ~= math.floor(end_pos)) then
|
||||||
|
error("str.slice: third argument must be an integer", 2)
|
||||||
|
end
|
||||||
|
local result, err = moonshark.string_slice(s, start, end_pos)
|
||||||
|
if not result then error("str.slice: " .. err, 2) end
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- REGULAR EXPRESSIONS (Optimized Lua patterns)
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function str.match(pattern, s)
|
||||||
|
if type(pattern) ~= "string" then error("str.match: first argument must be a string", 2) end
|
||||||
|
if type(s) ~= "string" then error("str.match: second argument must be a string", 2) end
|
||||||
|
|
||||||
|
local lua_pattern = pattern:gsub("\\d", "%%d"):gsub("\\w", "%%w"):gsub("\\s", "%%s")
|
||||||
|
return s:match(lua_pattern) ~= nil
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.find(pattern, s)
|
||||||
|
if type(pattern) ~= "string" then error("str.find: first argument must be a string", 2) end
|
||||||
|
if type(s) ~= "string" then error("str.find: second argument must be a string", 2) end
|
||||||
|
|
||||||
|
local lua_pattern = pattern:gsub("\\d", "%%d"):gsub("\\w", "%%w"):gsub("\\s", "%%s")
|
||||||
|
return s:match(lua_pattern)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.find_all(pattern, s)
|
||||||
|
if type(pattern) ~= "string" then error("str.find_all: first argument must be a string", 2) end
|
||||||
|
if type(s) ~= "string" then error("str.find_all: second argument must be a string", 2) end
|
||||||
|
|
||||||
|
local lua_pattern = pattern:gsub("\\d", "%%d"):gsub("\\w", "%%w"):gsub("\\s", "%%s")
|
||||||
|
local matches = {}
|
||||||
|
for match in s:gmatch(lua_pattern) do
|
||||||
|
table.insert(matches, match)
|
||||||
|
end
|
||||||
|
return matches
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.gsub(pattern, s, replacement)
|
||||||
|
if type(pattern) ~= "string" then error("str.gsub: first argument must be a string", 2) end
|
||||||
|
if type(s) ~= "string" then error("str.gsub: second argument must be a string", 2) end
|
||||||
|
if type(replacement) ~= "string" then error("str.gsub: third argument must be a string", 2) end
|
||||||
|
|
||||||
|
-- Use Go for complex regex, Lua for simple patterns
|
||||||
|
if pattern:match("[%[%]%(%)%{%}%|%\\%^%$]") then
|
||||||
|
-- Complex pattern, use Go
|
||||||
|
return moonshark.regex_replace(pattern, s, replacement)
|
||||||
|
else
|
||||||
|
-- Simple pattern, use Lua
|
||||||
|
local lua_pattern = pattern:gsub("\\d", "%%d"):gsub("\\w", "%%w"):gsub("\\s", "%%s")
|
||||||
|
return s:gsub(lua_pattern, replacement)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- TYPE CONVERSION & VALIDATION
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function str.to_number(s)
|
||||||
|
if type(s) ~= "string" then error("str.to_number: argument must be a string", 2) end
|
||||||
|
s = str.trim(s)
|
||||||
|
return tonumber(s)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.is_numeric(s)
|
||||||
|
if type(s) ~= "string" then error("str.is_numeric: argument must be a string", 2) end
|
||||||
|
s = str.trim(s)
|
||||||
|
return tonumber(s) ~= nil
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.is_alpha(s)
|
||||||
|
if type(s) ~= "string" then error("str.is_alpha: argument must be a string", 2) end
|
||||||
|
if #s == 0 then return false end
|
||||||
|
return s:match("^%a+$") ~= nil
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.is_alphanumeric(s)
|
||||||
|
if type(s) ~= "string" then error("str.is_alphanumeric: argument must be a string", 2) end
|
||||||
|
if #s == 0 then return false end
|
||||||
|
return s:match("^%w+$") ~= nil
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.is_empty(s)
|
||||||
|
return s == nil or s == ""
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.is_blank(s)
|
||||||
|
return str.is_empty(s) or str.trim(s) == ""
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.is_utf8(s)
|
||||||
|
if type(s) ~= "string" then error("str.is_utf8: argument must be a string", 2) end
|
||||||
|
return moonshark.string_is_valid_utf8(s)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- ADVANCED STRING OPERATIONS (Pure Lua)
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function str.capitalize(s)
|
||||||
|
if type(s) ~= "string" then error("str.capitalize: argument must be a string", 2) end
|
||||||
|
return s:gsub("(%a)([%w_']*)", function(first, rest)
|
||||||
|
return first:upper() .. rest:lower()
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.camel_case(s)
|
||||||
|
if type(s) ~= "string" then error("str.camel_case: argument must be a string", 2) end
|
||||||
|
local words = str.words(s)
|
||||||
|
if #words == 0 then return s end
|
||||||
|
local result = words[1]:lower()
|
||||||
|
for i = 2, #words do
|
||||||
|
result = result .. words[i]:sub(1,1):upper() .. words[i]:sub(2):lower()
|
||||||
|
end
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.pascal_case(s)
|
||||||
|
if type(s) ~= "string" then error("str.pascal_case: argument must be a string", 2) end
|
||||||
|
local words = str.words(s)
|
||||||
|
local result = ""
|
||||||
|
for _, word in ipairs(words) do
|
||||||
|
result = result .. word:sub(1,1):upper() .. word:sub(2):lower()
|
||||||
|
end
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.snake_case(s)
|
||||||
|
if type(s) ~= "string" then error("str.snake_case: argument must be a string", 2) end
|
||||||
|
local words = str.words(s)
|
||||||
|
local result = {}
|
||||||
|
for _, word in ipairs(words) do
|
||||||
|
table.insert(result, word:lower())
|
||||||
|
end
|
||||||
|
return table.concat(result, "_")
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.kebab_case(s)
|
||||||
|
if type(s) ~= "string" then error("str.kebab_case: argument must be a string", 2) end
|
||||||
|
local words = str.words(s)
|
||||||
|
local result = {}
|
||||||
|
for _, word in ipairs(words) do
|
||||||
|
table.insert(result, word:lower())
|
||||||
|
end
|
||||||
|
return table.concat(result, "-")
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.center(s, width, fill_char)
|
||||||
|
if type(s) ~= "string" then error("str.center: first argument must be a string", 2) end
|
||||||
|
if type(width) ~= "number" or width < 0 or width ~= math.floor(width) then
|
||||||
|
error("str.center: second argument must be a non-negative integer", 2)
|
||||||
|
end
|
||||||
|
fill_char = fill_char or " "
|
||||||
|
if type(fill_char) ~= "string" or #fill_char == 0 then
|
||||||
|
error("str.center: fill character must be a non-empty string", 2)
|
||||||
|
end
|
||||||
|
fill_char = fill_char:sub(1,1)
|
||||||
|
|
||||||
|
local len = str.length(s)
|
||||||
|
if len >= width then return s end
|
||||||
|
|
||||||
|
local pad_total = width - len
|
||||||
|
local pad_left = math.floor(pad_total / 2)
|
||||||
|
local pad_right = pad_total - pad_left
|
||||||
|
|
||||||
|
return string.rep(fill_char, pad_left) .. s .. string.rep(fill_char, pad_right)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.truncate(s, max_length, suffix)
|
||||||
|
if type(s) ~= "string" then error("str.truncate: first argument must be a string", 2) end
|
||||||
|
if type(max_length) ~= "number" or max_length < 0 or max_length ~= math.floor(max_length) then
|
||||||
|
error("str.truncate: second argument must be a non-negative integer", 2)
|
||||||
|
end
|
||||||
|
suffix = suffix or "..."
|
||||||
|
if type(suffix) ~= "string" then error("str.truncate: third argument must be a string", 2) end
|
||||||
|
|
||||||
|
local len = str.length(s)
|
||||||
|
if len <= max_length then return s end
|
||||||
|
|
||||||
|
local suffix_len = str.length(suffix)
|
||||||
|
if max_length <= suffix_len then
|
||||||
|
return str.slice(suffix, 1, max_length)
|
||||||
|
end
|
||||||
|
|
||||||
|
local main_part = str.slice(s, 1, max_length - suffix_len)
|
||||||
|
main_part = str.trim_right(main_part)
|
||||||
|
return main_part .. suffix
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.escape_regex(s)
|
||||||
|
if type(s) ~= "string" then error("str.escape_regex: argument must be a string", 2) end
|
||||||
|
return s:gsub("([%.%+%*%?%[%]%^%$%(%)%{%}%|%\\])", "\\%1")
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.url_encode(s)
|
||||||
|
if type(s) ~= "string" then error("str.url_encode: argument must be a string", 2) end
|
||||||
|
return s:gsub("([^%w%-%.%_%~])", function(c)
|
||||||
|
return string.format("%%%02X", string.byte(c))
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.url_decode(s)
|
||||||
|
if type(s) ~= "string" then error("str.url_decode: argument must be a string", 2) end
|
||||||
|
local result = s:gsub("%%(%x%x)", function(hex)
|
||||||
|
local byte = tonumber(hex, 16)
|
||||||
|
return byte and string.char(byte) or ("%" .. hex)
|
||||||
|
end):gsub("+", " ")
|
||||||
|
|
||||||
|
if not str.is_utf8(result) then
|
||||||
|
error("str.url_decode: result is not valid UTF-8", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.distance(a, b)
|
||||||
|
if type(a) ~= "string" then error("str.distance: first argument must be a string", 2) end
|
||||||
|
if type(b) ~= "string" then error("str.distance: second argument must be a string", 2) end
|
||||||
|
|
||||||
|
local len_a, len_b = str.length(a), str.length(b)
|
||||||
|
|
||||||
|
if len_a == 0 then return len_b end
|
||||||
|
if len_b == 0 then return len_a end
|
||||||
|
|
||||||
|
if len_a > 1000 or len_b > 1000 then
|
||||||
|
error("str.distance: strings too long for distance calculation", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
local matrix = {}
|
||||||
|
|
||||||
|
for i = 0, len_a do
|
||||||
|
matrix[i] = {[0] = i}
|
||||||
|
end
|
||||||
|
for j = 0, len_b do
|
||||||
|
matrix[0][j] = j
|
||||||
|
end
|
||||||
|
|
||||||
|
for i = 1, len_a do
|
||||||
|
for j = 1, len_b do
|
||||||
|
local cost = (str.slice(a, i, i) == str.slice(b, j, j)) and 0 or 1
|
||||||
|
matrix[i][j] = math.min(
|
||||||
|
matrix[i-1][j] + 1,
|
||||||
|
matrix[i][j-1] + 1,
|
||||||
|
matrix[i-1][j-1] + cost
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return matrix[len_a][len_b]
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.similarity(a, b)
|
||||||
|
if type(a) ~= "string" then error("str.similarity: first argument must be a string", 2) end
|
||||||
|
if type(b) ~= "string" then error("str.similarity: second argument must be a string", 2) end
|
||||||
|
|
||||||
|
local max_len = math.max(str.length(a), str.length(b))
|
||||||
|
if max_len == 0 then return 1.0 end
|
||||||
|
|
||||||
|
local dist = str.distance(a, b)
|
||||||
|
return 1.0 - (dist / max_len)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.template(template, vars)
|
||||||
|
if type(template) ~= "string" then error("str.template: first argument must be a string", 2) end
|
||||||
|
vars = vars or {}
|
||||||
|
if type(vars) ~= "table" then error("str.template: second argument must be a table", 2) end
|
||||||
|
|
||||||
|
return template:gsub("%${([%w_]+)}", function(var)
|
||||||
|
local value = vars[var]
|
||||||
|
return value ~= nil and tostring(value) or ""
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.random(length, charset)
|
||||||
|
if type(length) ~= "number" or length < 0 or length ~= math.floor(length) then
|
||||||
|
error("str.random: first argument must be a non-negative integer", 2)
|
||||||
|
end
|
||||||
|
if charset ~= nil and type(charset) ~= "string" then
|
||||||
|
error("str.random: second argument must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
charset = charset or "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
local result = {}
|
||||||
|
|
||||||
|
math.randomseed(os.time() + os.clock() * 1000000)
|
||||||
|
|
||||||
|
for i = 1, length do
|
||||||
|
local rand_index = math.random(1, #charset)
|
||||||
|
result[i] = charset:sub(rand_index, rand_index)
|
||||||
|
end
|
||||||
|
|
||||||
|
return table.concat(result)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.slug(s)
|
||||||
|
if type(s) ~= "string" then error("str.slug: argument must be a string", 2) end
|
||||||
|
|
||||||
|
local result = str.remove_accents(s):lower()
|
||||||
|
result = result:gsub("[^%w%s]", "")
|
||||||
|
result = result:gsub("%s+", "-")
|
||||||
|
result = result:gsub("^%-+", ""):gsub("%-+$", "")
|
||||||
|
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Add these functions to the end of string.lua, before the return statement
|
||||||
|
|
||||||
|
function str.screaming_snake_case(s)
|
||||||
|
if type(s) ~= "string" then error("str.screaming_snake_case: argument must be a string", 2) end
|
||||||
|
return str.snake_case(s):upper()
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.wrap(s, width)
|
||||||
|
if type(s) ~= "string" then error("str.wrap: first argument must be a string", 2) end
|
||||||
|
if type(width) ~= "number" or width <= 0 then error("str.wrap: width must be positive number", 2) end
|
||||||
|
|
||||||
|
local words = str.words(s)
|
||||||
|
local lines = {}
|
||||||
|
local current_line = ""
|
||||||
|
|
||||||
|
for _, word in ipairs(words) do
|
||||||
|
if current_line == "" then
|
||||||
|
current_line = word
|
||||||
|
elseif str.length(current_line .. " " .. word) <= width then
|
||||||
|
current_line = current_line .. " " .. word
|
||||||
|
else
|
||||||
|
table.insert(lines, current_line)
|
||||||
|
current_line = word
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if current_line ~= "" then
|
||||||
|
table.insert(lines, current_line)
|
||||||
|
end
|
||||||
|
|
||||||
|
return lines
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.dedent(s)
|
||||||
|
if type(s) ~= "string" then error("str.dedent: argument must be a string", 2) end
|
||||||
|
|
||||||
|
local lines = str.lines(s)
|
||||||
|
if #lines == 0 then return "" end
|
||||||
|
|
||||||
|
-- Find minimum indentation
|
||||||
|
local min_indent = math.huge
|
||||||
|
for _, line in ipairs(lines) do
|
||||||
|
if line:match("%S") then -- Non-empty line
|
||||||
|
local indent = line:match("^(%s*)")
|
||||||
|
min_indent = math.min(min_indent, #indent)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if min_indent == math.huge then min_indent = 0 end
|
||||||
|
|
||||||
|
-- Remove common indentation
|
||||||
|
local result = {}
|
||||||
|
for _, line in ipairs(lines) do
|
||||||
|
table.insert(result, line:sub(min_indent + 1))
|
||||||
|
end
|
||||||
|
|
||||||
|
return table.concat(result, "\n")
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.shell_quote(s)
|
||||||
|
if type(s) ~= "string" then error("str.shell_quote: argument must be a string", 2) end
|
||||||
|
|
||||||
|
if s:match("^[%w%-%./]+$") then
|
||||||
|
return s -- No quoting needed
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Replace single quotes with '"'"'
|
||||||
|
local quoted = s:gsub("'", "'\"'\"'")
|
||||||
|
return "'" .. quoted .. "'"
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.iequals(a, b)
|
||||||
|
if type(a) ~= "string" then error("str.iequals: first argument must be a string", 2) end
|
||||||
|
if type(b) ~= "string" then error("str.iequals: second argument must be a string", 2) end
|
||||||
|
return str.lower(a) == str.lower(b)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.template_advanced(template, context)
|
||||||
|
if type(template) ~= "string" then error("str.template_advanced: first argument must be a string", 2) end
|
||||||
|
context = context or {}
|
||||||
|
if type(context) ~= "table" then error("str.template_advanced: second argument must be a table", 2) end
|
||||||
|
|
||||||
|
return template:gsub("%${([%w_.]+)}", function(path)
|
||||||
|
local keys = str.split(path, ".")
|
||||||
|
local value = context
|
||||||
|
|
||||||
|
for _, key in ipairs(keys) do
|
||||||
|
if type(value) == "table" and value[key] ~= nil then
|
||||||
|
value = value[key]
|
||||||
|
else
|
||||||
|
return ""
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return tostring(value)
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.is_whitespace(s)
|
||||||
|
if type(s) ~= "string" then error("str.is_whitespace: argument must be a string", 2) end
|
||||||
|
return s:match("^%s*$") ~= nil
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.strip_whitespace(s)
|
||||||
|
if type(s) ~= "string" then error("str.strip_whitespace: argument must be a string", 2) end
|
||||||
|
return s:gsub("%s", "")
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.normalize_whitespace(s)
|
||||||
|
if type(s) ~= "string" then error("str.normalize_whitespace: argument must be a string", 2) end
|
||||||
|
return str.trim(s:gsub("%s+", " "))
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.extract_numbers(s)
|
||||||
|
if type(s) ~= "string" then error("str.extract_numbers: argument must be a string", 2) end
|
||||||
|
|
||||||
|
local numbers = {}
|
||||||
|
for match in s:gmatch("%-?%d+%.?%d*") do
|
||||||
|
local num = tonumber(match)
|
||||||
|
if num then
|
||||||
|
table.insert(numbers, num)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return numbers
|
||||||
|
end
|
||||||
|
|
||||||
|
function str.remove_accents(s)
|
||||||
|
if type(s) ~= "string" then error("str.remove_accents: argument must be a string", 2) end
|
||||||
|
|
||||||
|
local accents = {
|
||||||
|
["à"] = "a", ["á"] = "a", ["â"] = "a", ["ã"] = "a", ["ä"] = "a", ["å"] = "a",
|
||||||
|
["è"] = "e", ["é"] = "e", ["ê"] = "e", ["ë"] = "e",
|
||||||
|
["ì"] = "i", ["í"] = "i", ["î"] = "i", ["ï"] = "i",
|
||||||
|
["ò"] = "o", ["ó"] = "o", ["ô"] = "o", ["õ"] = "o", ["ö"] = "o",
|
||||||
|
["ù"] = "u", ["ú"] = "u", ["û"] = "u", ["ü"] = "u",
|
||||||
|
["ñ"] = "n", ["ç"] = "c", ["ÿ"] = "y",
|
||||||
|
["À"] = "A", ["Á"] = "A", ["Â"] = "A", ["Ã"] = "A", ["Ä"] = "A", ["Å"] = "A",
|
||||||
|
["È"] = "E", ["É"] = "E", ["Ê"] = "E", ["Ë"] = "E",
|
||||||
|
["Ì"] = "I", ["Í"] = "I", ["Î"] = "I", ["Ï"] = "I",
|
||||||
|
["Ò"] = "O", ["Ó"] = "O", ["Ô"] = "O", ["Õ"] = "O", ["Ö"] = "O",
|
||||||
|
["Ù"] = "U", ["Ú"] = "U", ["Û"] = "U", ["Ü"] = "U",
|
||||||
|
["Ñ"] = "N", ["Ç"] = "C", ["Ÿ"] = "Y"
|
||||||
|
}
|
||||||
|
|
||||||
|
local result = s
|
||||||
|
for accented, plain in pairs(accents) do
|
||||||
|
result = result:gsub(accented, plain)
|
||||||
|
end
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
return str
|
||||||
@ -1,56 +1,61 @@
|
|||||||
local orig_insert = table.insert
|
local tbl = {}
|
||||||
local orig_remove = table.remove
|
|
||||||
local orig_concat = table.concat
|
|
||||||
local orig_sort = table.sort
|
|
||||||
|
|
||||||
function table.insert(t, pos, value)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.insert: first argument must be a table", 2) end
|
-- BUILT-IN TABLE FUNCTIONS (Lua 5.1 wrappers for consistency)
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.insert(t, pos, value)
|
||||||
|
if type(t) ~= "table" then error("tbl.insert: first argument must be a table", 2) end
|
||||||
|
|
||||||
if value == nil then
|
if value == nil then
|
||||||
-- table.insert(t, value) form
|
-- table.insert(t, value) form
|
||||||
orig_insert(t, pos)
|
table.insert(t, pos)
|
||||||
else
|
else
|
||||||
-- table.insert(t, pos, value) form
|
-- table.insert(t, pos, value) form
|
||||||
if type(pos) ~= "number" or pos ~= math.floor(pos) then
|
if type(pos) ~= "number" or pos ~= math.floor(pos) then
|
||||||
error("table.insert: position must be an integer", 2)
|
error("tbl.insert: position must be an integer", 2)
|
||||||
end
|
end
|
||||||
orig_insert(t, pos, value)
|
table.insert(t, pos, value)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.remove(t, pos)
|
function tbl.remove(t, pos)
|
||||||
if type(t) ~= "table" then error("table.remove: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.remove: first argument must be a table", 2) end
|
||||||
if pos ~= nil and (type(pos) ~= "number" or pos ~= math.floor(pos)) then
|
if pos ~= nil and (type(pos) ~= "number" or pos ~= math.floor(pos)) then
|
||||||
error("table.remove: position must be an integer", 2)
|
error("tbl.remove: position must be an integer", 2)
|
||||||
end
|
end
|
||||||
return orig_remove(t, pos)
|
return table.remove(t, pos)
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.concat(t, sep, start_idx, end_idx)
|
function tbl.concat(t, sep, start_idx, end_idx)
|
||||||
if type(t) ~= "table" then error("table.concat: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.concat: first argument must be a table", 2) end
|
||||||
if sep ~= nil and type(sep) ~= "string" then error("table.concat: separator must be a string", 2) end
|
if sep ~= nil and type(sep) ~= "string" then error("tbl.concat: separator must be a string", 2) end
|
||||||
if start_idx ~= nil and (type(start_idx) ~= "number" or start_idx ~= math.floor(start_idx)) then
|
if start_idx ~= nil and (type(start_idx) ~= "number" or start_idx ~= math.floor(start_idx)) then
|
||||||
error("table.concat: start index must be an integer", 2)
|
error("tbl.concat: start index must be an integer", 2)
|
||||||
end
|
end
|
||||||
if end_idx ~= nil and (type(end_idx) ~= "number" or end_idx ~= math.floor(end_idx)) then
|
if end_idx ~= nil and (type(end_idx) ~= "number" or end_idx ~= math.floor(end_idx)) then
|
||||||
error("table.concat: end index must be an integer", 2)
|
error("tbl.concat: end index must be an integer", 2)
|
||||||
end
|
end
|
||||||
return orig_concat(t, sep, start_idx, end_idx)
|
return table.concat(t, sep, start_idx, end_idx)
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.sort(t, comp)
|
function tbl.sort(t, comp)
|
||||||
if type(t) ~= "table" then error("table.sort: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.sort: first argument must be a table", 2) end
|
||||||
if comp ~= nil and type(comp) ~= "function" then error("table.sort: comparator must be a function", 2) end
|
if comp ~= nil and type(comp) ~= "function" then error("tbl.sort: comparator must be a function", 2) end
|
||||||
orig_sort(t, comp)
|
table.sort(t, comp)
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.length(t)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.length: argument must be a table", 2) end
|
-- BASIC TABLE OPERATIONS
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.length(t)
|
||||||
|
if type(t) ~= "table" then error("tbl.length: argument must be a table", 2) end
|
||||||
return #t
|
return #t
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.size(t)
|
function tbl.size(t)
|
||||||
if type(t) ~= "table" then error("table.size: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.size: argument must be a table", 2) end
|
||||||
local count = 0
|
local count = 0
|
||||||
for _ in pairs(t) do
|
for _ in pairs(t) do
|
||||||
count = count + 1
|
count = count + 1
|
||||||
@ -58,14 +63,14 @@ function table.size(t)
|
|||||||
return count
|
return count
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.is_empty(t)
|
function tbl.is_empty(t)
|
||||||
if type(t) ~= "table" then error("table.is_empty: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.is_empty: argument must be a table", 2) end
|
||||||
return next(t) == nil
|
return next(t) == nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.is_array(t)
|
function tbl.is_array(t)
|
||||||
if type(t) ~= "table" then error("table.is_array: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.is_array: argument must be a table", 2) end
|
||||||
if table.is_empty(t) then return true end
|
if tbl.is_empty(t) then return true end
|
||||||
|
|
||||||
local max_index = 0
|
local max_index = 0
|
||||||
local count = 0
|
local count = 0
|
||||||
@ -79,15 +84,15 @@ function table.is_array(t)
|
|||||||
return max_index == count
|
return max_index == count
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.clear(t)
|
function tbl.clear(t)
|
||||||
if type(t) ~= "table" then error("table.clear: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.clear: argument must be a table", 2) end
|
||||||
for k in pairs(t) do
|
for k in pairs(t) do
|
||||||
t[k] = nil
|
t[k] = nil
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.clone(t)
|
function tbl.clone(t)
|
||||||
if type(t) ~= "table" then error("table.clone: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.clone: argument must be a table", 2) end
|
||||||
local result = {}
|
local result = {}
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
result[k] = v
|
result[k] = v
|
||||||
@ -95,8 +100,8 @@ function table.clone(t)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.deep_copy(t)
|
function tbl.deep_copy(t)
|
||||||
if type(t) ~= "table" then error("table.deep_copy: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.deep_copy: argument must be a table", 2) end
|
||||||
|
|
||||||
local function copy_recursive(obj, seen)
|
local function copy_recursive(obj, seen)
|
||||||
if type(obj) ~= "table" then return obj end
|
if type(obj) ~= "table" then return obj end
|
||||||
@ -115,25 +120,29 @@ function table.deep_copy(t)
|
|||||||
return copy_recursive(t, {})
|
return copy_recursive(t, {})
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.contains(t, value)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.contains: first argument must be a table", 2) end
|
-- SEARCHING AND FINDING
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.contains(t, value)
|
||||||
|
if type(t) ~= "table" then error("tbl.contains: first argument must be a table", 2) end
|
||||||
for _, v in pairs(t) do
|
for _, v in pairs(t) do
|
||||||
if v == value then return true end
|
if v == value then return true end
|
||||||
end
|
end
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.index_of(t, value)
|
function tbl.index_of(t, value)
|
||||||
if type(t) ~= "table" then error("table.index_of: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.index_of: first argument must be a table", 2) end
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
if v == value then return k end
|
if v == value then return k end
|
||||||
end
|
end
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.find(t, predicate)
|
function tbl.find(t, predicate)
|
||||||
if type(t) ~= "table" then error("table.find: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.find: first argument must be a table", 2) end
|
||||||
if type(predicate) ~= "function" then error("table.find: second argument must be a function", 2) end
|
if type(predicate) ~= "function" then error("tbl.find: second argument must be a function", 2) end
|
||||||
|
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
if predicate(v, k, t) then return v, k end
|
if predicate(v, k, t) then return v, k end
|
||||||
@ -141,9 +150,9 @@ function table.find(t, predicate)
|
|||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.find_index(t, predicate)
|
function tbl.find_index(t, predicate)
|
||||||
if type(t) ~= "table" then error("table.find_index: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.find_index: first argument must be a table", 2) end
|
||||||
if type(predicate) ~= "function" then error("table.find_index: second argument must be a function", 2) end
|
if type(predicate) ~= "function" then error("tbl.find_index: second argument must be a function", 2) end
|
||||||
|
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
if predicate(v, k, t) then return k end
|
if predicate(v, k, t) then return k end
|
||||||
@ -151,8 +160,8 @@ function table.find_index(t, predicate)
|
|||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.count(t, value_or_predicate)
|
function tbl.count(t, value_or_predicate)
|
||||||
if type(t) ~= "table" then error("table.count: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.count: first argument must be a table", 2) end
|
||||||
|
|
||||||
local count = 0
|
local count = 0
|
||||||
if type(value_or_predicate) == "function" then
|
if type(value_or_predicate) == "function" then
|
||||||
@ -167,12 +176,16 @@ function table.count(t, value_or_predicate)
|
|||||||
return count
|
return count
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.filter(t, predicate)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.filter: first argument must be a table", 2) end
|
-- FILTERING AND MAPPING
|
||||||
if type(predicate) ~= "function" then error("table.filter: second argument must be a function", 2) end
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.filter(t, predicate)
|
||||||
|
if type(t) ~= "table" then error("tbl.filter: first argument must be a table", 2) end
|
||||||
|
if type(predicate) ~= "function" then error("tbl.filter: second argument must be a function", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
if table.is_array(t) then
|
if tbl.is_array(t) then
|
||||||
local max_index = 0
|
local max_index = 0
|
||||||
for k in pairs(t) do
|
for k in pairs(t) do
|
||||||
if type(k) == "number" and k > max_index then
|
if type(k) == "number" and k > max_index then
|
||||||
@ -182,7 +195,7 @@ function table.filter(t, predicate)
|
|||||||
for i = 1, max_index do
|
for i = 1, max_index do
|
||||||
local v = t[i]
|
local v = t[i]
|
||||||
if v ~= nil and predicate(v, i, t) then
|
if v ~= nil and predicate(v, i, t) then
|
||||||
orig_insert(result, v)
|
table.insert(result, v)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
@ -195,16 +208,16 @@ function table.filter(t, predicate)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.reject(t, predicate)
|
function tbl.reject(t, predicate)
|
||||||
if type(t) ~= "table" then error("table.reject: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.reject: first argument must be a table", 2) end
|
||||||
if type(predicate) ~= "function" then error("table.reject: second argument must be a function", 2) end
|
if type(predicate) ~= "function" then error("tbl.reject: second argument must be a function", 2) end
|
||||||
|
|
||||||
return table.filter(t, function(v, k, tbl) return not predicate(v, k, tbl) end)
|
return tbl.filter(t, function(v, k, tbl) return not predicate(v, k, tbl) end)
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.map(t, transformer)
|
function tbl.map(t, transformer)
|
||||||
if type(t) ~= "table" then error("table.map: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.map: first argument must be a table", 2) end
|
||||||
if type(transformer) ~= "function" then error("table.map: second argument must be a function", 2) end
|
if type(transformer) ~= "function" then error("tbl.map: second argument must be a function", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
@ -213,9 +226,9 @@ function table.map(t, transformer)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.map_values(t, transformer)
|
function tbl.map_values(t, transformer)
|
||||||
if type(t) ~= "table" then error("table.map_values: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.map_values: first argument must be a table", 2) end
|
||||||
if type(transformer) ~= "function" then error("table.map_values: second argument must be a function", 2) end
|
if type(transformer) ~= "function" then error("tbl.map_values: second argument must be a function", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
@ -224,9 +237,9 @@ function table.map_values(t, transformer)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.map_keys(t, transformer)
|
function tbl.map_keys(t, transformer)
|
||||||
if type(t) ~= "table" then error("table.map_keys: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.map_keys: first argument must be a table", 2) end
|
||||||
if type(transformer) ~= "function" then error("table.map_keys: second argument must be a function", 2) end
|
if type(transformer) ~= "function" then error("tbl.map_keys: second argument must be a function", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
@ -236,9 +249,13 @@ function table.map_keys(t, transformer)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.reduce(t, reducer, initial)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.reduce: first argument must be a table", 2) end
|
-- REDUCING AND AGGREGATING
|
||||||
if type(reducer) ~= "function" then error("table.reduce: second argument must be a function", 2) end
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.reduce(t, reducer, initial)
|
||||||
|
if type(t) ~= "table" then error("tbl.reduce: first argument must be a table", 2) end
|
||||||
|
if type(reducer) ~= "function" then error("tbl.reduce: second argument must be a function", 2) end
|
||||||
|
|
||||||
local accumulator = initial
|
local accumulator = initial
|
||||||
local started = initial ~= nil
|
local started = initial ~= nil
|
||||||
@ -253,39 +270,39 @@ function table.reduce(t, reducer, initial)
|
|||||||
end
|
end
|
||||||
|
|
||||||
if not started then
|
if not started then
|
||||||
error("table.reduce: empty table with no initial value", 2)
|
error("tbl.reduce: empty table with no initial value", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
return accumulator
|
return accumulator
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.sum(t)
|
function tbl.sum(t)
|
||||||
if type(t) ~= "table" then error("table.sum: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.sum: argument must be a table", 2) end
|
||||||
local total = 0
|
local total = 0
|
||||||
for _, v in pairs(t) do
|
for _, v in pairs(t) do
|
||||||
if type(v) ~= "number" then error("table.sum: all values must be numbers", 2) end
|
if type(v) ~= "number" then error("tbl.sum: all values must be numbers", 2) end
|
||||||
total = total + v
|
total = total + v
|
||||||
end
|
end
|
||||||
return total
|
return total
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.product(t)
|
function tbl.product(t)
|
||||||
if type(t) ~= "table" then error("table.product: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.product: argument must be a table", 2) end
|
||||||
local result = 1
|
local result = 1
|
||||||
for _, v in pairs(t) do
|
for _, v in pairs(t) do
|
||||||
if type(v) ~= "number" then error("table.product: all values must be numbers", 2) end
|
if type(v) ~= "number" then error("tbl.product: all values must be numbers", 2) end
|
||||||
result = result * v
|
result = result * v
|
||||||
end
|
end
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.min(t)
|
function tbl.min(t)
|
||||||
if type(t) ~= "table" then error("table.min: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.min: argument must be a table", 2) end
|
||||||
if table.is_empty(t) then error("table.min: table is empty", 2) end
|
if tbl.is_empty(t) then error("tbl.min: table is empty", 2) end
|
||||||
|
|
||||||
local min_val = nil
|
local min_val = nil
|
||||||
for _, v in pairs(t) do
|
for _, v in pairs(t) do
|
||||||
if type(v) ~= "number" then error("table.min: all values must be numbers", 2) end
|
if type(v) ~= "number" then error("tbl.min: all values must be numbers", 2) end
|
||||||
if min_val == nil or v < min_val then
|
if min_val == nil or v < min_val then
|
||||||
min_val = v
|
min_val = v
|
||||||
end
|
end
|
||||||
@ -293,13 +310,13 @@ function table.min(t)
|
|||||||
return min_val
|
return min_val
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.max(t)
|
function tbl.max(t)
|
||||||
if type(t) ~= "table" then error("table.max: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.max: argument must be a table", 2) end
|
||||||
if table.is_empty(t) then error("table.max: table is empty", 2) end
|
if tbl.is_empty(t) then error("tbl.max: table is empty", 2) end
|
||||||
|
|
||||||
local max_val = nil
|
local max_val = nil
|
||||||
for _, v in pairs(t) do
|
for _, v in pairs(t) do
|
||||||
if type(v) ~= "number" then error("table.max: all values must be numbers", 2) end
|
if type(v) ~= "number" then error("tbl.max: all values must be numbers", 2) end
|
||||||
if max_val == nil or v > max_val then
|
if max_val == nil or v > max_val then
|
||||||
max_val = v
|
max_val = v
|
||||||
end
|
end
|
||||||
@ -307,17 +324,21 @@ function table.max(t)
|
|||||||
return max_val
|
return max_val
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.average(t)
|
function tbl.average(t)
|
||||||
if type(t) ~= "table" then error("table.average: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.average: argument must be a table", 2) end
|
||||||
if table.is_empty(t) then error("table.average: table is empty", 2) end
|
if tbl.is_empty(t) then error("tbl.average: table is empty", 2) end
|
||||||
return table.sum(t) / table.size(t)
|
return tbl.sum(t) / tbl.size(t)
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.all(t, predicate)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.all: first argument must be a table", 2) end
|
-- BOOLEAN OPERATIONS
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.all(t, predicate)
|
||||||
|
if type(t) ~= "table" then error("tbl.all: first argument must be a table", 2) end
|
||||||
|
|
||||||
if predicate then
|
if predicate then
|
||||||
if type(predicate) ~= "function" then error("table.all: second argument must be a function", 2) end
|
if type(predicate) ~= "function" then error("tbl.all: second argument must be a function", 2) end
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
if not predicate(v, k, t) then return false end
|
if not predicate(v, k, t) then return false end
|
||||||
end
|
end
|
||||||
@ -329,11 +350,11 @@ function table.all(t, predicate)
|
|||||||
return true
|
return true
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.any(t, predicate)
|
function tbl.any(t, predicate)
|
||||||
if type(t) ~= "table" then error("table.any: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.any: first argument must be a table", 2) end
|
||||||
|
|
||||||
if predicate then
|
if predicate then
|
||||||
if type(predicate) ~= "function" then error("table.any: second argument must be a function", 2) end
|
if type(predicate) ~= "function" then error("tbl.any: second argument must be a function", 2) end
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
if predicate(v, k, t) then return true end
|
if predicate(v, k, t) then return true end
|
||||||
end
|
end
|
||||||
@ -345,22 +366,26 @@ function table.any(t, predicate)
|
|||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.none(t, predicate)
|
function tbl.none(t, predicate)
|
||||||
if type(t) ~= "table" then error("table.none: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.none: first argument must be a table", 2) end
|
||||||
return not table.any(t, predicate)
|
return not tbl.any(t, predicate)
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.unique(t)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.unique: argument must be a table", 2) end
|
-- SET OPERATIONS
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.unique(t)
|
||||||
|
if type(t) ~= "table" then error("tbl.unique: argument must be a table", 2) end
|
||||||
|
|
||||||
local seen = {}
|
local seen = {}
|
||||||
local result = {}
|
local result = {}
|
||||||
|
|
||||||
if table.is_array(t) then
|
if tbl.is_array(t) then
|
||||||
for _, v in ipairs(t) do
|
for _, v in ipairs(t) do
|
||||||
if not seen[v] then
|
if not seen[v] then
|
||||||
seen[v] = true
|
seen[v] = true
|
||||||
orig_insert(result, v)
|
table.insert(result, v)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
@ -375,9 +400,9 @@ function table.unique(t)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.intersection(t1, t2)
|
function tbl.intersection(t1, t2)
|
||||||
if type(t1) ~= "table" then error("table.intersection: first argument must be a table", 2) end
|
if type(t1) ~= "table" then error("tbl.intersection: first argument must be a table", 2) end
|
||||||
if type(t2) ~= "table" then error("table.intersection: second argument must be a table", 2) end
|
if type(t2) ~= "table" then error("tbl.intersection: second argument must be a table", 2) end
|
||||||
|
|
||||||
local set2 = {}
|
local set2 = {}
|
||||||
for _, v in pairs(t2) do
|
for _, v in pairs(t2) do
|
||||||
@ -385,10 +410,10 @@ function table.intersection(t1, t2)
|
|||||||
end
|
end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
if table.is_array(t1) then
|
if tbl.is_array(t1) then
|
||||||
for _, v in ipairs(t1) do
|
for _, v in ipairs(t1) do
|
||||||
if set2[v] then
|
if set2[v] then
|
||||||
orig_insert(result, v)
|
table.insert(result, v)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
@ -402,16 +427,16 @@ function table.intersection(t1, t2)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.union(t1, t2)
|
function tbl.union(t1, t2)
|
||||||
if type(t1) ~= "table" then error("table.union: first argument must be a table", 2) end
|
if type(t1) ~= "table" then error("tbl.union: first argument must be a table", 2) end
|
||||||
if type(t2) ~= "table" then error("table.union: second argument must be a table", 2) end
|
if type(t2) ~= "table" then error("tbl.union: second argument must be a table", 2) end
|
||||||
|
|
||||||
local result = table.clone(t1)
|
local result = tbl.clone(t1)
|
||||||
|
|
||||||
if table.is_array(t1) and table.is_array(t2) then
|
if tbl.is_array(t1) and tbl.is_array(t2) then
|
||||||
for _, v in ipairs(t2) do
|
for _, v in ipairs(t2) do
|
||||||
if not table.contains(result, v) then
|
if not tbl.contains(result, v) then
|
||||||
orig_insert(result, v)
|
table.insert(result, v)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
@ -425,21 +450,25 @@ function table.union(t1, t2)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.difference(t1, t2)
|
function tbl.difference(t1, t2)
|
||||||
if type(t1) ~= "table" then error("table.difference: first argument must be a table", 2) end
|
if type(t1) ~= "table" then error("tbl.difference: first argument must be a table", 2) end
|
||||||
if type(t2) ~= "table" then error("table.difference: second argument must be a table", 2) end
|
if type(t2) ~= "table" then error("tbl.difference: second argument must be a table", 2) end
|
||||||
|
|
||||||
local set2 = {}
|
local set2 = {}
|
||||||
for _, v in pairs(t2) do
|
for _, v in pairs(t2) do
|
||||||
set2[v] = true
|
set2[v] = true
|
||||||
end
|
end
|
||||||
|
|
||||||
return table.filter(t1, function(v) return not set2[v] end)
|
return tbl.filter(t1, function(v) return not set2[v] end)
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.reverse(t)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.reverse: argument must be a table", 2) end
|
-- ARRAY OPERATIONS
|
||||||
if not table.is_array(t) then error("table.reverse: argument must be an array", 2) end
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.reverse(t)
|
||||||
|
if type(t) ~= "table" then error("tbl.reverse: argument must be a table", 2) end
|
||||||
|
if not tbl.is_array(t) then error("tbl.reverse: argument must be an array", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
local len = #t
|
local len = #t
|
||||||
@ -449,11 +478,11 @@ function table.reverse(t)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.shuffle(t)
|
function tbl.shuffle(t)
|
||||||
if type(t) ~= "table" then error("table.shuffle: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.shuffle: argument must be a table", 2) end
|
||||||
if not table.is_array(t) then error("table.shuffle: argument must be an array", 2) end
|
if not tbl.is_array(t) then error("tbl.shuffle: argument must be an array", 2) end
|
||||||
|
|
||||||
local result = table.clone(t)
|
local result = tbl.clone(t)
|
||||||
local len = #result
|
local len = #result
|
||||||
|
|
||||||
math.randomseed(os.time() + os.clock() * 1000000)
|
math.randomseed(os.time() + os.clock() * 1000000)
|
||||||
@ -466,18 +495,18 @@ function table.shuffle(t)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.rotate(t, positions)
|
function tbl.rotate(t, positions)
|
||||||
if type(t) ~= "table" then error("table.rotate: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.rotate: first argument must be a table", 2) end
|
||||||
if not table.is_array(t) then error("table.rotate: first argument must be an array", 2) end
|
if not tbl.is_array(t) then error("tbl.rotate: first argument must be an array", 2) end
|
||||||
if type(positions) ~= "number" or positions ~= math.floor(positions) then
|
if type(positions) ~= "number" or positions ~= math.floor(positions) then
|
||||||
error("table.rotate: second argument must be an integer", 2)
|
error("tbl.rotate: second argument must be an integer", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local len = #t
|
local len = #t
|
||||||
if len == 0 then return {} end
|
if len == 0 then return {} end
|
||||||
|
|
||||||
positions = positions % len
|
positions = positions % len
|
||||||
if positions == 0 then return table.clone(t) end
|
if positions == 0 then return tbl.clone(t) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
for i = 1, len do
|
for i = 1, len do
|
||||||
@ -488,14 +517,14 @@ function table.rotate(t, positions)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.slice(t, start_idx, end_idx)
|
function tbl.slice(t, start_idx, end_idx)
|
||||||
if type(t) ~= "table" then error("table.slice: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.slice: first argument must be a table", 2) end
|
||||||
if not table.is_array(t) then error("table.slice: first argument must be an array", 2) end
|
if not tbl.is_array(t) then error("tbl.slice: first argument must be an array", 2) end
|
||||||
if type(start_idx) ~= "number" or start_idx ~= math.floor(start_idx) then
|
if type(start_idx) ~= "number" or start_idx ~= math.floor(start_idx) then
|
||||||
error("table.slice: start index must be an integer", 2)
|
error("tbl.slice: start index must be an integer", 2)
|
||||||
end
|
end
|
||||||
if end_idx ~= nil and (type(end_idx) ~= "number" or end_idx ~= math.floor(end_idx)) then
|
if end_idx ~= nil and (type(end_idx) ~= "number" or end_idx ~= math.floor(end_idx)) then
|
||||||
error("table.slice: end index must be an integer", 2)
|
error("tbl.slice: end index must be an integer", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local len = #t
|
local len = #t
|
||||||
@ -507,20 +536,20 @@ function table.slice(t, start_idx, end_idx)
|
|||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
for i = start_idx, end_idx do
|
for i = start_idx, end_idx do
|
||||||
orig_insert(result, t[i])
|
table.insert(result, t[i])
|
||||||
end
|
end
|
||||||
|
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.splice(t, start_idx, delete_count, ...)
|
function tbl.splice(t, start_idx, delete_count, ...)
|
||||||
if type(t) ~= "table" then error("table.splice: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.splice: first argument must be a table", 2) end
|
||||||
if not table.is_array(t) then error("table.splice: first argument must be an array", 2) end
|
if not tbl.is_array(t) then error("tbl.splice: first argument must be an array", 2) end
|
||||||
if type(start_idx) ~= "number" or start_idx ~= math.floor(start_idx) then
|
if type(start_idx) ~= "number" or start_idx ~= math.floor(start_idx) then
|
||||||
error("table.splice: start index must be an integer", 2)
|
error("tbl.splice: start index must be an integer", 2)
|
||||||
end
|
end
|
||||||
if delete_count ~= nil and (type(delete_count) ~= "number" or delete_count ~= math.floor(delete_count) or delete_count < 0) then
|
if delete_count ~= nil and (type(delete_count) ~= "number" or delete_count ~= math.floor(delete_count) or delete_count < 0) then
|
||||||
error("table.splice: delete count must be a non-negative integer", 2)
|
error("tbl.splice: delete count must be a non-negative integer", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local len = #t
|
local len = #t
|
||||||
@ -564,22 +593,26 @@ function table.splice(t, start_idx, delete_count, ...)
|
|||||||
return deleted
|
return deleted
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.sort_by(t, key_func)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.sort_by: first argument must be a table", 2) end
|
-- SORTING HELPERS
|
||||||
if not table.is_array(t) then error("table.sort_by: first argument must be an array", 2) end
|
-- ======================================================================
|
||||||
if type(key_func) ~= "function" then error("table.sort_by: second argument must be a function", 2) end
|
|
||||||
|
|
||||||
local result = table.clone(t)
|
function tbl.sort_by(t, key_func)
|
||||||
orig_sort(result, function(a, b)
|
if type(t) ~= "table" then error("tbl.sort_by: first argument must be a table", 2) end
|
||||||
|
if not tbl.is_array(t) then error("tbl.sort_by: first argument must be an array", 2) end
|
||||||
|
if type(key_func) ~= "function" then error("tbl.sort_by: second argument must be a function", 2) end
|
||||||
|
|
||||||
|
local result = tbl.clone(t)
|
||||||
|
table.sort(result, function(a, b)
|
||||||
return key_func(a) < key_func(b)
|
return key_func(a) < key_func(b)
|
||||||
end)
|
end)
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.is_sorted(t, comp)
|
function tbl.is_sorted(t, comp)
|
||||||
if type(t) ~= "table" then error("table.is_sorted: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.is_sorted: first argument must be a table", 2) end
|
||||||
if not table.is_array(t) then error("table.is_sorted: first argument must be an array", 2) end
|
if not tbl.is_array(t) then error("tbl.is_sorted: first argument must be an array", 2) end
|
||||||
if comp ~= nil and type(comp) ~= "function" then error("table.is_sorted: comparator must be a function", 2) end
|
if comp ~= nil and type(comp) ~= "function" then error("tbl.is_sorted: comparator must be a function", 2) end
|
||||||
|
|
||||||
comp = comp or function(a, b) return a < b end
|
comp = comp or function(a, b) return a < b end
|
||||||
|
|
||||||
@ -591,43 +624,47 @@ function table.is_sorted(t, comp)
|
|||||||
return true
|
return true
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.keys(t)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.keys: argument must be a table", 2) end
|
-- UTILITY FUNCTIONS
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.keys(t)
|
||||||
|
if type(t) ~= "table" then error("tbl.keys: argument must be a table", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
for k, _ in pairs(t) do
|
for k, _ in pairs(t) do
|
||||||
orig_insert(result, k)
|
table.insert(result, k)
|
||||||
end
|
end
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.values(t)
|
function tbl.values(t)
|
||||||
if type(t) ~= "table" then error("table.values: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.values: argument must be a table", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
for _, v in pairs(t) do
|
for _, v in pairs(t) do
|
||||||
orig_insert(result, v)
|
table.insert(result, v)
|
||||||
end
|
end
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.pairs(t)
|
function tbl.pairs(t)
|
||||||
if type(t) ~= "table" then error("table.pairs: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.pairs: argument must be a table", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
orig_insert(result, {k, v})
|
table.insert(result, {k, v})
|
||||||
end
|
end
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.merge(...)
|
function tbl.merge(...)
|
||||||
local tables = {...}
|
local tables = {...}
|
||||||
if #tables == 0 then return {} end
|
if #tables == 0 then return {} end
|
||||||
|
|
||||||
for i, t in ipairs(tables) do
|
for i, t in ipairs(tables) do
|
||||||
if type(t) ~= "table" then
|
if type(t) ~= "table" then
|
||||||
error("table.merge: argument " .. i .. " must be a table", 2)
|
error("tbl.merge: argument " .. i .. " must be a table", 2)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -640,13 +677,13 @@ function table.merge(...)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.extend(t1, ...)
|
function tbl.extend(t1, ...)
|
||||||
if type(t1) ~= "table" then error("table.extend: first argument must be a table", 2) end
|
if type(t1) ~= "table" then error("tbl.extend: first argument must be a table", 2) end
|
||||||
|
|
||||||
local tables = {...}
|
local tables = {...}
|
||||||
for i, t in ipairs(tables) do
|
for i, t in ipairs(tables) do
|
||||||
if type(t) ~= "table" then
|
if type(t) ~= "table" then
|
||||||
error("table.extend: argument " .. (i + 1) .. " must be a table", 2)
|
error("tbl.extend: argument " .. (i + 1) .. " must be a table", 2)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -658,8 +695,8 @@ function table.extend(t1, ...)
|
|||||||
return t1
|
return t1
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.invert(t)
|
function tbl.invert(t)
|
||||||
if type(t) ~= "table" then error("table.invert: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.invert: argument must be a table", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
@ -668,8 +705,8 @@ function table.invert(t)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.pick(t, ...)
|
function tbl.pick(t, ...)
|
||||||
if type(t) ~= "table" then error("table.pick: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.pick: first argument must be a table", 2) end
|
||||||
|
|
||||||
local keys = {...}
|
local keys = {...}
|
||||||
local result = {}
|
local result = {}
|
||||||
@ -683,8 +720,8 @@ function table.pick(t, ...)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.omit(t, ...)
|
function tbl.omit(t, ...)
|
||||||
if type(t) ~= "table" then error("table.omit: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.omit: first argument must be a table", 2) end
|
||||||
|
|
||||||
local omit_keys = {}
|
local omit_keys = {}
|
||||||
for _, key in ipairs({...}) do
|
for _, key in ipairs({...}) do
|
||||||
@ -701,9 +738,13 @@ function table.omit(t, ...)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.deep_equals(t1, t2)
|
-- ======================================================================
|
||||||
if type(t1) ~= "table" then error("table.deep_equals: first argument must be a table", 2) end
|
-- DEEP OPERATIONS
|
||||||
if type(t2) ~= "table" then error("table.deep_equals: second argument must be a table", 2) end
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.deep_equals(t1, t2)
|
||||||
|
if type(t1) ~= "table" then error("tbl.deep_equals: first argument must be a table", 2) end
|
||||||
|
if type(t2) ~= "table" then error("tbl.deep_equals: second argument must be a table", 2) end
|
||||||
|
|
||||||
local function equals_recursive(a, b, seen)
|
local function equals_recursive(a, b, seen)
|
||||||
if a == b then return true end
|
if a == b then return true end
|
||||||
@ -739,11 +780,11 @@ function table.deep_equals(t1, t2)
|
|||||||
return equals_recursive(t1, t2, {})
|
return equals_recursive(t1, t2, {})
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.flatten(t, depth)
|
function tbl.flatten(t, depth)
|
||||||
if type(t) ~= "table" then error("table.flatten: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.flatten: first argument must be a table", 2) end
|
||||||
if not table.is_array(t) then error("table.flatten: first argument must be an array", 2) end
|
if not tbl.is_array(t) then error("tbl.flatten: first argument must be an array", 2) end
|
||||||
if depth ~= nil and (type(depth) ~= "number" or depth ~= math.floor(depth) or depth < 1) then
|
if depth ~= nil and (type(depth) ~= "number" or depth ~= math.floor(depth) or depth < 1) then
|
||||||
error("table.flatten: depth must be a positive integer", 2)
|
error("tbl.flatten: depth must be a positive integer", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
depth = depth or 1
|
depth = depth or 1
|
||||||
@ -751,13 +792,13 @@ function table.flatten(t, depth)
|
|||||||
local function flatten_recursive(arr, current_depth)
|
local function flatten_recursive(arr, current_depth)
|
||||||
local result = {}
|
local result = {}
|
||||||
for _, v in ipairs(arr) do
|
for _, v in ipairs(arr) do
|
||||||
if type(v) == "table" and table.is_array(v) and current_depth > 0 then
|
if type(v) == "table" and tbl.is_array(v) and current_depth > 0 then
|
||||||
local flattened = flatten_recursive(v, current_depth - 1)
|
local flattened = flatten_recursive(v, current_depth - 1)
|
||||||
for _, item in ipairs(flattened) do
|
for _, item in ipairs(flattened) do
|
||||||
orig_insert(result, item)
|
table.insert(result, item)
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
orig_insert(result, v)
|
table.insert(result, v)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return result
|
return result
|
||||||
@ -766,13 +807,13 @@ function table.flatten(t, depth)
|
|||||||
return flatten_recursive(t, depth)
|
return flatten_recursive(t, depth)
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.deep_merge(...)
|
function tbl.deep_merge(...)
|
||||||
local tables = {...}
|
local tables = {...}
|
||||||
if #tables == 0 then return {} end
|
if #tables == 0 then return {} end
|
||||||
|
|
||||||
for i, t in ipairs(tables) do
|
for i, t in ipairs(tables) do
|
||||||
if type(t) ~= "table" then
|
if type(t) ~= "table" then
|
||||||
error("table.deep_merge: argument " .. i .. " must be a table", 2)
|
error("tbl.deep_merge: argument " .. i .. " must be a table", 2)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -781,13 +822,13 @@ function table.deep_merge(...)
|
|||||||
if type(v) == "table" and type(target[k]) == "table" then
|
if type(v) == "table" and type(target[k]) == "table" then
|
||||||
target[k] = merge_recursive(target[k], v)
|
target[k] = merge_recursive(target[k], v)
|
||||||
else
|
else
|
||||||
target[k] = type(v) == "table" and table.deep_copy(v) or v
|
target[k] = type(v) == "table" and tbl.deep_copy(v) or v
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return target
|
return target
|
||||||
end
|
end
|
||||||
|
|
||||||
local result = table.deep_copy(tables[1])
|
local result = tbl.deep_copy(tables[1])
|
||||||
for i = 2, #tables do
|
for i = 2, #tables do
|
||||||
result = merge_recursive(result, tables[i])
|
result = merge_recursive(result, tables[i])
|
||||||
end
|
end
|
||||||
@ -795,11 +836,15 @@ function table.deep_merge(...)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.chunk(t, size)
|
-- ======================================================================
|
||||||
if type(t) ~= "table" then error("table.chunk: first argument must be a table", 2) end
|
-- ADVANCED OPERATIONS
|
||||||
if not table.is_array(t) then error("table.chunk: first argument must be an array", 2) end
|
-- ======================================================================
|
||||||
|
|
||||||
|
function tbl.chunk(t, size)
|
||||||
|
if type(t) ~= "table" then error("tbl.chunk: first argument must be a table", 2) end
|
||||||
|
if not tbl.is_array(t) then error("tbl.chunk: first argument must be an array", 2) end
|
||||||
if type(size) ~= "number" or size ~= math.floor(size) or size <= 0 then
|
if type(size) ~= "number" or size ~= math.floor(size) or size <= 0 then
|
||||||
error("table.chunk: size must be a positive integer", 2)
|
error("tbl.chunk: size must be a positive integer", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
@ -808,26 +853,26 @@ function table.chunk(t, size)
|
|||||||
for i = 1, len, size do
|
for i = 1, len, size do
|
||||||
local chunk = {}
|
local chunk = {}
|
||||||
for j = i, math.min(i + size - 1, len) do
|
for j = i, math.min(i + size - 1, len) do
|
||||||
orig_insert(chunk, t[j])
|
table.insert(chunk, t[j])
|
||||||
end
|
end
|
||||||
orig_insert(result, chunk)
|
table.insert(result, chunk)
|
||||||
end
|
end
|
||||||
|
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.partition(t, predicate)
|
function tbl.partition(t, predicate)
|
||||||
if type(t) ~= "table" then error("table.partition: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.partition: first argument must be a table", 2) end
|
||||||
if type(predicate) ~= "function" then error("table.partition: second argument must be a function", 2) end
|
if type(predicate) ~= "function" then error("tbl.partition: second argument must be a function", 2) end
|
||||||
|
|
||||||
local truthy, falsy = {}, {}
|
local truthy, falsy = {}, {}
|
||||||
|
|
||||||
if table.is_array(t) then
|
if tbl.is_array(t) then
|
||||||
for i, v in ipairs(t) do
|
for i, v in ipairs(t) do
|
||||||
if predicate(v, i, t) then
|
if predicate(v, i, t) then
|
||||||
orig_insert(truthy, v)
|
table.insert(truthy, v)
|
||||||
else
|
else
|
||||||
orig_insert(falsy, v)
|
table.insert(falsy, v)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
@ -843,9 +888,9 @@ function table.partition(t, predicate)
|
|||||||
return truthy, falsy
|
return truthy, falsy
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.group_by(t, key_func)
|
function tbl.group_by(t, key_func)
|
||||||
if type(t) ~= "table" then error("table.group_by: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.group_by: first argument must be a table", 2) end
|
||||||
if type(key_func) ~= "function" then error("table.group_by: second argument must be a function", 2) end
|
if type(key_func) ~= "function" then error("tbl.group_by: second argument must be a function", 2) end
|
||||||
|
|
||||||
local result = {}
|
local result = {}
|
||||||
|
|
||||||
@ -855,8 +900,8 @@ function table.group_by(t, key_func)
|
|||||||
result[group_key] = {}
|
result[group_key] = {}
|
||||||
end
|
end
|
||||||
|
|
||||||
if table.is_array(t) then
|
if tbl.is_array(t) then
|
||||||
orig_insert(result[group_key], v)
|
table.insert(result[group_key], v)
|
||||||
else
|
else
|
||||||
result[group_key][k] = v
|
result[group_key][k] = v
|
||||||
end
|
end
|
||||||
@ -865,16 +910,16 @@ function table.group_by(t, key_func)
|
|||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.zip(...)
|
function tbl.zip(...)
|
||||||
local arrays = {...}
|
local arrays = {...}
|
||||||
if #arrays == 0 then error("table.zip: at least one argument required", 2) end
|
if #arrays == 0 then error("tbl.zip: at least one argument required", 2) end
|
||||||
|
|
||||||
for i, arr in ipairs(arrays) do
|
for i, arr in ipairs(arrays) do
|
||||||
if type(arr) ~= "table" then
|
if type(arr) ~= "table" then
|
||||||
error("table.zip: argument " .. i .. " must be a table", 2)
|
error("tbl.zip: argument " .. i .. " must be a table", 2)
|
||||||
end
|
end
|
||||||
if not table.is_array(arr) then
|
if not tbl.is_array(arr) then
|
||||||
error("table.zip: argument " .. i .. " must be an array", 2)
|
error("tbl.zip: argument " .. i .. " must be an array", 2)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -887,16 +932,16 @@ function table.zip(...)
|
|||||||
for i = 1, min_length do
|
for i = 1, min_length do
|
||||||
local tuple = {}
|
local tuple = {}
|
||||||
for j = 1, #arrays do
|
for j = 1, #arrays do
|
||||||
orig_insert(tuple, arrays[j][i])
|
table.insert(tuple, arrays[j][i])
|
||||||
end
|
end
|
||||||
orig_insert(result, tuple)
|
table.insert(result, tuple)
|
||||||
end
|
end
|
||||||
|
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.compact(t)
|
function tbl.compact(t)
|
||||||
if type(t) ~= "table" then error("table.compact: argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.compact: argument must be a table", 2) end
|
||||||
|
|
||||||
-- Check if table has only integer keys (array-like)
|
-- Check if table has only integer keys (array-like)
|
||||||
local has_only_int_keys = true
|
local has_only_int_keys = true
|
||||||
@ -915,34 +960,34 @@ function table.compact(t)
|
|||||||
for i = 1, max_key do
|
for i = 1, max_key do
|
||||||
local v = t[i]
|
local v = t[i]
|
||||||
if v ~= nil and v ~= false then
|
if v ~= nil and v ~= false then
|
||||||
orig_insert(result, v)
|
table.insert(result, v)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return result
|
return result
|
||||||
else
|
else
|
||||||
-- Regular table filtering
|
-- Regular table filtering
|
||||||
return table.filter(t, function(v) return v ~= nil and v ~= false end)
|
return tbl.filter(t, function(v) return v ~= nil and v ~= false end)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.sample(t, n)
|
function tbl.sample(t, n)
|
||||||
if type(t) ~= "table" then error("table.sample: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.sample: first argument must be a table", 2) end
|
||||||
if not table.is_array(t) then error("table.sample: first argument must be an array", 2) end
|
if not tbl.is_array(t) then error("tbl.sample: first argument must be an array", 2) end
|
||||||
if n ~= nil and (type(n) ~= "number" or n ~= math.floor(n) or n < 0) then
|
if n ~= nil and (type(n) ~= "number" or n ~= math.floor(n) or n < 0) then
|
||||||
error("table.sample: sample size must be a non-negative integer", 2)
|
error("tbl.sample: sample size must be a non-negative integer", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
n = n or 1
|
n = n or 1
|
||||||
local len = #t
|
local len = #t
|
||||||
if n >= len then return table.clone(t) end
|
if n >= len then return tbl.clone(t) end
|
||||||
|
|
||||||
local shuffled = table.shuffle(t)
|
local shuffled = tbl.shuffle(t)
|
||||||
return table.slice(shuffled, 1, n)
|
return tbl.slice(shuffled, 1, n)
|
||||||
end
|
end
|
||||||
|
|
||||||
function table.fold(t, folder, initial)
|
function tbl.fold(t, folder, initial)
|
||||||
if type(t) ~= "table" then error("table.fold: first argument must be a table", 2) end
|
if type(t) ~= "table" then error("tbl.fold: first argument must be a table", 2) end
|
||||||
if type(folder) ~= "function" then error("table.fold: second argument must be a function", 2) end
|
if type(folder) ~= "function" then error("tbl.fold: second argument must be a function", 2) end
|
||||||
|
|
||||||
local accumulator = initial
|
local accumulator = initial
|
||||||
for k, v in pairs(t) do
|
for k, v in pairs(t) do
|
||||||
@ -950,3 +995,5 @@ function table.fold(t, folder, initial)
|
|||||||
end
|
end
|
||||||
return accumulator
|
return accumulator
|
||||||
end
|
end
|
||||||
|
|
||||||
|
return tbl
|
||||||
|
|||||||
362
tests/kv.lua
362
tests/kv.lua
@ -1,362 +0,0 @@
|
|||||||
require("tests")
|
|
||||||
local kv = require("kv")
|
|
||||||
|
|
||||||
-- Clean up any existing test files
|
|
||||||
os.remove("test_store.json")
|
|
||||||
os.remove("test_store.txt")
|
|
||||||
os.remove("test_oop.json")
|
|
||||||
os.remove("test_temp.json")
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- BASIC OPERATIONS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Store creation and opening", function()
|
|
||||||
assert(kv.open("test", "test_store.json"))
|
|
||||||
assert(kv.open("memory_only", ""))
|
|
||||||
assert(not kv.open("", "test.json")) -- Empty name should fail
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Set and get operations", function()
|
|
||||||
assert(kv.set("test", "key1", "value1"))
|
|
||||||
assert_equal("value1", kv.get("test", "key1"))
|
|
||||||
assert_equal("default", kv.get("test", "nonexistent", "default"))
|
|
||||||
assert_equal(nil, kv.get("test", "nonexistent"))
|
|
||||||
|
|
||||||
-- Test with special characters
|
|
||||||
assert(kv.set("test", "special:key", "value with spaces & symbols!"))
|
|
||||||
assert_equal("value with spaces & symbols!", kv.get("test", "special:key"))
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Key existence and deletion", function()
|
|
||||||
kv.set("test", "temp_key", "temp_value")
|
|
||||||
assert(kv.has("test", "temp_key"))
|
|
||||||
assert(not kv.has("test", "missing_key"))
|
|
||||||
|
|
||||||
assert(kv.delete("test", "temp_key"))
|
|
||||||
assert(not kv.has("test", "temp_key"))
|
|
||||||
assert(not kv.delete("test", "missing_key"))
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Store size tracking", function()
|
|
||||||
kv.clear("test")
|
|
||||||
assert_equal(0, kv.size("test"))
|
|
||||||
|
|
||||||
kv.set("test", "k1", "v1")
|
|
||||||
kv.set("test", "k2", "v2")
|
|
||||||
kv.set("test", "k3", "v3")
|
|
||||||
assert_equal(3, kv.size("test"))
|
|
||||||
|
|
||||||
kv.delete("test", "k2")
|
|
||||||
assert_equal(2, kv.size("test"))
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Keys and values retrieval", function()
|
|
||||||
kv.clear("test")
|
|
||||||
kv.set("test", "a", "apple")
|
|
||||||
kv.set("test", "b", "banana")
|
|
||||||
kv.set("test", "c", "cherry")
|
|
||||||
|
|
||||||
local keys = kv.keys("test")
|
|
||||||
local values = kv.values("test")
|
|
||||||
|
|
||||||
assert_equal(3, #keys)
|
|
||||||
assert_equal(3, #values)
|
|
||||||
|
|
||||||
-- Check all keys exist
|
|
||||||
local key_set = {}
|
|
||||||
for _, k in ipairs(keys) do
|
|
||||||
key_set[k] = true
|
|
||||||
end
|
|
||||||
assert(key_set["a"] and key_set["b"] and key_set["c"])
|
|
||||||
|
|
||||||
-- Check all values exist
|
|
||||||
local value_set = {}
|
|
||||||
for _, v in ipairs(values) do
|
|
||||||
value_set[v] = true
|
|
||||||
end
|
|
||||||
assert(value_set["apple"] and value_set["banana"] and value_set["cherry"])
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Clear store", function()
|
|
||||||
kv.set("test", "temp1", "value1")
|
|
||||||
kv.set("test", "temp2", "value2")
|
|
||||||
assert(kv.size("test") > 0)
|
|
||||||
|
|
||||||
assert(kv.clear("test"))
|
|
||||||
assert_equal(0, kv.size("test"))
|
|
||||||
assert(not kv.has("test", "temp1"))
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Save and close operations", function()
|
|
||||||
-- Ensure store is properly opened with filename
|
|
||||||
kv.close("test") -- Close if already open
|
|
||||||
assert(kv.open("test", "test_store.json"))
|
|
||||||
|
|
||||||
assert(kv.set("test", "persistent", "data"))
|
|
||||||
assert(kv.save("test"))
|
|
||||||
assert(kv.close("test"))
|
|
||||||
|
|
||||||
-- Reopen and verify data persists
|
|
||||||
assert(kv.open("test", "test_store.json"))
|
|
||||||
assert_equal("data", kv.get("test", "persistent"))
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Invalid store operations", function()
|
|
||||||
assert(not kv.set("nonexistent", "key", "value"))
|
|
||||||
assert_equal(nil, kv.get("nonexistent", "key"))
|
|
||||||
assert(not kv.has("nonexistent", "key"))
|
|
||||||
assert_equal(0, kv.size("nonexistent"))
|
|
||||||
assert(not kv.delete("nonexistent", "key"))
|
|
||||||
assert(not kv.clear("nonexistent"))
|
|
||||||
assert(not kv.save("nonexistent"))
|
|
||||||
assert(not kv.close("nonexistent"))
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- OBJECT-ORIENTED INTERFACE
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("OOP store creation", function()
|
|
||||||
local store = kv.create("oop_test", "test_oop.json")
|
|
||||||
assert_equal("oop_test", store.name)
|
|
||||||
|
|
||||||
local memory_store = kv.create("memory_oop")
|
|
||||||
assert_equal("memory_oop", memory_store.name)
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("OOP basic operations", function()
|
|
||||||
local store = kv.create("oop_basic")
|
|
||||||
|
|
||||||
assert(store:set("foo", "bar"))
|
|
||||||
assert_equal("bar", store:get("foo"))
|
|
||||||
assert_equal("default", store:get("missing", "default"))
|
|
||||||
assert(store:has("foo"))
|
|
||||||
assert(not store:has("missing"))
|
|
||||||
assert_equal(1, store:size())
|
|
||||||
|
|
||||||
assert(store:delete("foo"))
|
|
||||||
assert(not store:has("foo"))
|
|
||||||
assert_equal(0, store:size())
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("OOP collections", function()
|
|
||||||
local store = kv.create("oop_collections")
|
|
||||||
|
|
||||||
store:set("a", "apple")
|
|
||||||
store:set("b", "banana")
|
|
||||||
store:set("c", "cherry")
|
|
||||||
|
|
||||||
local keys = store:keys()
|
|
||||||
local values = store:values()
|
|
||||||
|
|
||||||
assert_equal(3, #keys)
|
|
||||||
assert_equal(3, #values)
|
|
||||||
|
|
||||||
store:clear()
|
|
||||||
assert_equal(0, store:size())
|
|
||||||
|
|
||||||
store:close()
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- UTILITY FUNCTIONS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Increment operations", function()
|
|
||||||
kv.open("util_test")
|
|
||||||
|
|
||||||
-- Increment non-existent key
|
|
||||||
assert_equal(1, kv.increment("util_test", "counter"))
|
|
||||||
assert_equal("1", kv.get("util_test", "counter"))
|
|
||||||
|
|
||||||
-- Increment existing key
|
|
||||||
assert_equal(6, kv.increment("util_test", "counter", 5))
|
|
||||||
assert_equal("6", kv.get("util_test", "counter"))
|
|
||||||
|
|
||||||
-- Decrement
|
|
||||||
assert_equal(4, kv.increment("util_test", "counter", -2))
|
|
||||||
|
|
||||||
kv.close("util_test")
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Append operations", function()
|
|
||||||
kv.open("append_test")
|
|
||||||
|
|
||||||
-- Append to non-existent key
|
|
||||||
assert(kv.append("append_test", "list", "first"))
|
|
||||||
assert_equal("first", kv.get("append_test", "list"))
|
|
||||||
|
|
||||||
-- Append with separator
|
|
||||||
assert(kv.append("append_test", "list", "second", ","))
|
|
||||||
assert_equal("first,second", kv.get("append_test", "list"))
|
|
||||||
|
|
||||||
assert(kv.append("append_test", "list", "third", ","))
|
|
||||||
assert_equal("first,second,third", kv.get("append_test", "list"))
|
|
||||||
|
|
||||||
kv.close("append_test")
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("TTL and expiration", function()
|
|
||||||
kv.open("ttl_test", "test_temp.json")
|
|
||||||
|
|
||||||
kv.set("ttl_test", "temp_key", "temp_value")
|
|
||||||
assert(kv.expire("ttl_test", "temp_key", 1)) -- 1 second TTL
|
|
||||||
|
|
||||||
-- Key should still exist immediately
|
|
||||||
assert(kv.has("ttl_test", "temp_key"))
|
|
||||||
|
|
||||||
-- Wait for expiration
|
|
||||||
os.execute("sleep 2")
|
|
||||||
local expired = kv.cleanup_expired("ttl_test")
|
|
||||||
assert(expired >= 1)
|
|
||||||
|
|
||||||
-- Key should be gone
|
|
||||||
assert(not kv.has("ttl_test", "temp_key"))
|
|
||||||
|
|
||||||
kv.close("ttl_test")
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- FILE FORMAT TESTS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("JSON file format", function()
|
|
||||||
kv.open("json_test", "test.json")
|
|
||||||
kv.set("json_test", "key1", "value1")
|
|
||||||
kv.set("json_test", "key2", "value2")
|
|
||||||
kv.save("json_test")
|
|
||||||
kv.close("json_test")
|
|
||||||
|
|
||||||
-- Verify file exists and reload
|
|
||||||
assert(file_exists("test.json"))
|
|
||||||
kv.open("json_test", "test.json")
|
|
||||||
assert_equal("value1", kv.get("json_test", "key1"))
|
|
||||||
assert_equal("value2", kv.get("json_test", "key2"))
|
|
||||||
kv.close("json_test")
|
|
||||||
|
|
||||||
os.remove("test.json")
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Text file format", function()
|
|
||||||
kv.open("txt_test", "test.txt")
|
|
||||||
kv.set("txt_test", "setting1", "value1")
|
|
||||||
kv.set("txt_test", "setting2", "value2")
|
|
||||||
kv.save("txt_test")
|
|
||||||
kv.close("txt_test")
|
|
||||||
|
|
||||||
-- Verify file exists and reload
|
|
||||||
assert(file_exists("test.txt"))
|
|
||||||
kv.open("txt_test", "test.txt")
|
|
||||||
assert_equal("value1", kv.get("txt_test", "setting1"))
|
|
||||||
assert_equal("value2", kv.get("txt_test", "setting2"))
|
|
||||||
kv.close("txt_test")
|
|
||||||
|
|
||||||
os.remove("test.txt")
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- EDGE CASES AND ERROR HANDLING
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Empty values and keys", function()
|
|
||||||
kv.open("edge_test")
|
|
||||||
|
|
||||||
-- Empty value
|
|
||||||
assert(kv.set("edge_test", "empty", ""))
|
|
||||||
assert_equal("", kv.get("edge_test", "empty"))
|
|
||||||
|
|
||||||
-- Unicode keys and values
|
|
||||||
assert(kv.set("edge_test", "ключ", "значение"))
|
|
||||||
assert_equal("значение", kv.get("edge_test", "ключ"))
|
|
||||||
|
|
||||||
kv.close("edge_test")
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Special characters in data", function()
|
|
||||||
kv.open("special_test")
|
|
||||||
|
|
||||||
local special_value = 'Special chars: "quotes", \'apostrophes\', \n newlines, \t tabs, \\ backslashes'
|
|
||||||
assert(kv.set("special_test", "special", special_value))
|
|
||||||
assert_equal(special_value, kv.get("special_test", "special"))
|
|
||||||
|
|
||||||
kv.close("special_test")
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Large data handling", function()
|
|
||||||
kv.open("large_test")
|
|
||||||
|
|
||||||
-- Large value
|
|
||||||
local large_value = string.rep("x", 10000)
|
|
||||||
assert(kv.set("large_test", "large", large_value))
|
|
||||||
assert_equal(large_value, kv.get("large_test", "large"))
|
|
||||||
|
|
||||||
-- Many keys
|
|
||||||
for i = 1, 100 do
|
|
||||||
kv.set("large_test", "key" .. i, "value" .. i)
|
|
||||||
end
|
|
||||||
assert_equal(101, kv.size("large_test")) -- 100 + 1 large value
|
|
||||||
|
|
||||||
kv.close("large_test")
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- PERFORMANCE TESTS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Performance test", function()
|
|
||||||
kv.open("perf_test")
|
|
||||||
|
|
||||||
local start = os.clock()
|
|
||||||
|
|
||||||
-- Bulk insert
|
|
||||||
for i = 1, 1000 do
|
|
||||||
kv.set("perf_test", "key" .. i, "value" .. i)
|
|
||||||
end
|
|
||||||
local insert_time = os.clock() - start
|
|
||||||
|
|
||||||
-- Bulk read
|
|
||||||
start = os.clock()
|
|
||||||
for i = 1, 1000 do
|
|
||||||
local value = kv.get("perf_test", "key" .. i)
|
|
||||||
assert_equal("value" .. i, value)
|
|
||||||
end
|
|
||||||
local read_time = os.clock() - start
|
|
||||||
|
|
||||||
-- Check final size
|
|
||||||
assert_equal(1000, kv.size("perf_test"))
|
|
||||||
|
|
||||||
print(string.format(" Insert 1000 items: %.3fs", insert_time))
|
|
||||||
print(string.format(" Read 1000 items: %.3fs", read_time))
|
|
||||||
|
|
||||||
kv.close("perf_test")
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- INTEGRATION TESTS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Multiple store integration", function()
|
|
||||||
local users = kv.create("users_int")
|
|
||||||
local cache = kv.create("cache_int")
|
|
||||||
|
|
||||||
-- Simulate user data
|
|
||||||
users:set("user:123", "john_doe")
|
|
||||||
cache:set("user:123:last_seen", tostring(os.time()))
|
|
||||||
|
|
||||||
-- Verify data in stores
|
|
||||||
assert_equal("john_doe", users:get("user:123"))
|
|
||||||
assert(cache:has("user:123:last_seen"))
|
|
||||||
|
|
||||||
-- Clean up
|
|
||||||
users:close()
|
|
||||||
cache:close()
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- Clean up test files
|
|
||||||
os.remove("test_store.json")
|
|
||||||
os.remove("test_oop.json")
|
|
||||||
os.remove("test_temp.json")
|
|
||||||
|
|
||||||
summary()
|
|
||||||
test_exit()
|
|
||||||
@ -1,397 +0,0 @@
|
|||||||
require("tests")
|
|
||||||
local sessions = require("sessions")
|
|
||||||
|
|
||||||
-- Clean up test files
|
|
||||||
os.remove("test_sessions.json")
|
|
||||||
os.remove("test_sessions2.json")
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- SESSION STORE INITIALIZATION
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Session store initialization", function()
|
|
||||||
assert(sessions.init("test_sessions", "test_sessions.json"))
|
|
||||||
assert(sessions.init("memory_sessions"))
|
|
||||||
|
|
||||||
-- Test with explicit store name
|
|
||||||
assert(sessions.init("named_sessions", "test_sessions2.json"))
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Session ID generation", function()
|
|
||||||
local id1 = sessions.generate_id()
|
|
||||||
local id2 = sessions.generate_id()
|
|
||||||
|
|
||||||
assert_equal(32, #id1)
|
|
||||||
assert_equal(32, #id2)
|
|
||||||
assert(id1 ~= id2, "Session IDs should be unique")
|
|
||||||
|
|
||||||
-- Check character set (alphanumeric)
|
|
||||||
assert(id1:match("^[a-zA-Z0-9]+$"), "Session ID should be alphanumeric")
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- BASIC SESSION OPERATIONS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Session creation and retrieval", function()
|
|
||||||
sessions.init("basic_test")
|
|
||||||
|
|
||||||
local session_id = sessions.generate_id()
|
|
||||||
local session_data = {
|
|
||||||
user_id = 123,
|
|
||||||
username = "testuser",
|
|
||||||
role = "admin",
|
|
||||||
permissions = {"read", "write"}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(sessions.create(session_id, session_data))
|
|
||||||
|
|
||||||
local retrieved = sessions.get(session_id)
|
|
||||||
assert_equal(123, retrieved.user_id)
|
|
||||||
assert_equal("testuser", retrieved.username)
|
|
||||||
assert_equal("admin", retrieved.role)
|
|
||||||
assert_table_equal({"read", "write"}, retrieved.permissions)
|
|
||||||
assert(retrieved._created ~= nil)
|
|
||||||
assert(retrieved._last_accessed ~= nil)
|
|
||||||
|
|
||||||
-- Test non-existent session
|
|
||||||
assert_equal(nil, sessions.get("nonexistent"))
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Session updates", function()
|
|
||||||
sessions.init("update_test")
|
|
||||||
|
|
||||||
local session_id = sessions.generate_id()
|
|
||||||
sessions.create(session_id, {count = 1})
|
|
||||||
|
|
||||||
local session = sessions.get(session_id)
|
|
||||||
local new_data = {
|
|
||||||
count = 2,
|
|
||||||
new_field = "added"
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(sessions.update(session_id, new_data))
|
|
||||||
|
|
||||||
local updated = sessions.get(session_id)
|
|
||||||
assert_equal(2, updated.count)
|
|
||||||
assert_equal("added", updated.new_field)
|
|
||||||
assert(updated._created ~= nil)
|
|
||||||
assert(updated._last_accessed ~= nil)
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Session deletion", function()
|
|
||||||
sessions.init("delete_test")
|
|
||||||
|
|
||||||
local session_id = sessions.generate_id()
|
|
||||||
sessions.create(session_id, {temp = true})
|
|
||||||
|
|
||||||
assert(sessions.get(session_id) ~= nil)
|
|
||||||
assert(sessions.delete(session_id))
|
|
||||||
assert_equal(nil, sessions.get(session_id))
|
|
||||||
|
|
||||||
-- Delete non-existent session
|
|
||||||
assert(not sessions.delete("nonexistent"))
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Session existence check", function()
|
|
||||||
sessions.init("exists_test")
|
|
||||||
|
|
||||||
local session_id = sessions.generate_id()
|
|
||||||
assert(not sessions.exists(session_id))
|
|
||||||
|
|
||||||
sessions.create(session_id, {test = true})
|
|
||||||
assert(sessions.exists(session_id))
|
|
||||||
|
|
||||||
sessions.delete(session_id)
|
|
||||||
assert(not sessions.exists(session_id))
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- MULTI-STORE OPERATIONS
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Multiple session stores", function()
|
|
||||||
assert(sessions.init("store1", "store1.json"))
|
|
||||||
assert(sessions.init("store2", "store2.json"))
|
|
||||||
|
|
||||||
local id1 = sessions.generate_id()
|
|
||||||
local id2 = sessions.generate_id()
|
|
||||||
|
|
||||||
-- Create sessions in different stores
|
|
||||||
assert(sessions.create(id1, {store = "store1"}, "store1"))
|
|
||||||
assert(sessions.create(id2, {store = "store2"}, "store2"))
|
|
||||||
|
|
||||||
-- Verify isolation
|
|
||||||
assert_equal(nil, sessions.get(id1, "store2"))
|
|
||||||
assert_equal(nil, sessions.get(id2, "store1"))
|
|
||||||
|
|
||||||
-- Verify correct retrieval
|
|
||||||
local s1 = sessions.get(id1, "store1")
|
|
||||||
local s2 = sessions.get(id2, "store2")
|
|
||||||
assert_equal("store1", s1.store)
|
|
||||||
assert_equal("store2", s2.store)
|
|
||||||
|
|
||||||
os.remove("store1.json")
|
|
||||||
os.remove("store2.json")
|
|
||||||
end)
|
|
||||||
|
|
||||||
test("Default store behavior", function()
|
|
||||||
sessions.reset()
|
|
||||||
sessions.init("default_store")
|
|
||||||
|
|
||||||
local session_id = sessions.generate_id()
|
|
||||||
sessions.create(session_id, {test = "default"})
|
|
||||||
|
|
||||||
-- Should work without specifying store
|
|
||||||
local retrieved = sessions.get(session_id)
|
|
||||||
assert_equal("default", retrieved.test)
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- SESSION CLEANUP
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Session cleanup", function()
|
|
||||||
sessions.init("cleanup_test")
|
|
||||||
|
|
||||||
local old_session = sessions.generate_id()
|
|
||||||
local new_session = sessions.generate_id()
|
|
||||||
|
|
||||||
sessions.create(old_session, {test = "old"}, "cleanup_test")
|
|
||||||
sessions.create(new_session, {test = "new"}, "cleanup_test")
|
|
||||||
|
|
||||||
-- Wait to create age difference
|
|
||||||
os.execute("sleep 2")
|
|
||||||
|
|
||||||
-- Access new session to update timestamp
|
|
||||||
sessions.get(new_session, "cleanup_test")
|
|
||||||
|
|
||||||
-- Clean up sessions older than 1 second
|
|
||||||
local deleted = sessions.cleanup(1, "cleanup_test")
|
|
||||||
assert(deleted >= 1)
|
|
||||||
|
|
||||||
-- New session should remain (recently accessed)
|
|
||||||
assert(sessions.get(new_session, "cleanup_test") ~= nil)
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- SESSION LISTING AND COUNTING
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Session listing and counting", function()
|
|
||||||
sessions.init("list_test")
|
|
||||||
|
|
||||||
local ids = {}
|
|
||||||
for i = 1, 5 do
|
|
||||||
local id = sessions.generate_id()
|
|
||||||
sessions.create(id, {index = i}, "list_test")
|
|
||||||
table.insert(ids, id)
|
|
||||||
end
|
|
||||||
|
|
||||||
local session_list = sessions.list("list_test")
|
|
||||||
assert_equal(5, #session_list)
|
|
||||||
|
|
||||||
local count = sessions.count("list_test")
|
|
||||||
assert_equal(5, count)
|
|
||||||
|
|
||||||
-- Verify all IDs are in the list
|
|
||||||
local id_set = {}
|
|
||||||
for _, id in ipairs(session_list) do
|
|
||||||
id_set[id] = true
|
|
||||||
end
|
|
||||||
|
|
||||||
for _, id in ipairs(ids) do
|
|
||||||
assert(id_set[id], "Session ID should be in list")
|
|
||||||
end
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- OBJECT-ORIENTED INTERFACE
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("OOP session store", function()
|
|
||||||
local store = sessions.create_store("oop_sessions", "oop_test.json")
|
|
||||||
|
|
||||||
local session_id = sessions.generate_id()
|
|
||||||
local data = {user = "oop_test", role = "admin"}
|
|
||||||
|
|
||||||
assert(store:create(session_id, data))
|
|
||||||
|
|
||||||
local retrieved = store:get(session_id)
|
|
||||||
assert_equal("oop_test", retrieved.user)
|
|
||||||
assert_equal("admin", retrieved.role)
|
|
||||||
|
|
||||||
retrieved.last_action = "login"
|
|
||||||
assert(store:update(session_id, retrieved))
|
|
||||||
|
|
||||||
local updated = store:get(session_id)
|
|
||||||
assert_equal("login", updated.last_action)
|
|
||||||
|
|
||||||
assert(store:exists(session_id))
|
|
||||||
assert_equal(1, store:count())
|
|
||||||
|
|
||||||
local session_list = store:list()
|
|
||||||
assert_equal(1, #session_list)
|
|
||||||
assert_equal(session_id, session_list[1])
|
|
||||||
|
|
||||||
assert(store:delete(session_id))
|
|
||||||
assert(not store:exists(session_id))
|
|
||||||
|
|
||||||
store:close()
|
|
||||||
os.remove("oop_test.json")
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- ERROR HANDLING
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Session error handling", function()
|
|
||||||
sessions.reset()
|
|
||||||
|
|
||||||
-- Try operations without initialization
|
|
||||||
local success1, err1 = pcall(sessions.create, "test_id", {})
|
|
||||||
assert(not success1)
|
|
||||||
|
|
||||||
local success2, err2 = pcall(sessions.get, "test_id")
|
|
||||||
assert(not success2)
|
|
||||||
|
|
||||||
-- Initialize and test invalid inputs
|
|
||||||
sessions.init("error_test")
|
|
||||||
|
|
||||||
-- Invalid session ID type
|
|
||||||
local success3, err3 = pcall(sessions.create, 123, {})
|
|
||||||
assert(not success3)
|
|
||||||
|
|
||||||
-- Invalid data type
|
|
||||||
local success4, err4 = pcall(sessions.create, "test", "not_a_table")
|
|
||||||
assert(not success4)
|
|
||||||
|
|
||||||
-- Invalid store name
|
|
||||||
local success5, err5 = pcall(sessions.get, "test", 123)
|
|
||||||
assert(not success5)
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- DATA PERSISTENCE
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Session persistence", function()
|
|
||||||
sessions.init("persist_test", "persist_sessions.json")
|
|
||||||
|
|
||||||
local session_id = sessions.generate_id()
|
|
||||||
local data = {
|
|
||||||
user_id = 789,
|
|
||||||
settings = {theme = "dark", lang = "en"},
|
|
||||||
cart = {items = {"item1", "item2"}, total = 25.99}
|
|
||||||
}
|
|
||||||
|
|
||||||
sessions.create(session_id, data, "persist_test")
|
|
||||||
sessions.close("persist_test")
|
|
||||||
|
|
||||||
-- Reinitialize and verify data persists
|
|
||||||
sessions.init("persist_test", "persist_sessions.json")
|
|
||||||
local retrieved = sessions.get(session_id, "persist_test")
|
|
||||||
|
|
||||||
assert_equal(789, retrieved.user_id)
|
|
||||||
assert_equal("dark", retrieved.settings.theme)
|
|
||||||
assert_equal("en", retrieved.settings.lang)
|
|
||||||
assert_equal(2, #retrieved.cart.items)
|
|
||||||
assert_equal(25.99, retrieved.cart.total)
|
|
||||||
|
|
||||||
sessions.close("persist_test")
|
|
||||||
os.remove("persist_sessions.json")
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- COMPLEX DATA STRUCTURES
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Complex session data", function()
|
|
||||||
sessions.init("complex_test")
|
|
||||||
|
|
||||||
local session_id = sessions.generate_id()
|
|
||||||
local complex_data = {
|
|
||||||
user = {
|
|
||||||
id = 456,
|
|
||||||
profile = {
|
|
||||||
name = "Jane Doe",
|
|
||||||
email = "jane@example.com",
|
|
||||||
preferences = {
|
|
||||||
notifications = true,
|
|
||||||
privacy = "friends_only"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
activity = {
|
|
||||||
pages_visited = {"home", "profile", "settings"},
|
|
||||||
actions = {
|
|
||||||
{type = "login", time = os.time()},
|
|
||||||
{type = "view_page", page = "profile", time = os.time()}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
metadata = {
|
|
||||||
ip = "192.168.1.1",
|
|
||||||
user_agent = "Test Browser 1.0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sessions.create(session_id, complex_data)
|
|
||||||
local retrieved = sessions.get(session_id)
|
|
||||||
|
|
||||||
assert_equal(456, retrieved.user.id)
|
|
||||||
assert_equal("Jane Doe", retrieved.user.profile.name)
|
|
||||||
assert_equal(true, retrieved.user.profile.preferences.notifications)
|
|
||||||
assert_equal(3, #retrieved.activity.pages_visited)
|
|
||||||
assert_equal("login", retrieved.activity.actions[1].type)
|
|
||||||
assert_equal("192.168.1.1", retrieved.metadata.ip)
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- WORKFLOW INTEGRATION
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
test("Session workflow integration", function()
|
|
||||||
sessions.init("workflow_test")
|
|
||||||
|
|
||||||
-- Simulate user workflow
|
|
||||||
local session_id = sessions.generate_id()
|
|
||||||
|
|
||||||
-- User login
|
|
||||||
sessions.create(session_id, {
|
|
||||||
user_id = 999,
|
|
||||||
username = "workflow_user",
|
|
||||||
status = "logged_in"
|
|
||||||
})
|
|
||||||
|
|
||||||
-- User adds items to cart
|
|
||||||
local session = sessions.get(session_id)
|
|
||||||
session.cart = {"item1", "item2"}
|
|
||||||
session.cart_total = 19.99
|
|
||||||
sessions.update(session_id, session)
|
|
||||||
|
|
||||||
-- User proceeds to checkout
|
|
||||||
session = sessions.get(session_id)
|
|
||||||
session.checkout_step = "payment"
|
|
||||||
session.payment_method = "credit_card"
|
|
||||||
sessions.update(session_id, session)
|
|
||||||
|
|
||||||
-- Verify final state
|
|
||||||
local final_session = sessions.get(session_id)
|
|
||||||
assert_equal(999, final_session.user_id)
|
|
||||||
assert_equal("logged_in", final_session.status)
|
|
||||||
assert_equal(2, #final_session.cart)
|
|
||||||
assert_equal(19.99, final_session.cart_total)
|
|
||||||
assert_equal("payment", final_session.checkout_step)
|
|
||||||
assert_equal("credit_card", final_session.payment_method)
|
|
||||||
|
|
||||||
-- User completes order and logs out
|
|
||||||
sessions.delete(session_id)
|
|
||||||
assert_equal(nil, sessions.get(session_id))
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- Clean up test files
|
|
||||||
os.remove("test_sessions.json")
|
|
||||||
os.remove("test_sessions2.json")
|
|
||||||
|
|
||||||
summary()
|
|
||||||
test_exit()
|
|
||||||
954
tests/string.lua
954
tests/string.lua
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user