Compare commits

..

95 Commits

Author SHA1 Message Date
55e6b54789 clean up sql connections at server shutdown 2025-08-01 15:35:31 -05:00
041b4a517d improve string library usage in the database libraries 2025-08-01 15:21:43 -05:00
aa340b7323 Improve watcher detection and rebuild error handling 2025-08-01 12:26:48 -05:00
64dd0b0c56 various string and http improvements 2025-08-01 12:22:22 -05:00
da5281ba0a string fixes, static server improvements 2025-07-31 19:07:27 -05:00
4d0d5b6757 merge request and response into context 2025-07-25 19:01:26 -05:00
81426c61b7 move csrf input helper to request obj, add post-commit hook for commit hash in metadata 2025-07-25 18:13:59 -05:00
d3633dfdbc fix redundant middleware, clean up pipeline 2025-07-25 14:43:51 -05:00
1fbcfaf9df add csrf, auth and flash to http module 2025-07-25 13:00:38 -05:00
7397c3ebbc add password utilities to crypto module 2025-07-25 11:05:32 -05:00
78b38ee544 remove json require from crypto lib 2025-07-25 10:31:43 -05:00
52fe63df39 fix table module 2025-07-25 09:43:57 -05:00
e05369431c fix kv store persistence on server shutdown 2025-07-24 22:44:52 -05:00
f09c9f345a merge sessions into http module directly 2025-07-24 22:30:23 -05:00
d2f0a75d50 minor string fixes 2025-07-24 21:49:44 -05:00
8b4a3b27c0 migrate json module to global 2025-07-24 17:52:58 -05:00
3239d6ac95 rewrite math module to global, improve test suite 2025-07-24 17:24:54 -05:00
ae4af71822 fix table as global mod 2025-07-24 16:52:28 -05:00
2d43c457e1 remove string requires, fix calls 2025-07-24 16:48:29 -05:00
41cba2f049 rewrite of the string module 2025-07-24 16:45:12 -05:00
8a53fea511 first pass on kv store 2025-07-24 11:49:16 -05:00
e45d63cf24 add build dir to ignore, update deps 2025-07-24 10:32:08 -05:00
71633b4b4b enhance database modules with table utils 2025-07-24 10:21:07 -05:00
5551f16bc1 implement table tests, fix crypto module json fail 2025-07-24 10:12:54 -05:00
1d05ac8bb2 add comprehensive table module 2025-07-24 09:47:13 -05:00
cf203d7899 initial sql database support - sqlite, postgres, mysql 2025-07-24 09:39:24 -05:00
09646394a5 hide old test dirs 2025-07-24 07:48:02 -05:00
d3dcf95e0c use string utils to simplify some http utils 2025-07-23 21:54:13 -05:00
6f20540720 fix http json usage, use native implementation 2025-07-18 10:19:04 -05:00
1e19ba7700 minor code modernization 2025-07-18 10:13:08 -05:00
cc2cb0c682 add watch mode 2025-07-17 23:00:47 -05:00
88c9bd90af fix static file serving 2025-07-17 22:39:14 -05:00
4ceca8d289 optimize string module 2025-07-17 22:12:57 -05:00
25a44660a4 fs module improvements 2025-07-17 21:42:11 -05:00
e5df8a5b8a refactor fs module, use io and os as backend 2025-07-17 21:34:47 -05:00
74faa76dbd replace go json with pure lua, much speed 2025-07-17 20:17:53 -05:00
12ba756b95 use fasthttp static file serving 2025-07-17 19:11:01 -05:00
e95f0f3370 adjust static file handling 2025-07-17 19:07:25 -05:00
1753007090 http module success 1 2025-07-17 18:56:48 -05:00
a110b93f5c introduce bytecode caching, update state management to use it 2025-07-17 13:48:56 -05:00
e86cb55aa6 introduce state management 2025-07-17 13:46:18 -05:00
09bfd24c8d simplify module registry 2025-07-17 13:40:41 -05:00
abf3aaba35 fix module registry, test failures 2025-07-17 12:51:25 -05:00
0012a7089d update json, math and string modules 2025-07-17 12:34:36 -05:00
898b29b86a update fs module 2025-07-17 12:22:04 -05:00
503f76d127 refactor modules, update crypto module 2025-07-17 12:20:26 -05:00
4ff04e141d interface{} to any 2025-07-16 20:54:15 -05:00
03d1b93f35 move modules.go to /modules 2025-07-16 19:57:49 -05:00
dadc6f13f7 store global bytecode 2025-07-16 17:01:16 -05:00
5a5d18ca0e Move http module go code to its own package 2025-07-16 13:25:18 -05:00
e8ad16ccdc first pass on HTTP module 2025-07-15 20:24:12 -05:00
a99eed9485 remove benchmark 2025-07-15 16:07:37 -05:00
edc8e9e607 enhance string library 2025-07-15 16:06:32 -05:00
743fd0e835 crypto, fs, string libs 2025-07-14 23:53:53 -05:00
acb8670177 update test fwk, fix package path 2025-07-14 21:51:02 -05:00
da602278c5 go functions first pass 2025-07-14 21:34:02 -05:00
e5388c4c23 runtime 2, math lib 2025-07-14 20:45:26 -05:00
f75ba90f74 restart 2025-07-14 19:11:43 -05:00
843e318e01 next pass 2025-07-14 17:36:59 -05:00
bb06e2431d update deps 2025-07-14 16:07:51 -05:00
97e3ec6547 first attempt 2025-07-14 16:03:02 -05:00
c53c54a5d9 refactor DirectoryWatcher to Watcher 2025-07-01 21:14:16 -05:00
8f74566e96 simplify file watchers 2025-07-01 21:13:33 -05:00
d86167a86e re-add watcher setup, drop unnecessary files 2025-07-01 15:29:26 -05:00
2c731b9cbf revert to string router 2025-06-06 22:25:19 -05:00
d44a9b5b28 major rewrite work 2 2025-06-06 18:57:47 -05:00
43b9dd7320 bring in ext color package, use new fin loading func 2025-06-06 13:56:46 -05:00
a2d9b0ad9f major rewrite work 1 2025-06-05 22:18:21 -05:00
ddb1d7c9d7 move logger to top level 2025-06-05 20:26:09 -05:00
60dd7ba82f move config to top level 2025-06-05 20:25:38 -05:00
0c87d0d704 license and readme 2025-06-05 20:23:50 -05:00
db8416778f move metadata to top level 2025-06-05 20:23:07 -05:00
bca4eef166 move color to top level, add windows support 2025-06-05 20:21:34 -05:00
4077ac03f1 work towards reusable sqlite connections 2025-06-05 19:04:05 -05:00
0c4ddd7e3d revert sqlite state tracking, add state index as global 2025-06-05 14:52:50 -05:00
14fcd7894b move sqlite in runner to its own package 2025-06-05 12:58:05 -05:00
cf38b947e1 add timestamp features, fix sqlite positional parameters 2025-06-05 11:34:19 -05:00
bf8ce59b73 optimize moduleLoader, re-add to runner 2025-06-04 21:50:56 -05:00
e3ee503c31 rewrite router, server and runner 2025-06-04 19:06:47 -05:00
d6e24f0185 refactor lua code into more granular modules 2025-06-04 11:22:15 -05:00
c2be77bf6a fix automatic json response handling 2025-06-04 11:11:21 -05:00
769a8dd2ce add flash session support, fix template rendering control flow 2025-06-04 09:54:13 -05:00
2c0067dfcf fix env, add additional type support, use coroutines for endpoint exec 2025-06-04 09:25:49 -05:00
ff01a1f0b1 fix error logs, add timestamp to http request logs, add back password funcs to sandbox 2025-06-03 20:24:05 -05:00
22c340648b massive optimizations and fixes 2025-06-03 18:34:22 -05:00
98be4aef25 simplify embed 2025-06-03 12:00:09 -05:00
b596ce9072 simplify sandbox and execution flow 2025-06-03 11:51:48 -05:00
86a122a50c Move Go lua libs to lualibs 2025-06-03 11:38:30 -05:00
61f66d6594 fix lots of luajit api regressions 2025-06-02 22:18:54 -05:00
1ad3059ff0 fix fenv error 2025-06-02 12:29:45 -05:00
e2b1b932ff migrate to new LJTG API 2025-06-02 11:15:56 -05:00
cc6a7675d8 update config, new LJTG version 2025-06-01 21:49:30 -05:00
c0c6100c17 rewrite logger and main 2025-05-31 15:26:06 -05:00
ac646bbad8 config rewrite 2025-05-31 13:02:16 -05:00
3f86c31c9f rewrite color utils 2025-05-31 09:11:21 -05:00
89 changed files with 16585 additions and 12428 deletions

14
.gitignore vendored
View File

@ -22,8 +22,14 @@ luajit/.git
# Go workspace file
go.work
# Claude workspace files
.claude
CLAUDE.md
# Test directories and files
/config.lua
test/
/init.lua
/moonshark
/*.lua
test_fs_dir
public
test
test.db
build

19
LICENSE
View File

@ -1,19 +1,2 @@
## Sharkk Open License
### Version 1.0, March 2025
Copyright (c) Sharkk, Skylear Johnson
Hey there, code surfer! You're free to ride this wave—use, modify, and share this software however you like, as long as you stick to these chill but important rules:
1. **Share Your Changes**: If you tweak, remix, or build on this software, youve gotta share your work with the world under the same license. That means making your modified source code available in a reasonable way—like linking to a public repo. Keep the stoke alive!
2. **Keep This License**: Whenever you pass this software along (whether youve changed it or not), you need to include this license in full. No sneaky restrictions that limit the freedom to ride the digital waves.
3. **Give Credit Where Its Due**: Show some love to the original author(s) by keeping the copyright notice and, if possible, linking back to the original source. Good vibes and respect go a long way.
4. **Make It Your Own**: If you add your own original code or features, youre totally free to monetize those additions. Sell it, license it, or turn it into the next big thing—just keep the original parts open for everyone.
5. **No Guarantees**: This software comes "as is." No promises, no warranties—just pure, unfiltered code. If things go sideways, youre riding that wave at your own risk. The authors arent responsible for any wipeouts.
By using, modifying, or sharing this software, youre agreeing to these terms. Keep it open, keep it flowing, and most of all—have fun!
DO NOT USE THIS SOFTWARE

View File

@ -1,8 +1,5 @@
# Moonshark
```bash
git submodule update --init --recursive
git submodule update --remote --recursive
go build -trimpath -ldflags="-s -w" -o moonshark .
go build -trimpath -ldflags="-s -w" -o build/moonshark .
```

4
build.sh Executable file
View File

@ -0,0 +1,4 @@
#!/bin/bash
mkdir -p build
go build -trimpath -ldflags="-s -w" -o build/moonshark .

View File

@ -1,21 +0,0 @@
server = {
port = 3117,
debug = false,
log_level = "info",
http_logging = false
}
runner = {
pool_size = 0 -- 0 defaults to GOMAXPROCS
}
dirs = {
routes = "routes",
static = "public",
fs = "fs",
data = "data",
override = "override",
libs = {
"libs"
}
}

41
go.mod
View File

@ -2,34 +2,39 @@ module Moonshark
go 1.24.1
require git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6
require (
git.sharkk.net/Go/LRU v1.0.0
git.sharkk.net/Sky/LuaJIT-to-Go v0.4.1
github.com/VictoriaMetrics/fastcache v1.12.4
github.com/alexedwards/argon2id v1.0.0
github.com/deneonet/benc v1.1.8
github.com/go-sql-driver/mysql v1.9.3
github.com/goccy/go-json v0.10.5
github.com/golang/snappy v1.0.0
github.com/matoous/go-nanoid/v2 v2.1.0
github.com/valyala/bytebufferpool v1.0.0
github.com/valyala/fasthttp v1.62.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.5
golang.org/x/crypto v0.40.0
zombiezen.com/go/sqlite v1.4.2
)
require (
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/crypto v0.38.0 // indirect
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 // indirect
golang.org/x/sys v0.33.0 // indirect
modernc.org/libc v1.65.8 // indirect
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.34.0 // indirect
golang.org/x/text v0.27.0 // indirect
modernc.org/libc v1.66.3 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.37.1 // indirect
modernc.org/sqlite v1.38.0 // indirect
)
require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.64.0
)

134
go.sum
View File

@ -1,35 +1,32 @@
git.sharkk.net/Go/LRU v1.0.0 h1:/KqdRVhHldi23aVfQZ4ss6vhCWZqA3vFiQyf1MJPpQc=
git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50=
git.sharkk.net/Sky/LuaJIT-to-Go v0.4.1 h1:CAYt+C6Vgo4JxK876j0ApQ2GDFFvy9FKO0OoZBVD18k=
git.sharkk.net/Sky/LuaJIT-to-Go v0.4.1/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0=
github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI=
github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w=
github.com/alexedwards/argon2id v1.0.0/go.mod h1:tYKkqIjzXvZdzPvADMWOEZ+l6+BD6CtBXMj5fnJppiw=
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8=
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6 h1:XytP9R2fWykv0MXIzxggPx5S/PmTkjyZVvUX2sn4EaU=
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/deneonet/benc v1.1.8 h1:Qk9diyH0UcnduvCrZ62mBrwUeSZzte4kQxMbclVdhW4=
github.com/deneonet/benc v1.1.8/go.mod h1:UCfkM5Od0B2huwv/ZItvtUb7QnALFt9YXtX8NXX4Lts=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs=
github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE=
github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
@ -38,79 +35,48 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0=
github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg=
github.com/valyala/fasthttp v1.64.0 h1:QBygLLQmiAyiXuRhthf0tuRkqAFcrC42dckN2S+N3og=
github.com/valyala/fasthttp v1.64.0/go.mod h1:dGmFxwkWXSK0NbOSJuF7AMVzU+lkHz0wQVvVITv2UQA=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.26.1 h1:+X5NtzVBn0KgsBCBe+xkDC7twLb/jNVj9FPgiwSQO3s=
modernc.org/cc/v4 v4.26.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
modernc.org/fileutil v1.3.1 h1:8vq5fe7jdtEvoCf3Zf9Nm0Q05sH6kGx0Op2CPx1wTC8=
modernc.org/fileutil v1.3.1/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/libc v1.65.8 h1:7PXRJai0TXZ8uNA3srsmYzmTyrLoHImV5QxHeni108Q=
modernc.org/libc v1.65.8/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
@ -119,8 +85,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs=
modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g=
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=

View File

@ -1,344 +0,0 @@
package http
import (
"bytes"
"context"
"sync"
"time"
"Moonshark/router"
"Moonshark/runner"
"Moonshark/sessions"
"Moonshark/utils"
"Moonshark/utils/color"
"Moonshark/utils/config"
"Moonshark/utils/logger"
"Moonshark/utils/metadata"
"github.com/valyala/fasthttp"
)
var (
//methodGET = []byte("GET")
methodPOST = []byte("POST")
methodPUT = []byte("PUT")
methodPATCH = []byte("PATCH")
debugPath = []byte("/debug/stats")
)
type Server struct {
luaRouter *router.LuaRouter
staticHandler fasthttp.RequestHandler
staticFS *fasthttp.FS
luaRunner *runner.Runner
fasthttpServer *fasthttp.Server
loggingEnabled bool
debugMode bool
config *config.Config
sessionManager *sessions.SessionManager
errorConfig utils.ErrorPageConfig
ctxPool sync.Pool
paramsPool sync.Pool
staticDir string
staticPrefix string
staticPrefixBytes []byte
// Cached error pages
cached404 []byte
cached500 []byte
errorCacheMu sync.RWMutex
}
func New(luaRouter *router.LuaRouter, staticDir string,
runner *runner.Runner, loggingEnabled bool, debugMode bool,
overrideDir string, config *config.Config) *Server {
staticPrefix := config.Server.StaticPrefix
if staticPrefix == "" {
staticPrefix = "/static/"
}
if staticPrefix[0] != '/' {
staticPrefix = "/" + staticPrefix
}
if staticPrefix[len(staticPrefix)-1] != '/' {
staticPrefix = staticPrefix + "/"
}
s := &Server{
luaRouter: luaRouter,
luaRunner: runner,
loggingEnabled: loggingEnabled,
debugMode: debugMode,
config: config,
sessionManager: sessions.GlobalSessionManager,
staticDir: staticDir,
staticPrefix: staticPrefix,
staticPrefixBytes: []byte(staticPrefix),
errorConfig: utils.ErrorPageConfig{
OverrideDir: overrideDir,
DebugMode: debugMode,
},
ctxPool: sync.Pool{
New: func() any {
return make(map[string]any, 6)
},
},
paramsPool: sync.Pool{
New: func() any {
return make(map[string]any, 4)
},
},
}
// Pre-cache error pages
s.cached404 = []byte(utils.NotFoundPage(s.errorConfig, ""))
s.cached500 = []byte(utils.InternalErrorPage(s.errorConfig, "", "Internal Server Error"))
// Setup static file serving
if staticDir != "" {
s.staticFS = &fasthttp.FS{
Root: staticDir,
IndexNames: []string{"index.html"},
GenerateIndexPages: false,
AcceptByteRange: true,
Compress: true,
CompressedFileSuffix: ".gz",
CompressBrotli: true,
CompressZstd: true,
PathRewrite: fasthttp.NewPathPrefixStripper(len(staticPrefix) - 1),
}
s.staticHandler = s.staticFS.NewRequestHandler()
}
s.fasthttpServer = &fasthttp.Server{
Handler: s.handleRequest,
Name: "Moonshark/" + metadata.Version,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
MaxRequestBodySize: 16 << 20,
TCPKeepalive: true,
TCPKeepalivePeriod: 60 * time.Second,
ReduceMemoryUsage: true,
DisablePreParseMultipartForm: true,
DisableHeaderNamesNormalizing: true,
NoDefaultServerHeader: true,
StreamRequestBody: true,
}
return s
}
func (s *Server) ListenAndServe(addr string) error {
logger.Info("Catch the swell at %s", color.Apply("http://localhost"+addr, color.Cyan))
return s.fasthttpServer.ListenAndServe(addr)
}
func (s *Server) Shutdown(ctx context.Context) error {
return s.fasthttpServer.ShutdownWithContext(ctx)
}
func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
start := time.Now()
methodBytes := ctx.Method()
pathBytes := ctx.Path()
// Fast path for debug stats
if s.debugMode && bytes.Equal(pathBytes, debugPath) {
s.handleDebugStats(ctx)
if s.loggingEnabled {
logger.LogRequest(ctx.Response.StatusCode(), string(methodBytes), string(pathBytes), time.Since(start))
}
return
}
// Fast path for static files
if s.staticHandler != nil && bytes.HasPrefix(pathBytes, s.staticPrefixBytes) {
s.staticHandler(ctx)
if s.loggingEnabled {
logger.LogRequest(ctx.Response.StatusCode(), string(methodBytes), string(pathBytes), time.Since(start))
}
return
}
// Lua route lookup - only allocate params if found
bytecode, scriptPath, routeErr, params, found := s.luaRouter.GetRouteInfo(methodBytes, pathBytes)
if found {
if len(bytecode) == 0 || routeErr != nil {
s.sendError(ctx, fasthttp.StatusInternalServerError, pathBytes, routeErr)
} else {
s.handleLuaRoute(ctx, bytecode, scriptPath, params, methodBytes, pathBytes)
}
} else {
s.send404(ctx, pathBytes)
}
if s.loggingEnabled {
logger.LogRequest(ctx.Response.StatusCode(), string(methodBytes), string(pathBytes), time.Since(start))
}
}
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string,
params *router.Params, methodBytes, pathBytes []byte) {
luaCtx := runner.NewHTTPContext(ctx)
defer luaCtx.Release()
if runner.GetGlobalEnvManager() != nil {
luaCtx.Set("env", runner.GetGlobalEnvManager().GetAll())
}
sessionMap := s.ctxPool.Get().(map[string]any)
defer func() {
for k := range sessionMap {
delete(sessionMap, k)
}
s.ctxPool.Put(sessionMap)
}()
session := s.sessionManager.GetSessionFromRequest(ctx)
sessionMap["id"] = session.ID
// Only get session data if not empty
if !session.IsEmpty() {
sessionMap["data"] = session.GetAll()
} else {
sessionMap["data"] = emptyMap
}
// Set basic context
luaCtx.Set("method", string(methodBytes))
luaCtx.Set("path", string(pathBytes))
luaCtx.Set("host", string(ctx.Host()))
luaCtx.Set("session", sessionMap)
// Add headers to context
headers := make(map[string]any)
ctx.Request.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
luaCtx.Set("headers", headers)
// Handle params
if params != nil && params.Count > 0 {
paramMap := s.paramsPool.Get().(map[string]any)
for i := range params.Count {
paramMap[params.Keys[i]] = params.Values[i]
}
luaCtx.Set("params", paramMap)
defer func() {
for k := range paramMap {
delete(paramMap, k)
}
s.paramsPool.Put(paramMap)
}()
} else {
luaCtx.Set("params", emptyMap)
}
// Parse form data for POST/PUT/PATCH
if bytes.Equal(methodBytes, methodPOST) ||
bytes.Equal(methodBytes, methodPUT) ||
bytes.Equal(methodBytes, methodPATCH) {
if formData, err := ParseForm(ctx); err == nil {
luaCtx.Set("form", formData)
} else {
if s.debugMode {
logger.Warning("Error parsing form: %v", err)
}
luaCtx.Set("form", emptyMap)
}
} else {
luaCtx.Set("form", emptyMap)
}
response, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
if err != nil {
logger.Error("Lua execution error: %v", err)
s.sendError(ctx, fasthttp.StatusInternalServerError, pathBytes, err)
return
}
// Handle session updates
if len(response.SessionData) > 0 {
if _, clearAll := response.SessionData["__clear_all"]; clearAll {
session.Clear()
delete(response.SessionData, "__clear_all")
}
for k, v := range response.SessionData {
if v == "__SESSION_DELETE_MARKER__" {
session.Delete(k)
} else {
session.Set(k, v)
}
}
}
s.sessionManager.ApplySessionCookie(ctx, session)
runner.ApplyResponse(response, ctx)
runner.ReleaseResponse(response)
}
func (s *Server) send404(ctx *fasthttp.RequestCtx, pathBytes []byte) {
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusNotFound)
// Use cached 404 for common case
if len(pathBytes) == 1 && pathBytes[0] == '/' {
s.errorCacheMu.RLock()
ctx.SetBody(s.cached404)
s.errorCacheMu.RUnlock()
} else {
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, string(pathBytes))))
}
}
func (s *Server) sendError(ctx *fasthttp.RequestCtx, status int, pathBytes []byte, err error) {
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(status)
if err == nil {
s.errorCacheMu.RLock()
ctx.SetBody(s.cached500)
s.errorCacheMu.RUnlock()
} else {
ctx.SetBody([]byte(utils.InternalErrorPage(s.errorConfig, string(pathBytes), err.Error())))
}
}
func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) {
stats := utils.CollectSystemStats(s.config)
routeCount, bytecodeBytes := s.luaRouter.GetRouteStats()
stats.Components = utils.ComponentStats{
RouteCount: routeCount,
BytecodeBytes: bytecodeBytes,
SessionStats: sessions.GlobalSessionManager.GetCacheStats(),
}
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetBody([]byte(utils.DebugStatsPage(stats)))
}
// SetStaticCaching enables/disables static file caching
func (s *Server) SetStaticCaching(duration time.Duration) {
if s.staticFS != nil {
s.staticFS.CacheDuration = duration
s.staticHandler = s.staticFS.NewRequestHandler()
}
}
// GetStaticPrefix returns the URL prefix for static files
func (s *Server) GetStaticPrefix() string {
return s.staticPrefix
}
// UpdateErrorCache refreshes cached error pages
func (s *Server) UpdateErrorCache() {
s.errorCacheMu.Lock()
s.cached404 = []byte(utils.NotFoundPage(s.errorConfig, ""))
s.cached500 = []byte(utils.InternalErrorPage(s.errorConfig, "", "Internal Server Error"))
s.errorCacheMu.Unlock()
}

View File

@ -1,147 +0,0 @@
package http
import (
"crypto/rand"
"encoding/base64"
"mime/multipart"
"strings"
"sync"
"github.com/valyala/fasthttp"
)
var (
emptyMap = make(map[string]any)
formDataPool = sync.Pool{
New: func() any {
return make(map[string]any, 16)
},
}
)
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
args := ctx.QueryArgs()
if args.Len() == 0 {
return emptyMap
}
queryMap := make(map[string]any, args.Len())
args.VisitAll(func(key, value []byte) {
k := string(key)
v := string(value)
appendValue(queryMap, k, v)
})
return queryMap
}
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
if strings.Contains(string(ctx.Request.Header.ContentType()), "multipart/form-data") {
return parseMultipartForm(ctx)
}
args := ctx.PostArgs()
if args.Len() == 0 {
return emptyMap, nil
}
formData := formDataPool.Get().(map[string]any)
for k := range formData {
delete(formData, k)
}
args.VisitAll(func(key, value []byte) {
k := string(key)
v := string(value)
appendValue(formData, k, v)
})
return formData, nil
}
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
form, err := ctx.MultipartForm()
if err != nil {
return nil, err
}
formData := formDataPool.Get().(map[string]any)
for k := range formData {
delete(formData, k)
}
for key, values := range form.Value {
if len(values) == 1 {
formData[key] = values[0]
} else if len(values) > 1 {
formData[key] = values
}
}
if len(form.File) > 0 {
files := make(map[string]any, len(form.File))
for fieldName, fileHeaders := range form.File {
if len(fileHeaders) == 1 {
files[fieldName] = fileInfoToMap(fileHeaders[0])
} else {
fileInfos := make([]map[string]any, len(fileHeaders))
for i, fh := range fileHeaders {
fileInfos[i] = fileInfoToMap(fh)
}
files[fieldName] = fileInfos
}
}
formData["_files"] = files
}
return formData, nil
}
func fileInfoToMap(fh *multipart.FileHeader) map[string]any {
ct := fh.Header.Get("Content-Type")
if ct == "" {
ct = getMimeType(fh.Filename)
}
return map[string]any{
"filename": fh.Filename,
"size": fh.Size,
"mimetype": ct,
}
}
func getMimeType(filename string) string {
if i := strings.LastIndex(filename, "."); i >= 0 {
switch filename[i:] {
case ".pdf":
return "application/pdf"
case ".png":
return "image/png"
case ".jpg", ".jpeg":
return "image/jpeg"
case ".gif":
return "image/gif"
case ".svg":
return "image/svg+xml"
}
}
return "application/octet-stream"
}
func appendValue(m map[string]any, k, v string) {
if existing, exists := m[k]; exists {
switch typed := existing.(type) {
case []string:
m[k] = append(typed, v)
case string:
m[k] = []string{typed, v}
}
} else {
m[k] = v
}
}
func GenerateSecureToken(length int) (string, error) {
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b)[:length], nil
}

393
main.go
View File

@ -1,393 +0,0 @@
package main
import (
"context"
"errors"
"flag"
"fmt"
"os"
"os/signal"
"path/filepath"
"strconv"
"syscall"
"time"
"Moonshark/http"
"Moonshark/router"
"Moonshark/runner"
"Moonshark/sessions"
"Moonshark/utils/color"
"Moonshark/utils/config"
"Moonshark/utils/logger"
"Moonshark/utils/metadata"
"Moonshark/watchers"
)
// Moonshark represents the server and all its dependencies
type Moonshark struct {
Config *config.Config
LuaRouter *router.LuaRouter
LuaRunner *runner.Runner
HTTPServer *http.Server
cleanupFuncs []func() error
scriptMode bool
}
func main() {
configPath := flag.String("config", "config.lua", "Path to configuration file")
debugFlag := flag.Bool("debug", false, "Enable debug mode")
scriptPath := flag.String("script", "", "Path to Lua script to execute once")
flag.Parse()
scriptMode := *scriptPath != ""
banner()
mode := ""
if scriptMode {
mode = "[Script Mode]"
} else {
mode = "[Server Mode]"
}
fmt.Printf("%s %s\n\n", color.Apply(mode, color.Gray), color.Apply("v"+metadata.Version, color.Blue))
// Initialize logger
logger.InitGlobalLogger(true, false)
// Load config
cfg, err := config.Load(*configPath)
if err != nil {
logger.Warning("Config load failed: %v, using defaults", color.Apply(err.Error(), color.Red))
cfg = config.New()
}
// Setup logging with debug mode
if *debugFlag || cfg.Server.Debug {
logger.EnableDebug()
logger.Debug("Debug logging enabled")
}
var moonshark *Moonshark
if scriptMode {
moonshark, err = initScriptMode(cfg)
} else {
moonshark, err = initServerMode(cfg, *debugFlag)
}
if err != nil {
logger.Fatal("Initialization failed: %v", err)
}
defer func() {
if err := moonshark.Shutdown(); err != nil {
logger.Error("Error during shutdown: %v", err)
os.Exit(1)
}
}()
if scriptMode {
// Run the script and exit
if err := moonshark.RunScript(*scriptPath); err != nil {
logger.Fatal("Script execution failed: %v", err)
}
return
}
// Start the server
if err := moonshark.Start(); err != nil {
logger.Fatal("Failed to start server: %v", err)
}
// Wait for shutdown signal
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
fmt.Print("\n")
logger.Info("Shutdown signal received")
}
// initScriptMode initializes minimal components needed for script execution
func initScriptMode(cfg *config.Config) (*Moonshark, error) {
moonshark := &Moonshark{
Config: cfg,
scriptMode: true,
}
// Only initialize the Lua runner with required paths
runnerOpts := []runner.RunnerOption{
runner.WithPoolSize(1), // Only need one state for script mode
runner.WithLibDirs(cfg.Dirs.Libs...),
runner.WithFsDir(cfg.Dirs.FS),
runner.WithDataDir(cfg.Dirs.Data),
}
var err error
moonshark.LuaRunner, err = runner.NewRunner(runnerOpts...)
if err != nil {
return nil, fmt.Errorf("failed to initialize Lua runner: %v", err)
}
logger.Debug("Script mode initialized with minimized components")
return moonshark, nil
}
// initServerMode initializes all components needed for server operation
func initServerMode(cfg *config.Config, debug bool) (*Moonshark, error) {
moonshark := &Moonshark{
Config: cfg,
scriptMode: false,
}
if debug {
cfg.Server.Debug = true
}
if err := initLuaRouter(moonshark); err != nil {
return nil, err
}
if err := initRunner(moonshark); err != nil {
return nil, err
}
if err := setupWatchers(moonshark); err != nil {
logger.Warning("Watcher setup failed: %v", err)
}
// Get static directory - empty string if it doesn't exist
staticDir := ""
if dirExists(cfg.Dirs.Static) {
staticDir = cfg.Dirs.Static
logger.Info("Static files enabled: %s", color.Apply(staticDir, color.Yellow))
} else {
logger.Warning("Static directory not found: %s", color.Apply(cfg.Dirs.Static, color.Yellow))
}
moonshark.HTTPServer = http.New(
moonshark.LuaRouter,
staticDir,
moonshark.LuaRunner,
cfg.Server.HTTPLogging,
cfg.Server.Debug,
cfg.Dirs.Override,
cfg,
)
// For development, disable caching. For production, enable it
if cfg.Server.Debug {
moonshark.HTTPServer.SetStaticCaching(0) // No caching in debug mode
} else {
moonshark.HTTPServer.SetStaticCaching(1 * time.Hour) // Cache for 1 hour in production
}
return moonshark, nil
}
// RunScript executes a Lua script in the sandbox environment
func (s *Moonshark) RunScript(scriptPath string) error {
scriptPath, err := filepath.Abs(scriptPath)
if err != nil {
return fmt.Errorf("failed to resolve script path: %v", err)
}
if _, err := os.Stat(scriptPath); os.IsNotExist(err) {
return fmt.Errorf("script file not found: %s", scriptPath)
}
logger.Info("Executing: %s", scriptPath)
resp, err := s.LuaRunner.RunScriptFile(scriptPath)
if err != nil {
return fmt.Errorf("execution failed: %v", err)
}
if resp != nil && resp.Body != nil {
logger.Info("Script result: %v", resp.Body)
} else {
logger.Info("Script executed successfully (no return value)")
}
return nil
}
// Start starts the HTTP server
func (s *Moonshark) Start() error {
if s.scriptMode {
return errors.New("cannot start server in script mode")
}
logger.Info("Surf's up on port %s!", color.Apply(strconv.Itoa(s.Config.Server.Port), color.Cyan))
go func() {
if err := s.HTTPServer.ListenAndServe(fmt.Sprintf(":%d", s.Config.Server.Port)); err != nil {
if err.Error() != "http: Server closed" {
logger.Error("Server error: %v", err)
}
}
}()
return nil
}
// Shutdown gracefully shuts down Moonshark
func (s *Moonshark) Shutdown() error {
logger.Info("Shutting down...")
if !s.scriptMode && s.HTTPServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.HTTPServer.Shutdown(ctx); err != nil {
logger.Error("HTTP server shutdown error: %v", err)
}
}
for _, cleanup := range s.cleanupFuncs {
if err := cleanup(); err != nil {
logger.Warning("Cleanup error: %v", err)
}
}
if s.LuaRunner != nil {
s.LuaRunner.Close()
}
if err := runner.CleanupEnv(); err != nil {
logger.Warning("Environment cleanup failed: %v", err)
}
logger.Info("Shutdown complete")
return nil
}
func dirExists(path string) bool {
info, err := os.Stat(path)
if err != nil {
return false
}
return info.IsDir()
}
func initLuaRouter(s *Moonshark) error {
if !dirExists(s.Config.Dirs.Routes) {
return fmt.Errorf("routes directory doesn't exist: %s", s.Config.Dirs.Routes)
}
var err error
s.LuaRouter, err = router.NewLuaRouter(s.Config.Dirs.Routes)
if err != nil {
if errors.Is(err, router.ErrRoutesCompilationErrors) {
// Non-fatal, some routes failed
logger.Warning("Some routes failed to compile")
if failedRoutes := s.LuaRouter.ReportFailedRoutes(); len(failedRoutes) > 0 {
for _, re := range failedRoutes {
logger.Error("Route %s %s: %v", re.Method, re.Path, re.Err)
}
}
} else {
return fmt.Errorf("lua router init failed: %v", err)
}
}
logger.Info("LuaRouter is g2g! %s", color.Set(s.Config.Dirs.Routes, color.Yellow))
return nil
}
func initRunner(s *Moonshark) error {
if !dirExists(s.Config.Dirs.Override) {
logger.Warning("Override directory not found... %s", color.Apply(s.Config.Dirs.Override, color.Yellow))
s.Config.Dirs.Override = ""
}
for _, dir := range s.Config.Dirs.Libs {
if !dirExists(dir) {
logger.Warning("Lib directory not found... %s", color.Apply(dir, color.Yellow))
}
}
// Initialize environment manager
if err := runner.InitEnv(s.Config.Dirs.Data); err != nil {
logger.Warning("Environment initialization failed: %v", err)
}
sessionManager := sessions.GlobalSessionManager
sessionManager.SetCookieOptions(
"MoonsharkSID",
"/",
"",
false,
true,
86400,
)
poolSize := s.Config.Runner.PoolSize
if s.scriptMode {
poolSize = 1 // Only need one state for script mode
}
runnerOpts := []runner.RunnerOption{
runner.WithPoolSize(poolSize),
runner.WithLibDirs(s.Config.Dirs.Libs...),
runner.WithFsDir(s.Config.Dirs.FS),
runner.WithDataDir(s.Config.Dirs.Data),
}
var err error
s.LuaRunner, err = runner.NewRunner(runnerOpts...)
if err != nil {
return fmt.Errorf("lua runner init failed: %v", err)
}
logger.Info("LuaRunner is g2g with %s states!", color.Apply(strconv.Itoa(poolSize), color.Yellow))
return nil
}
func setupWatchers(s *Moonshark) error {
manager := watchers.GetWatcherManager()
// Watch routes directory
routeWatcher, err := watchers.WatchLuaRouter(s.LuaRouter, s.LuaRunner, s.Config.Dirs.Routes)
if err != nil {
logger.Warning("Routes directory watch failed: %v", err)
} else {
routesDir := routeWatcher.GetDir()
s.cleanupFuncs = append(s.cleanupFuncs, func() error {
return manager.UnwatchDirectory(routesDir)
})
}
// Watch module directories
moduleWatchers, err := watchers.WatchLuaModules(s.LuaRunner, s.Config.Dirs.Libs)
if err != nil {
logger.Warning("Module directories watch failed: %v", err)
} else {
for _, watcher := range moduleWatchers {
dirPath := watcher.GetDir()
s.cleanupFuncs = append(s.cleanupFuncs, func() error {
return manager.UnwatchDirectory(dirPath)
})
}
plural := ""
if len(moduleWatchers) == 1 {
plural = "directory"
} else {
plural = "directories"
}
logger.Info("Watching %s module %s.", color.Apply(strconv.Itoa(len(moduleWatchers)), color.Yellow), plural)
}
return nil
}
func banner() {
banner := `
_____ _________.__ __
/ \ ____ ____ ____ / _____/| |__ _____ _______| | __
/ \ / \ / _ \ / _ \ / \ \_____ \ | | \\__ \\_ __ \ |/ /
/ Y ( <_> | <_> ) | \/ \| Y \/ __ \| | \/ <
\____|__ /\____/ \____/|___| /_______ /|___| (____ /__| |__|_ \
\/ \/ \/ \/ \/ \/
`
fmt.Println(color.Apply(banner, color.Blue))
}

5
metadata/version.go Normal file
View File

@ -0,0 +1,5 @@
package metadata
const (
Version = "1.0.0"
)

548
modules/crypto/crypto.go Normal file
View File

@ -0,0 +1,548 @@
package crypto
import (
"crypto/hmac"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/hex"
"fmt"
"math/big"
"strings"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/google/uuid"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/scrypt"
)
func GetFunctionList() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"base64_encode": base64_encode,
"base64_decode": base64_decode,
"base64_url_encode": base64_url_encode,
"base64_url_decode": base64_url_decode,
"hex_encode": hex_encode,
"hex_decode": hex_decode,
"md5_hash": md5_hash,
"sha1_hash": sha1_hash,
"sha256_hash": sha256_hash,
"sha512_hash": sha512_hash,
"hmac_sha256": hmac_sha256,
"hmac_sha1": hmac_sha1,
"uuid_generate": uuid_generate,
"uuid_generate_v4": uuid_generate_v4,
"uuid_validate": uuid_validate,
"random_bytes": random_bytes,
"random_hex": random_hex,
"random_string": random_string,
"secure_compare": secure_compare,
"argon2_hash": argon2_hash,
"argon2_verify": argon2_verify,
"bcrypt_hash": bcrypt_hash,
"bcrypt_verify": bcrypt_verify,
"scrypt_hash": scrypt_hash,
"scrypt_verify": scrypt_verify,
"pbkdf2_hash": pbkdf2_hash,
"pbkdf2_verify": pbkdf2_verify,
"password_hash": password_hash,
"password_verify": password_verify,
}
}
func base64_encode(s *luajit.State) int {
str := s.ToString(1)
encoded := base64.StdEncoding.EncodeToString([]byte(str))
s.PushString(encoded)
return 1
}
func base64_decode(s *luajit.State) int {
str := s.ToString(1)
decoded, err := base64.StdEncoding.DecodeString(str)
if err != nil {
s.PushNil()
s.PushString("invalid base64 data")
return 2
}
s.PushString(string(decoded))
return 1
}
func base64_url_encode(s *luajit.State) int {
str := s.ToString(1)
encoded := base64.URLEncoding.EncodeToString([]byte(str))
s.PushString(encoded)
return 1
}
func base64_url_decode(s *luajit.State) int {
str := s.ToString(1)
decoded, err := base64.URLEncoding.DecodeString(str)
if err != nil {
s.PushNil()
s.PushString("invalid base64url data")
return 2
}
s.PushString(string(decoded))
return 1
}
func hex_encode(s *luajit.State) int {
str := s.ToString(1)
encoded := hex.EncodeToString([]byte(str))
s.PushString(encoded)
return 1
}
func hex_decode(s *luajit.State) int {
str := s.ToString(1)
decoded, err := hex.DecodeString(str)
if err != nil {
s.PushNil()
s.PushString("invalid hex data")
return 2
}
s.PushString(string(decoded))
return 1
}
func md5_hash(s *luajit.State) int {
str := s.ToString(1)
hash := md5.Sum([]byte(str))
s.PushString(hex.EncodeToString(hash[:]))
return 1
}
func sha1_hash(s *luajit.State) int {
str := s.ToString(1)
hash := sha1.Sum([]byte(str))
s.PushString(hex.EncodeToString(hash[:]))
return 1
}
func sha256_hash(s *luajit.State) int {
str := s.ToString(1)
hash := sha256.Sum256([]byte(str))
s.PushString(hex.EncodeToString(hash[:]))
return 1
}
func sha512_hash(s *luajit.State) int {
str := s.ToString(1)
hash := sha512.Sum512([]byte(str))
s.PushString(hex.EncodeToString(hash[:]))
return 1
}
func hmac_sha256(s *luajit.State) int {
message := s.ToString(1)
key := s.ToString(2)
h := hmac.New(sha256.New, []byte(key))
h.Write([]byte(message))
s.PushString(hex.EncodeToString(h.Sum(nil)))
return 1
}
func hmac_sha1(s *luajit.State) int {
message := s.ToString(1)
key := s.ToString(2)
h := hmac.New(sha1.New, []byte(key))
h.Write([]byte(message))
s.PushString(hex.EncodeToString(h.Sum(nil)))
return 1
}
func uuid_generate(s *luajit.State) int {
id := uuid.New()
s.PushString(id.String())
return 1
}
func uuid_generate_v4(s *luajit.State) int {
id := uuid.New()
s.PushString(id.String())
return 1
}
func uuid_validate(s *luajit.State) int {
str := s.ToString(1)
_, err := uuid.Parse(str)
s.PushBoolean(err == nil)
return 1
}
func random_bytes(s *luajit.State) int {
length := int(s.ToNumber(1))
if length < 0 || length > 65536 {
s.PushNil()
s.PushString("invalid length")
return 2
}
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
s.PushNil()
s.PushString("failed to generate random bytes")
return 2
}
s.PushString(string(bytes))
return 1
}
func random_hex(s *luajit.State) int {
length := int(s.ToNumber(1))
if length < 0 || length > 32768 {
s.PushNil()
s.PushString("invalid length")
return 2
}
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
s.PushNil()
s.PushString("failed to generate random bytes")
return 2
}
s.PushString(hex.EncodeToString(bytes))
return 1
}
func random_string(s *luajit.State) int {
length := int(s.ToNumber(1))
if length < 0 || length > 65536 {
s.PushNil()
s.PushString("invalid length")
return 2
}
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
if s.GetTop() >= 2 && !s.IsNil(2) {
charset = s.ToString(2)
}
if len(charset) == 0 {
s.PushNil()
s.PushString("empty charset")
return 2
}
result := make([]byte, length)
charsetLen := big.NewInt(int64(len(charset)))
for i := range result {
n, err := rand.Int(rand.Reader, charsetLen)
if err != nil {
s.PushNil()
s.PushString("failed to generate random number")
return 2
}
result[i] = charset[n.Int64()]
}
s.PushString(string(result))
return 1
}
func secure_compare(s *luajit.State) int {
a := s.ToString(1)
b := s.ToString(2)
s.PushBoolean(hmac.Equal([]byte(a), []byte(b)))
return 1
}
func argon2_hash(s *luajit.State) int {
password := s.ToString(1)
time := uint32(1)
memory := uint32(64 * 1024)
threads := uint8(4)
keyLen := uint32(32)
if s.GetTop() >= 2 && !s.IsNil(2) {
time = uint32(s.ToNumber(2))
}
if s.GetTop() >= 3 && !s.IsNil(3) {
memory = uint32(s.ToNumber(3))
}
if s.GetTop() >= 4 && !s.IsNil(4) {
threads = uint8(s.ToNumber(4))
}
if s.GetTop() >= 5 && !s.IsNil(5) {
keyLen = uint32(s.ToNumber(5))
}
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
s.PushNil()
s.PushString("failed to generate salt")
return 2
}
hash := argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen)
encodedSalt := base64.RawStdEncoding.EncodeToString(salt)
encodedHash := base64.RawStdEncoding.EncodeToString(hash)
result := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
memory, time, threads, encodedSalt, encodedHash)
s.PushString(result)
return 1
}
func argon2_verify(s *luajit.State) int {
password := s.ToString(1)
hash := s.ToString(2)
parts := strings.Split(hash, "$")
if len(parts) != 6 || parts[1] != "argon2id" {
s.PushBoolean(false)
return 1
}
var memory, time uint32
var threads uint8
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads); err != nil {
s.PushBoolean(false)
return 1
}
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
s.PushBoolean(false)
return 1
}
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
s.PushBoolean(false)
return 1
}
actualHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
s.PushBoolean(hmac.Equal(actualHash, expectedHash))
return 1
}
func bcrypt_hash(s *luajit.State) int {
password := s.ToString(1)
cost := 12
if s.GetTop() >= 2 && !s.IsNil(2) {
cost = int(s.ToNumber(2))
if cost < 4 || cost > 31 {
s.PushNil()
s.PushString("invalid cost (must be 4-31)")
return 2
}
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), cost)
if err != nil {
s.PushNil()
s.PushString("bcrypt hash failed")
return 2
}
s.PushString(string(hash))
return 1
}
func bcrypt_verify(s *luajit.State) int {
password := s.ToString(1)
hash := s.ToString(2)
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
s.PushBoolean(err == nil)
return 1
}
func scrypt_hash(s *luajit.State) int {
password := s.ToString(1)
N := 32768 // CPU cost
r := 8 // block size
p := 1 // parallelization
keyLen := 32 // key length
if s.GetTop() >= 2 && !s.IsNil(2) {
N = int(s.ToNumber(2))
}
if s.GetTop() >= 3 && !s.IsNil(3) {
r = int(s.ToNumber(3))
}
if s.GetTop() >= 4 && !s.IsNil(4) {
p = int(s.ToNumber(4))
}
if s.GetTop() >= 5 && !s.IsNil(5) {
keyLen = int(s.ToNumber(5))
}
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
s.PushNil()
s.PushString("failed to generate salt")
return 2
}
hash, err := scrypt.Key([]byte(password), salt, N, r, p, keyLen)
if err != nil {
s.PushNil()
s.PushString("scrypt hash failed")
return 2
}
encodedSalt := base64.RawStdEncoding.EncodeToString(salt)
encodedHash := base64.RawStdEncoding.EncodeToString(hash)
result := fmt.Sprintf("$scrypt$N=%d,r=%d,p=%d$%s$%s", N, r, p, encodedSalt, encodedHash)
s.PushString(result)
return 1
}
func scrypt_verify(s *luajit.State) int {
password := s.ToString(1)
hash := s.ToString(2)
parts := strings.Split(hash, "$")
if len(parts) != 5 || parts[1] != "scrypt" {
s.PushBoolean(false)
return 1
}
var N, r, p int
if _, err := fmt.Sscanf(parts[2], "N=%d,r=%d,p=%d", &N, &r, &p); err != nil {
s.PushBoolean(false)
return 1
}
salt, err := base64.RawStdEncoding.DecodeString(parts[3])
if err != nil {
s.PushBoolean(false)
return 1
}
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
s.PushBoolean(false)
return 1
}
actualHash, err := scrypt.Key([]byte(password), salt, N, r, p, len(expectedHash))
if err != nil {
s.PushBoolean(false)
return 1
}
s.PushBoolean(hmac.Equal(actualHash, expectedHash))
return 1
}
func pbkdf2_hash(s *luajit.State) int {
password := s.ToString(1)
iterations := 100000
keyLen := 32
if s.GetTop() >= 2 && !s.IsNil(2) {
iterations = int(s.ToNumber(2))
}
if s.GetTop() >= 3 && !s.IsNil(3) {
keyLen = int(s.ToNumber(3))
}
salt := make([]byte, 16)
if _, err := rand.Read(salt); err != nil {
s.PushNil()
s.PushString("failed to generate salt")
return 2
}
hash := pbkdf2.Key([]byte(password), salt, iterations, keyLen, sha256.New)
encodedSalt := base64.RawStdEncoding.EncodeToString(salt)
encodedHash := base64.RawStdEncoding.EncodeToString(hash)
result := fmt.Sprintf("$pbkdf2-sha256$i=%d$%s$%s", iterations, encodedSalt, encodedHash)
s.PushString(result)
return 1
}
func pbkdf2_verify(s *luajit.State) int {
password := s.ToString(1)
hash := s.ToString(2)
parts := strings.Split(hash, "$")
if len(parts) != 5 || parts[1] != "pbkdf2-sha256" {
s.PushBoolean(false)
return 1
}
var iterations int
if _, err := fmt.Sscanf(parts[2], "i=%d", &iterations); err != nil {
s.PushBoolean(false)
return 1
}
salt, err := base64.RawStdEncoding.DecodeString(parts[3])
if err != nil {
s.PushBoolean(false)
return 1
}
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
s.PushBoolean(false)
return 1
}
actualHash := pbkdf2.Key([]byte(password), salt, iterations, len(expectedHash), sha256.New)
s.PushBoolean(hmac.Equal(actualHash, expectedHash))
return 1
}
func password_hash(s *luajit.State) int {
password := s.ToString(1)
algorithm := "argon2id" // default
if s.GetTop() >= 2 && !s.IsNil(2) {
algorithm = s.ToString(2)
}
switch algorithm {
case "argon2id":
s.PushString(password)
return argon2_hash(s)
case "bcrypt":
s.PushString(password)
if s.GetTop() >= 3 {
s.PushNumber(s.ToNumber(3))
}
return bcrypt_hash(s)
case "scrypt":
s.PushString(password)
return scrypt_hash(s)
case "pbkdf2":
s.PushString(password)
return pbkdf2_hash(s)
default:
s.PushNil()
s.PushString("unsupported algorithm: " + algorithm)
return 2
}
}
func password_verify(s *luajit.State) int {
hash := s.ToString(2)
// Auto-detect algorithm from hash format
if strings.HasPrefix(hash, "$argon2id$") {
return argon2_verify(s)
} else if strings.HasPrefix(hash, "$2a$") || strings.HasPrefix(hash, "$2b$") || strings.HasPrefix(hash, "$2y$") {
return bcrypt_verify(s)
} else if strings.HasPrefix(hash, "$scrypt$") {
return scrypt_verify(s)
} else if strings.HasPrefix(hash, "$pbkdf2-sha256$") {
return pbkdf2_verify(s)
}
s.PushBoolean(false)
return 1
}

530
modules/crypto/crypto.lua Normal file
View File

@ -0,0 +1,530 @@
local crypto = {}
-- ======================================================================
-- ENCODING / DECODING
-- ======================================================================
function crypto.base64_encode(data)
local result, err = moonshark.base64_encode(data)
if not result then
error(err)
end
return result
end
function crypto.base64_decode(data)
local result, err = moonshark.base64_decode(data)
if not result then
error(err)
end
return result
end
function crypto.base64_url_encode(data)
local result, err = moonshark.base64_url_encode(data)
if not result then
error(err)
end
return result
end
function crypto.base64_url_decode(data)
local result, err = moonshark.base64_url_decode(data)
if not result then
error(err)
end
return result
end
function crypto.hex_encode(data)
local result, err = moonshark.hex_encode(data)
if not result then
error(err)
end
return result
end
function crypto.hex_decode(data)
local result, err = moonshark.hex_decode(data)
if not result then
error(err)
end
return result
end
-- ======================================================================
-- HASHING FUNCTIONS
-- ======================================================================
function crypto.md5(data)
local result, err = moonshark.md5_hash(data)
if not result then
error(err)
end
return result
end
function crypto.sha1(data)
local result, err = moonshark.sha1_hash(data)
if not result then
error(err)
end
return result
end
function crypto.sha256(data)
local result, err = moonshark.sha256_hash(data)
if not result then
error(err)
end
return result
end
function crypto.sha512(data)
local result, err = moonshark.sha512_hash(data)
if not result then
error(err)
end
return result
end
-- Hash file contents
function crypto.hash_file(path, algorithm)
algorithm = algorithm or "sha256"
if not moonshark.file_exists(path) then
error("File not found: " .. path)
end
local content = moonshark.file_read(path)
if not content then
error("Failed to read file: " .. path)
end
if algorithm == "md5" then
return crypto.md5(content)
elseif algorithm == "sha1" then
return crypto.sha1(content)
elseif algorithm == "sha256" then
return crypto.sha256(content)
elseif algorithm == "sha512" then
return crypto.sha512(content)
else
error("Unsupported hash algorithm: " .. algorithm)
end
end
-- ======================================================================
-- HMAC FUNCTIONS
-- ======================================================================
function crypto.hmac_sha1(message, key)
local result, err = moonshark.hmac_sha1(message, key)
if not result then
error(err)
end
return result
end
function crypto.hmac_sha256(message, key)
local result, err = moonshark.hmac_sha256(message, key)
if not result then
error(err)
end
return result
end
-- ======================================================================
-- UUID FUNCTIONS
-- ======================================================================
function crypto.uuid()
return moonshark.uuid_generate()
end
function crypto.uuid_v4()
return moonshark.uuid_generate_v4()
end
function crypto.is_uuid(str)
return moonshark.uuid_validate(str)
end
-- ======================================================================
-- RANDOM GENERATORS
-- ======================================================================
function crypto.random_bytes(length)
local result, err = moonshark.random_bytes(length)
if not result then
error(err)
end
return result
end
function crypto.random_hex(length)
local result, err = moonshark.random_hex(length)
if not result then
error(err)
end
return result
end
function crypto.random_string(length, charset)
local result, err = moonshark.random_string(length, charset)
if not result then
error(err)
end
return result
end
-- Generate random alphanumeric string
function crypto.random_alphanumeric(length)
return crypto.random_string(length, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
end
-- Generate random password with mixed characters
function crypto.random_password(length, include_symbols)
length = length or 12
local charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
if include_symbols then
charset = charset .. "!@#$%^&*()_+-=[]{}|;:,.<>?"
end
return crypto.random_string(length, charset)
end
-- Generate cryptographically secure token
function crypto.token(length)
length = length or 32
return crypto.random_hex(length)
end
-- ======================================================================
-- UTILITY FUNCTIONS
-- ======================================================================
function crypto.secure_compare(a, b)
return moonshark.secure_compare(a, b)
end
-- Generate checksum for data integrity
function crypto.checksum(data, algorithm)
algorithm = algorithm or "sha256"
return crypto.hash_file and crypto[algorithm] and crypto[algorithm](data) or error("Invalid algorithm")
end
-- Verify data against checksum
function crypto.verify_checksum(data, expected, algorithm)
algorithm = algorithm or "sha256"
local actual = crypto[algorithm](data)
return crypto.secure_compare(actual, expected)
end
-- Simple encryption using XOR (not cryptographically secure)
function crypto.xor_encrypt(data, key)
local result = {}
local key_len = #key
for i = 1, #data do
local data_byte = string.byte(data, i)
local key_byte = string.byte(key, ((i - 1) % key_len) + 1)
table.insert(result, string.char(bit32 and bit32.bxor(data_byte, key_byte) or bit.bxor(data_byte, key_byte)))
end
return table.concat(result)
end
-- XOR decryption (same as encryption)
function crypto.xor_decrypt(data, key)
return crypto.xor_encrypt(data, key)
end
-- Generate hash chain for proof of work
function crypto.hash_chain(data, iterations, algorithm)
iterations = iterations or 1000
algorithm = algorithm or "sha256"
local result = data
for i = 1, iterations do
result = crypto[algorithm](result)
end
return result
end
-- Key derivation using PBKDF2-like approach (simplified)
function crypto.derive_key(password, salt, iterations, algorithm)
iterations = iterations or 10000
algorithm = algorithm or "sha256"
salt = salt or crypto.random_hex(16)
local derived = password .. salt
for i = 1, iterations do
derived = crypto[algorithm](derived)
end
return derived, salt
end
-- Generate nonce (number used once)
function crypto.nonce(length)
length = length or 16
return crypto.random_hex(length)
end
-- Create message authentication code
function crypto.mac(message, key, algorithm)
algorithm = algorithm or "sha256"
return crypto["hmac_" .. algorithm](message, key)
end
-- Verify message authentication code
function crypto.verify_mac(message, key, mac, algorithm)
algorithm = algorithm or "sha256"
local expected = crypto.mac(message, key, algorithm)
return crypto.secure_compare(expected, mac)
end
-- ======================================================================
-- CONVENIENCE FUNCTIONS
-- ======================================================================
-- One-shot encoding chain
function crypto.encode_chain(data, formats)
formats = formats or {"base64"}
local result = data
for _, format in ipairs(formats) do
if format == "base64" then
result = crypto.base64_encode(result)
elseif format == "base64url" then
result = crypto.base64_url_encode(result)
elseif format == "hex" then
result = crypto.hex_encode(result)
else
error("Unknown encoding format: " .. format)
end
end
return result
end
-- One-shot decoding chain (reverse order)
function crypto.decode_chain(data, formats)
formats = formats or {"base64"}
local result = data
-- Reverse the formats for decoding
for i = #formats, 1, -1 do
local format = formats[i]
if format == "base64" then
result = crypto.base64_decode(result)
elseif format == "base64url" then
result = crypto.base64_url_decode(result)
elseif format == "hex" then
result = crypto.hex_decode(result)
else
error("Unknown decoding format: " .. format)
end
end
return result
end
-- Hash multiple inputs
function crypto.hash_multiple(inputs, algorithm)
algorithm = algorithm or "sha256"
local combined = table.concat(inputs, "")
return crypto[algorithm](combined)
end
-- Create fingerprint from table data
function crypto.fingerprint(data, algorithm)
algorithm = algorithm or "sha256"
return crypto[algorithm](json.encode(data))
end
-- Simple data integrity check
function crypto.integrity_check(data)
return {
data = data,
hash = crypto.sha256(data),
timestamp = os.time(),
uuid = crypto.uuid()
}
end
-- Verify integrity check
function crypto.verify_integrity(check)
if not check.data or not check.hash then
return false
end
local expected = crypto.sha256(check.data)
return crypto.secure_compare(expected, check.hash)
end
-- ======================================================================
-- PASSWORD HASHING
-- ======================================================================
-- Generic password hashing (defaults to argon2id)
function crypto.hash_password(password, algorithm, options)
algorithm = algorithm or "argon2id"
options = options or {}
local result, err
if algorithm == "argon2id" then
local time = options.time or 1
local memory = options.memory or 65536 -- 64MB in KB
local threads = options.threads or 4
local keylen = options.keylen or 32
result, err = moonshark.argon2_hash(password, time, memory, threads, keylen)
elseif algorithm == "bcrypt" then
local cost = options.cost or 12
result, err = moonshark.bcrypt_hash(password, cost)
elseif algorithm == "scrypt" then
local N = options.N or 32768
local r = options.r or 8
local p = options.p or 1
local keylen = options.keylen or 32
result, err = moonshark.scrypt_hash(password, N, r, p, keylen)
elseif algorithm == "pbkdf2" then
local iterations = options.iterations or 100000
local keylen = options.keylen or 32
result, err = moonshark.pbkdf2_hash(password, iterations, keylen)
else
error("unsupported algorithm: " .. algorithm)
end
if not result then
error(err)
end
return result
end
-- Generic password verification (auto-detects algorithm)
function crypto.verify_password(password, hash)
return moonshark.password_verify(password, hash)
end
-- ======================================================================
-- ALGORITHM-SPECIFIC FUNCTIONS
-- ======================================================================
-- Argon2id hashing
function crypto.argon2_hash(password, options)
options = options or {}
local time = options.time or 1
local memory = options.memory or 65536
local threads = options.threads or 4
local keylen = options.keylen or 32
local result, err = moonshark.argon2_hash(password, time, memory, threads, keylen)
if not result then error(err) end
return result
end
function crypto.argon2_verify(password, hash)
return moonshark.argon2_verify(password, hash)
end
-- bcrypt hashing
function crypto.bcrypt_hash(password, cost)
cost = cost or 12
local result, err = moonshark.bcrypt_hash(password, cost)
if not result then error(err) end
return result
end
function crypto.bcrypt_verify(password, hash)
return moonshark.bcrypt_verify(password, hash)
end
-- scrypt hashing
function crypto.scrypt_hash(password, options)
options = options or {}
local N = options.N or 32768
local r = options.r or 8
local p = options.p or 1
local keylen = options.keylen or 32
local result, err = moonshark.scrypt_hash(password, N, r, p, keylen)
if not result then error(err) end
return result
end
function crypto.scrypt_verify(password, hash)
return moonshark.scrypt_verify(password, hash)
end
-- PBKDF2 hashing
function crypto.pbkdf2_hash(password, iterations, keylen)
iterations = iterations or 100000
keylen = keylen or 32
local result, err = moonshark.pbkdf2_hash(password, iterations, keylen)
if not result then error(err) end
return result
end
function crypto.pbkdf2_verify(password, hash)
return moonshark.pbkdf2_verify(password, hash)
end
-- ======================================================================
-- PASSWORD CONFIG PRESETS
-- ======================================================================
function crypto.hash_password_fast(password, algorithm)
algorithm = algorithm or "argon2id"
local options = {
argon2id = { time = 1, memory = 8192, threads = 1 },
bcrypt = { cost = 10 },
scrypt = { N = 16384, r = 8, p = 1 },
pbkdf2 = { iterations = 50000 }
}
return crypto.hash_password(password, algorithm, options[algorithm])
end
function crypto.hash_password_strong(password, algorithm)
algorithm = algorithm or "argon2id"
local options = {
argon2id = { time = 3, memory = 131072, threads = 4 },
bcrypt = { cost = 14 },
scrypt = { N = 65536, r = 8, p = 2 },
pbkdf2 = { iterations = 200000 }
}
return crypto.hash_password(password, algorithm, options[algorithm])
end
-- ======================================================================
-- UTILITY FUNCTIONS
-- ======================================================================
-- Detect algorithm from hash
function crypto.detect_algorithm(hash)
if hash:match("^%$argon2id%$") then
return "argon2id"
elseif hash:match("^%$2[aby]%$") then
return "bcrypt"
elseif hash:match("^%$scrypt%$") then
return "scrypt"
elseif hash:match("^%$pbkdf2%-sha256%$") then
return "pbkdf2"
else
return "unknown"
end
end
return crypto

37
modules/fs/fs.go Normal file
View File

@ -0,0 +1,37 @@
package fs
import (
"os"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func GetFunctionList() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"getcwd": getcwd,
"chdir": chdir,
}
}
func getcwd(s *luajit.State) int {
cwd, err := os.Getwd()
if err != nil {
s.PushNil()
s.PushString(err.Error())
return 2
}
s.PushString(cwd)
return 1
}
func chdir(s *luajit.State) int {
path := s.ToString(1)
err := os.Chdir(path)
if err != nil {
s.PushBoolean(false)
s.PushString(err.Error())
return 2
}
s.PushBoolean(true)
return 1
}

723
modules/fs/fs.lua Normal file
View File

@ -0,0 +1,723 @@
local fs = {}
local is_windows = package.config:sub(1,1) == '\\'
local path_sep = is_windows and '\\' or '/'
-- ======================================================================
-- UTILITY FUNCTIONS
-- ======================================================================
local function shell_escape(str)
if is_windows then
-- Windows: escape quotes and wrap in quotes
return '"' .. str:gsub('"', '""') .. '"'
else
-- Unix: escape shell metacharacters
return "'" .. str:gsub("'", "'\"'\"'") .. "'"
end
end
-- ======================================================================
-- FILE OPERATIONS
-- ======================================================================
function fs.exists(path)
local file = io.open(path, "r")
if file then
file:close()
return true
end
return false
end
function fs.size(path)
local file = io.open(path, "r")
if not file then return nil end
local size = file:seek("end")
file:close()
return size
end
function fs.is_dir(path)
if not fs.exists(path) then return false end
local cmd
if is_windows then
cmd = 'dir ' .. shell_escape(path) .. ' >nul 2>&1 && echo dir'
else
cmd = 'test -d ' .. shell_escape(path) .. ' && echo dir'
end
local handle = io.popen(cmd)
if handle then
local result = handle:read("*l")
handle:close()
return result == "dir"
end
return false
end
function fs.is_file(path)
return fs.exists(path) and not fs.is_dir(path)
end
function fs.read(path)
local file = io.open(path, "r")
if not file then
error("Failed to read file '" .. path .. "': file not found or permission denied")
end
local content = file:read("*all")
file:close()
if not content then
error("Failed to read file '" .. path .. "': read error")
end
return content
end
function fs.write(path, content)
if path == "" or path:find("\0") then
error("Failed to write file '" .. path .. "': invalid path")
end
local file = io.open(path, "w")
if not file then
error("Failed to write file '" .. path .. "': permission denied or invalid path")
end
local success = file:write(content)
file:close()
if not success then
error("Failed to write file '" .. path .. "': write error")
end
return true
end
function fs.append(path, content)
local file = io.open(path, "a")
if not file then
error("Failed to append to file '" .. path .. "': permission denied or invalid path")
end
local success = file:write(content)
file:close()
if not success then
error("Failed to append to file '" .. path .. "': write error")
end
return true
end
function fs.copy(src, dst)
local src_file = io.open(src, "rb")
if not src_file then
error("Failed to copy '" .. src .. "' to '" .. dst .. "': source file not found")
end
local dst_file = io.open(dst, "wb")
if not dst_file then
src_file:close()
error("Failed to copy '" .. src .. "' to '" .. dst .. "': cannot create destination")
end
local chunk_size = 8192
while true do
local chunk = src_file:read(chunk_size)
if not chunk then break end
if not dst_file:write(chunk) then
src_file:close()
dst_file:close()
error("Failed to copy '" .. src .. "' to '" .. dst .. "': write error")
end
end
src_file:close()
dst_file:close()
return true
end
function fs.move(src, dst)
local success = os.rename(src, dst)
if success then return true end
-- Fallback to copy + delete
fs.copy(src, dst)
fs.remove(src)
return true
end
function fs.remove(path)
local success = os.remove(path)
if not success then
error("Failed to remove '" .. path .. "': file not found or permission denied")
end
return true
end
function fs.mtime(path)
local cmd
if is_windows then
cmd = 'forfiles /p . /m ' .. shell_escape(fs.basename(path)) .. ' /c "cmd /c echo @fdate @ftime" 2>nul'
local handle = io.popen(cmd)
if handle then
local result = handle:read("*line")
handle:close()
if result then
-- Basic Windows timestamp parsing - fallback to file existence check
return fs.exists(path) and os.time() or nil
end
end
else
cmd = 'stat -c %Y ' .. shell_escape(path) .. ' 2>/dev/null'
local handle = io.popen(cmd)
if handle then
local result = handle:read("*line")
handle:close()
if result and result:match("^%d+$") then
return tonumber(result)
end
end
end
return nil
end
function fs.touch(path)
if fs.exists(path) then
-- Update timestamp
local content = fs.read(path)
fs.write(path, content)
else
-- Create empty file
fs.write(path, "")
end
return true
end
function fs.lines(path)
local lines = {}
local file = io.open(path, "r")
if not file then
error("Failed to read lines from '" .. path .. "': file not found")
end
for line in file:lines() do
lines[#lines + 1] = line
end
file:close()
return lines
end
-- ======================================================================
-- DIRECTORY OPERATIONS
-- ======================================================================
function fs.mkdir(path)
local cmd
if is_windows then
cmd = 'mkdir ' .. shell_escape(path) .. ' >nul 2>&1'
else
cmd = 'mkdir -p ' .. shell_escape(path)
end
local result = os.execute(cmd)
-- Handle different Lua version return values
local success = (result == 0) or (result == true)
if not success then
error("Failed to create directory '" .. path .. "'")
end
return true
end
function fs.rmdir(path)
local cmd
if is_windows then
cmd = 'rmdir /s /q ' .. shell_escape(path) .. ' >nul 2>&1'
else
cmd = 'rm -rf ' .. shell_escape(path)
end
local result = os.execute(cmd)
local success = (result == 0) or (result == true)
if not success then
error("Failed to remove directory '" .. path .. "'")
end
return true
end
function fs.list(path)
if not fs.exists(path) then
error("Failed to list directory '" .. path .. "': directory not found")
end
if not fs.is_dir(path) then
error("Failed to list directory '" .. path .. "': not a directory")
end
local entries = {}
local cmd
if is_windows then
cmd = 'dir /b ' .. shell_escape(path) .. ' 2>nul'
else
cmd = 'ls -1 ' .. shell_escape(path) .. ' 2>/dev/null'
end
local handle = io.popen(cmd)
if not handle then
error("Failed to list directory '" .. path .. "': command failed")
end
for name in handle:lines() do
local full_path = fs.join(path, name)
entries[#entries + 1] = {
name = name,
is_dir = fs.is_dir(full_path),
size = fs.is_file(full_path) and fs.size(full_path) or 0,
mtime = fs.mtime(full_path)
}
end
handle:close()
return entries
end
function fs.list_files(path)
local entries = fs.list(path)
local files = {}
for _, entry in ipairs(entries) do
if not entry.is_dir then
files[#files + 1] = entry
end
end
return files
end
function fs.list_dirs(path)
local entries = fs.list(path)
local dirs = {}
for _, entry in ipairs(entries) do
if entry.is_dir then
dirs[#dirs + 1] = entry
end
end
return dirs
end
function fs.list_names(path)
local entries = fs.list(path)
local names = {}
for _, entry in ipairs(entries) do
names[#names + 1] = entry.name
end
return names
end
-- ======================================================================
-- PATH OPERATIONS
-- ======================================================================
function fs.join(...)
local parts = {...}
if #parts == 0 then return "" end
if #parts == 1 then return parts[1] or "" end
local result = parts[1] or ""
for i = 2, #parts do
local part = parts[i]
if part and part ~= "" then
if result == "" then
result = part
elseif result:sub(-1) == path_sep then
result = result .. part
else
result = result .. path_sep .. part
end
end
end
return result
end
function fs.dirname(path)
if path == "" then return "." end
-- Remove trailing separators
path = path:gsub("[/\\]+$", "")
local pos = path:find("[/\\][^/\\]*$")
if pos then
local dir = path:sub(1, pos - 1)
return dir ~= "" and dir or path_sep
end
return "."
end
function fs.basename(path)
if path == "" then return "" end
-- Remove trailing separators
path = path:gsub("[/\\]+$", "")
return path:match("[^/\\]*$") or ""
end
function fs.ext(path)
local base = fs.basename(path)
local dot_pos = base:find("%.[^%.]*$")
return dot_pos and base:sub(dot_pos) or ""
end
function fs.abs(path)
if path == "" then path = "." end
if is_windows then
if path:match("^[A-Za-z]:") or path:match("^\\\\") then
return fs.clean(path)
end
else
if path:sub(1, 1) == "/" then
return fs.clean(path)
end
end
local cwd = fs.getcwd()
return fs.clean(fs.join(cwd, path))
end
function fs.clean(path)
if path == "" then return "." end
-- Normalize path separators
path = path:gsub("[/\\]+", path_sep)
-- Track if path was absolute
local is_absolute = false
if is_windows then
is_absolute = path:match("^[A-Za-z]:") or path:match("^\\\\")
else
is_absolute = path:sub(1, 1) == "/"
end
-- Split into components
local parts = {}
for part in path:gmatch("[^" .. path_sep:gsub("\\", "\\\\") .. "]+") do
if part == ".." then
if #parts > 0 and parts[#parts] ~= ".." then
if not is_absolute or #parts > 1 then
parts[#parts] = nil
end
elseif not is_absolute then
parts[#parts + 1] = ".."
end
elseif part ~= "." and part ~= "" then
parts[#parts + 1] = part
end
end
local result = table.concat(parts, path_sep)
if is_absolute then
if is_windows and path:match("^[A-Za-z]:") then
-- Windows drive letter
local drive = path:match("^[A-Za-z]:")
result = drive .. path_sep .. result
elseif is_windows and path:match("^\\\\") then
-- UNC path
result = "\\\\" .. result
else
-- Unix absolute
result = path_sep .. result
end
end
return result ~= "" and result or (is_absolute and path_sep or ".")
end
function fs.split(path)
local pos = path:find("[/\\][^/\\]*$")
if pos then
return path:sub(1, pos), path:sub(pos + 1)
end
return "", path
end
function fs.splitext(path)
local dir = fs.dirname(path)
local base = fs.basename(path)
local ext = fs.ext(path)
local name = base
if ext ~= "" then
name = base:sub(1, -(#ext + 1))
end
return dir, name, ext
end
-- ======================================================================
-- WORKING DIRECTORY
-- ======================================================================
function fs.getcwd()
local cwd, err = moonshark.getcwd()
if not cwd then
error("Failed to get current directory: " .. (err or "unknown error"))
end
return cwd
end
function fs.chdir(path)
local success, err = moonshark.chdir(path)
if not success then
error("Failed to change directory to '" .. path .. "': " .. (err or "unknown error"))
end
return true
end
-- ======================================================================
-- TEMPORARY FILES
-- ======================================================================
function fs.tempfile(prefix)
prefix = prefix or "tmp"
local temp_name = prefix .. "_" .. os.time() .. "_" .. math.random(10000)
local temp_path
if is_windows then
local temp_dir = os.getenv("TEMP") or os.getenv("TMP") or "C:\\temp"
temp_path = fs.join(temp_dir, temp_name)
else
temp_path = fs.join("/tmp", temp_name)
end
fs.write(temp_path, "")
return temp_path
end
function fs.tempdir(prefix)
prefix = prefix or "tmp"
local temp_name = prefix .. "_" .. os.time() .. "_" .. math.random(10000)
local temp_path
if is_windows then
local temp_dir = os.getenv("TEMP") or os.getenv("TMP") or "C:\\temp"
temp_path = fs.join(temp_dir, temp_name)
else
temp_path = fs.join("/tmp", temp_name)
end
fs.mkdir(temp_path)
return temp_path
end
-- ======================================================================
-- PATTERN MATCHING
-- ======================================================================
function fs.glob(pattern)
-- Simple validation to prevent obvious shell injection
if pattern:find("[;&|`$(){}]") then
return {}
end
local cmd
if is_windows then
cmd = 'for %f in (' .. pattern .. ') do @echo %f 2>nul'
else
cmd = 'ls -1d ' .. pattern .. ' 2>/dev/null'
end
local matches = {}
local handle = io.popen(cmd)
if handle then
for match in handle:lines() do
matches[#matches + 1] = match
end
handle:close()
end
return matches
end
function fs.walk(root)
local files = {}
local function walk_recursive(path)
if not fs.exists(path) then return end
files[#files + 1] = path
if fs.is_dir(path) then
local success, entries = pcall(fs.list, path)
if success then
for _, entry in ipairs(entries) do
walk_recursive(fs.join(path, entry.name))
end
end
end
end
walk_recursive(root)
return files
end
-- ======================================================================
-- UTILITY FUNCTIONS
-- ======================================================================
function fs.extension(path)
local ext = fs.ext(path)
return ext:sub(2) -- Remove leading dot
end
function fs.change_ext(path, new_ext)
local dir, name, _ = fs.splitext(path)
if not new_ext:match("^%.") then
new_ext = "." .. new_ext
end
if dir == "." then
return name .. new_ext
end
return fs.join(dir, name .. new_ext)
end
function fs.ensure_dir(path)
if not fs.exists(path) then
fs.mkdir(path)
elseif not fs.is_dir(path) then
error("Path exists but is not a directory: " .. path)
end
return true
end
function fs.size_human(path)
local size = fs.size(path)
if not size then return nil end
local units = {"B", "KB", "MB", "GB", "TB"}
local unit_index = 1
local size_float = size
while size_float >= 1024 and unit_index < #units do
size_float = size_float / 1024
unit_index = unit_index + 1
end
if unit_index == 1 then
return string.format("%d %s", size_float, units[unit_index])
else
return string.format("%.1f %s", size_float, units[unit_index])
end
end
function fs.is_safe_path(path)
path = fs.clean(path)
if path:match("%.%.") then return false end
if path:match("^[/\\]") then return false end
if path:match("^~") then return false end
return true
end
-- Aliases for convenience
fs.makedirs = fs.mkdir
fs.removedirs = fs.rmdir
function fs.copytree(src, dst)
if not fs.exists(src) then
error("Source directory does not exist: " .. src)
end
if not fs.is_dir(src) then
error("Source is not a directory: " .. src)
end
fs.mkdir(dst)
local entries = fs.list(src)
for _, entry in ipairs(entries) do
local src_path = fs.join(src, entry.name)
local dst_path = fs.join(dst, entry.name)
if entry.is_dir then
fs.copytree(src_path, dst_path)
else
fs.copy(src_path, dst_path)
end
end
return true
end
function fs.find(root, pattern, recursive)
recursive = recursive ~= false
local results = {}
local function search(dir)
local success, entries = pcall(fs.list, dir)
if not success then return end
for _, entry in ipairs(entries) do
local full_path = fs.join(dir, entry.name)
if not entry.is_dir and entry.name:match(pattern) then
results[#results + 1] = full_path
elseif entry.is_dir and recursive then
search(full_path)
end
end
end
search(root)
return results
end
function fs.tree(root, max_depth)
max_depth = max_depth or 10
local function build_tree(path, depth)
if depth > max_depth then return nil end
if not fs.exists(path) then return nil end
local node = {
name = fs.basename(path),
path = path,
is_dir = fs.is_dir(path)
}
if node.is_dir then
node.children = {}
local success, entries = pcall(fs.list, path)
if success then
for _, entry in ipairs(entries) do
local child_path = fs.join(path, entry.name)
local child = build_tree(child_path, depth + 1)
if child then
node.children[#node.children + 1] = child
end
end
end
else
node.size = fs.size(path)
node.mtime = fs.mtime(path)
end
return node
end
return build_tree(root, 1)
end
return fs

358
modules/http/http.go Normal file
View File

@ -0,0 +1,358 @@
package http
import (
"Moonshark/metadata"
"context"
"fmt"
"net"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/valyala/fasthttp"
)
var (
globalServer *fasthttp.Server
globalWorkerPool *WorkerPool
globalStateCreator StateCreator
globalMu sync.RWMutex
serverRunning bool
staticHandlers = make(map[string]*fasthttp.FS)
staticMu sync.RWMutex
)
func SetStateCreator(creator StateCreator) {
globalStateCreator = creator
}
func GetFunctionList() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"http_create_server": http_create_server,
"http_spawn_workers": http_spawn_workers,
"http_listen": http_listen,
"http_close_server": http_close_server,
"http_has_servers": http_has_servers,
"http_register_static": http_register_static,
}
}
func http_create_server(s *luajit.State) int {
globalMu.Lock()
defer globalMu.Unlock()
if globalServer != nil {
s.PushBoolean(true) // Already created
return 1
}
globalServer = &fasthttp.Server{
Name: "Moonshark/" + metadata.Version,
Handler: handleRequest,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
}
s.PushBoolean(true)
return 1
}
func http_spawn_workers(s *luajit.State) int {
globalMu.Lock()
defer globalMu.Unlock()
if globalWorkerPool != nil {
s.PushBoolean(true) // Already spawned
return 1
}
if globalStateCreator == nil {
s.PushBoolean(false)
s.PushString("state creator not set")
return 2
}
workerCount := max(runtime.NumCPU(), 2)
pool, err := NewWorkerPool(workerCount, s, globalStateCreator)
if err != nil {
s.PushBoolean(false)
s.PushString(fmt.Sprintf("failed to create worker pool: %v", err))
return 2
}
globalWorkerPool = pool
s.PushBoolean(true)
return 1
}
func http_listen(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("http_listen: %v", err)
}
addr := s.ToString(1)
globalMu.RLock()
server := globalServer
globalMu.RUnlock()
if server == nil {
s.PushBoolean(false)
s.PushString("no server created")
return 2
}
globalMu.Lock()
if serverRunning {
globalMu.Unlock()
s.PushBoolean(true) // Already running
return 1
}
serverRunning = true
globalMu.Unlock()
go func() {
if err := server.ListenAndServe(addr); err != nil {
fmt.Printf("HTTP server error: %v\n", err)
}
}()
time.Sleep(100 * time.Millisecond)
conn, err := net.Dial("tcp", addr)
if err != nil {
globalMu.Lock()
serverRunning = false
globalMu.Unlock()
s.PushBoolean(false)
s.PushString(fmt.Sprintf("failed to start server: %v", err))
return 2
}
conn.Close()
s.PushBoolean(true)
return 1
}
func http_close_server(s *luajit.State) int {
StopAllServers()
s.PushBoolean(true)
return 1
}
func http_has_servers(s *luajit.State) int {
globalMu.RLock()
running := serverRunning
globalMu.RUnlock()
s.PushBoolean(running)
return 1
}
func http_register_static(s *luajit.State) int {
if err := s.CheckMinArgs(2); err != nil {
return s.PushError("http_register_static: %v", err)
}
urlPrefix := s.ToString(1)
rootPath := s.ToString(2)
noCache := s.ToBoolean(3)
// Ensure prefix starts with /
if !strings.HasPrefix(urlPrefix, "/") {
urlPrefix = "/" + urlPrefix
}
// Convert to absolute path
absPath, err := filepath.Abs(rootPath)
if err != nil {
s.PushBoolean(false)
s.PushString(fmt.Sprintf("invalid path: %v", err))
return 2
}
RegisterStaticHandler(urlPrefix, absPath, noCache)
s.PushBoolean(true)
return 1
}
func HasActiveServers() bool {
globalMu.RLock()
defer globalMu.RUnlock()
return serverRunning
}
func WaitForServers() {
for HasActiveServers() {
time.Sleep(100 * time.Millisecond)
}
}
func handleRequest(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
// Fast path for likely static files (has extension)
if isLikelyStaticFile(path) && tryStaticHandler(ctx, path) {
return
}
// Try Lua routing
globalMu.RLock()
pool := globalWorkerPool
globalMu.RUnlock()
if pool != nil {
worker := pool.Get()
if worker != nil {
defer pool.Put(worker)
req := GetRequest()
defer PutRequest(req)
resp := GetResponse()
defer PutResponse(resp)
// Populate request
req.Method = string(ctx.Method())
req.Path = path
req.Body = string(ctx.Request.Body())
for key, value := range ctx.QueryArgs().All() {
req.Query[string(key)] = string(value)
}
for key, value := range ctx.Request.Header.All() {
req.Headers[string(key)] = string(value)
}
err := worker.HandleRequest(req, resp)
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetBodyString(fmt.Sprintf("Internal Server Error: %v", err))
return
}
// If Lua found a route, use it
if resp.StatusCode != 404 {
ctx.SetStatusCode(resp.StatusCode)
for key, value := range resp.Headers {
ctx.Response.Header.Set(key, value)
}
if resp.Body != "" {
ctx.SetBodyString(resp.Body)
}
return
}
}
}
if !isLikelyStaticFile(path) && tryStaticHandler(ctx, path) {
return
}
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBodyString("Not Found")
}
func tryStaticHandler(ctx *fasthttp.RequestCtx, path string) bool {
staticMu.RLock()
defer staticMu.RUnlock()
for prefix, fs := range staticHandlers {
if after, ok := strings.CutPrefix(path, prefix); ok {
ctx.Request.URI().SetPath(after)
fs.NewRequestHandler()(ctx)
return true
}
}
return false
}
func isLikelyStaticFile(path string) bool {
// Check for file extension
lastSlash := strings.LastIndex(path, "/")
lastDot := strings.LastIndex(path, ".")
return lastDot > lastSlash && lastDot != -1
}
// RegisterStaticHandler adds a static file handler
func RegisterStaticHandler(urlPrefix, rootPath string, noCache bool) {
staticMu.Lock()
defer staticMu.Unlock()
var cacheDuration time.Duration
var compress bool
if noCache {
cacheDuration = 0
compress = false
} else {
cacheDuration = 3600 * time.Second
compress = true
}
fs := &fasthttp.FS{
Root: rootPath,
CompressRoot: rootPath + "/.cache",
IndexNames: []string{"index.html"},
GenerateIndexPages: false,
Compress: compress,
CompressBrotli: compress,
CompressZstd: compress,
CacheDuration: cacheDuration,
AcceptByteRange: true,
PathNotFound: func(ctx *fasthttp.RequestCtx) {
path := ctx.Path()
fmt.Printf("404 not found: %s\n", path)
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetBodyString("404 not found")
},
}
staticHandlers[urlPrefix] = fs
}
func StopAllServers() {
globalMu.Lock()
defer globalMu.Unlock()
// Start shutting down both in parallel
var wg sync.WaitGroup
// Close worker pool in parallel
if globalWorkerPool != nil {
wg.Add(1)
pool := globalWorkerPool
globalWorkerPool = nil
go func() {
defer wg.Done()
pool.Close()
}()
}
// Shutdown server with 100ms timeout
if globalServer != nil {
wg.Add(1)
server := globalServer
globalServer = nil
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := server.ShutdownWithContext(ctx); err != nil {
// Force close if graceful shutdown times out
server.CloseOnShutdown = true
_ = server.Shutdown()
}
}()
}
serverRunning = false
// Wait for both to complete
wg.Wait()
}

1241
modules/http/http.lua Normal file

File diff suppressed because it is too large Load Diff

346
modules/http/pool.go Normal file
View File

@ -0,0 +1,346 @@
package http
import (
"fmt"
"sync"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
type StateCreator func() (*luajit.State, error)
type Request struct {
Method string
Path string
Query map[string]string
Headers map[string]string
Body string
}
type Response struct {
StatusCode int
Headers map[string]string
Body string
}
type Worker struct {
state *luajit.State
id int
}
type WorkerPool struct {
workers chan *Worker
masterState *luajit.State
stateCreator StateCreator
workerCount int
closed bool
mu sync.RWMutex
}
var (
requestPool = sync.Pool{
New: func() any {
return &Request{
Query: make(map[string]string),
Headers: make(map[string]string),
}
},
}
responsePool = sync.Pool{
New: func() any {
return &Response{
Headers: make(map[string]string),
}
},
}
)
func GetRequest() *Request {
req := requestPool.Get().(*Request)
for k := range req.Query {
delete(req.Query, k)
}
for k := range req.Headers {
delete(req.Headers, k)
}
req.Method = ""
req.Path = ""
req.Body = ""
return req
}
func PutRequest(req *Request) {
requestPool.Put(req)
}
func GetResponse() *Response {
resp := responsePool.Get().(*Response)
for k := range resp.Headers {
delete(resp.Headers, k)
}
resp.StatusCode = 200
resp.Body = ""
return resp
}
func PutResponse(resp *Response) {
responsePool.Put(resp)
}
func NewWorkerPool(size int, masterState *luajit.State, stateCreator StateCreator) (*WorkerPool, error) {
pool := &WorkerPool{
workers: make(chan *Worker, size),
masterState: masterState,
stateCreator: stateCreator,
workerCount: size,
}
for i := range size {
worker, err := pool.createWorker(i)
if err != nil {
pool.Close()
return nil, fmt.Errorf("failed to create worker %d: %w", i, err)
}
pool.workers <- worker
}
return pool, nil
}
func (p *WorkerPool) createWorker(id int) (*Worker, error) {
workerState, err := p.stateCreator()
if err != nil {
return nil, fmt.Errorf("failed to create worker state: %w", err)
}
return &Worker{
state: workerState,
id: id,
}, nil
}
func (p *WorkerPool) Get() *Worker {
p.mu.RLock()
if p.closed {
p.mu.RUnlock()
return nil
}
p.mu.RUnlock()
select {
case worker := <-p.workers:
return worker
default:
worker, err := p.createWorker(-1)
if err != nil {
return nil
}
return worker
}
}
func (p *WorkerPool) Put(worker *Worker) {
if worker == nil {
return
}
p.mu.RLock()
defer p.mu.RUnlock()
if p.closed {
worker.Close()
return
}
if worker.id == -1 {
worker.Close()
return
}
select {
case p.workers <- worker:
default:
worker.Close()
}
}
func (p *WorkerPool) SyncRoutes(routesData any) {
p.mu.RLock()
defer p.mu.RUnlock()
if p.closed {
return
}
// Sync routes to all workers
workers := make([]*Worker, 0, p.workerCount)
// Collect all workers
for {
select {
case worker := <-p.workers:
workers = append(workers, worker)
default:
goto syncWorkers
}
}
syncWorkers:
// Sync and return workers
for _, worker := range workers {
if worker.state != nil {
worker.state.PushValue(routesData)
worker.state.SetGlobal("_http_routes_data")
worker.state.GetGlobal("_http_sync_worker_routes")
if worker.state.IsFunction(-1) {
worker.state.Call(0, 0)
} else {
worker.state.Pop(1)
}
}
select {
case p.workers <- worker:
default:
worker.Close()
}
}
}
func (p *WorkerPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return
}
p.closed = true
// Collect all workers first
workers := make([]*Worker, 0, len(p.workers))
close(p.workers)
for worker := range p.workers {
workers = append(workers, worker)
}
// Close all workers in parallel
var wg sync.WaitGroup
for _, worker := range workers {
wg.Add(1)
go func(w *Worker) {
defer wg.Done()
w.Close()
}(worker)
}
wg.Wait()
}
func (w *Worker) Close() {
if w.state != nil {
w.state.Close()
w.state = nil
}
}
func (w *Worker) HandleRequest(req *Request, resp *Response) error {
if w.state == nil {
return fmt.Errorf("worker state is nil")
}
// Create request table
w.state.NewTable()
w.state.PushString("method")
w.state.PushString(req.Method)
w.state.SetTable(-3)
w.state.PushString("path")
w.state.PushString(req.Path)
w.state.SetTable(-3)
w.state.PushString("body")
w.state.PushString(req.Body)
w.state.SetTable(-3)
// Query params
w.state.PushString("query")
w.state.NewTable()
for k, v := range req.Query {
w.state.PushString(k)
w.state.PushString(v)
w.state.SetTable(-3)
}
w.state.SetTable(-3)
// Headers
w.state.PushString("headers")
w.state.NewTable()
for k, v := range req.Headers {
w.state.PushString(k)
w.state.PushString(v)
w.state.SetTable(-3)
}
w.state.SetTable(-3)
// Create response table
w.state.NewTable()
w.state.PushString("status")
w.state.PushNumber(200)
w.state.SetTable(-3)
w.state.PushString("body")
w.state.PushString("")
w.state.SetTable(-3)
w.state.PushString("headers")
w.state.NewTable()
w.state.SetTable(-3)
// Call _http_handle_request(req, res) - pure Lua routing
w.state.GetGlobal("_http_handle_request")
if !w.state.IsFunction(-1) {
w.state.Pop(3)
resp.StatusCode = 500
resp.Body = "HTTP handler not initialized"
return nil
}
w.state.PushCopy(-3) // request
w.state.PushCopy(-3) // response
if err := w.state.Call(2, 0); err != nil {
w.state.Pop(2)
resp.StatusCode = 500
resp.Body = fmt.Sprintf("Handler error: %v", err)
return nil
}
// Extract response
w.state.GetField(-1, "status")
if w.state.IsNumber(-1) {
resp.StatusCode = int(w.state.ToNumber(-1))
}
w.state.Pop(1)
w.state.GetField(-1, "body")
if w.state.IsString(-1) {
resp.Body = w.state.ToString(-1)
}
w.state.Pop(1)
w.state.GetField(-1, "headers")
if w.state.IsTable(-1) {
w.state.PushNil()
for w.state.Next(-2) {
if w.state.IsString(-2) && w.state.IsString(-1) {
resp.Headers[w.state.ToString(-2)] = w.state.ToString(-1)
}
w.state.Pop(1)
}
}
w.state.Pop(1)
w.state.Pop(2) // Clean up request and response tables
return nil
}

598
modules/json+/json.lua Normal file
View File

@ -0,0 +1,598 @@
json = {}
function json.encode(value)
local buffer = {}
local pos = 1
local function encode_string(s)
buffer[pos] = '"'
pos = pos + 1
local start = 1
for i = 1, #s do
local c = s:byte(i)
if c == 34 then -- "
if i > start then
buffer[pos] = s:sub(start, i - 1)
pos = pos + 1
end
buffer[pos] = '\\"'
pos = pos + 1
start = i + 1
elseif c == 92 then -- \
if i > start then
buffer[pos] = s:sub(start, i - 1)
pos = pos + 1
end
buffer[pos] = '\\\\'
pos = pos + 1
start = i + 1
elseif c < 32 then
if i > start then
buffer[pos] = s:sub(start, i - 1)
pos = pos + 1
end
if c == 8 then
buffer[pos] = '\\b'
elseif c == 9 then
buffer[pos] = '\\t'
elseif c == 10 then
buffer[pos] = '\\n'
elseif c == 12 then
buffer[pos] = '\\f'
elseif c == 13 then
buffer[pos] = '\\r'
else
buffer[pos] = ('\\u%04x'):format(c)
end
pos = pos + 1
start = i + 1
end
end
if start <= #s then
buffer[pos] = s:sub(start)
pos = pos + 1
end
buffer[pos] = '"'
pos = pos + 1
end
local function encode_value(v, depth)
local t = type(v)
if t == 'string' then
encode_string(v)
elseif t == 'number' then
if v ~= v then -- NaN
buffer[pos] = 'null'
elseif v == 1/0 or v == -1/0 then -- Infinity
buffer[pos] = 'null'
else
buffer[pos] = tostring(v)
end
pos = pos + 1
elseif t == 'boolean' then
buffer[pos] = v and 'true' or 'false'
pos = pos + 1
elseif t == 'table' then
if depth > 100 then error('circular reference') end
local is_array = true
local max_index = 0
local count = 0
for k, _ in pairs(v) do
count = count + 1
if type(k) ~= 'number' or k <= 0 or k % 1 ~= 0 then
is_array = false
break
end
if k > max_index then max_index = k end
end
if is_array and count == max_index then
buffer[pos] = '['
pos = pos + 1
for i = 1, max_index do
if i > 1 then
buffer[pos] = ','
pos = pos + 1
end
encode_value(v[i], depth + 1)
end
buffer[pos] = ']'
pos = pos + 1
else
buffer[pos] = '{'
pos = pos + 1
local first = true
for k, val in pairs(v) do
if not first then
buffer[pos] = ','
pos = pos + 1
end
first = false
encode_string(tostring(k))
buffer[pos] = ':'
pos = pos + 1
encode_value(val, depth + 1)
end
buffer[pos] = '}'
pos = pos + 1
end
else
buffer[pos] = 'null'
pos = pos + 1
end
end
encode_value(value, 0)
return table.concat(buffer)
end
function json.decode(str)
local pos = 1
local len = #str
local function skip_whitespace()
while pos <= len do
local c = str:byte(pos)
if c ~= 32 and c ~= 9 and c ~= 10 and c ~= 13 then break end
pos = pos + 1
end
end
local function decode_string()
local start = pos + 1
pos = pos + 1
while pos <= len do
local c = str:byte(pos)
if c == 34 then -- "
local result = str:sub(start, pos - 1)
pos = pos + 1
if result:find('\\') then
result = result:gsub('\\(.)', {
['"'] = '"',
['\\'] = '\\',
['/'] = '/',
['b'] = '\b',
['f'] = '\f',
['n'] = '\n',
['r'] = '\r',
['t'] = '\t'
})
result = result:gsub('\\u(%x%x%x%x)', function(hex)
return string.char(tonumber(hex, 16))
end)
end
return result
elseif c == 92 then -- \
pos = pos + 2
else
pos = pos + 1
end
end
error('unterminated string')
end
local function decode_number()
local start = pos
local c = str:byte(pos)
if c == 45 then pos = pos + 1 end -- -
c = str:byte(pos)
if not c or c < 48 or c > 57 then error('invalid number') end
if c == 48 then
pos = pos + 1
else
while pos <= len do
c = str:byte(pos)
if c < 48 or c > 57 then break end
pos = pos + 1
end
end
if pos <= len and str:byte(pos) == 46 then -- .
pos = pos + 1
local found_digit = false
while pos <= len do
c = str:byte(pos)
if c < 48 or c > 57 then break end
found_digit = true
pos = pos + 1
end
if not found_digit then error('invalid number') end
end
if pos <= len then
c = str:byte(pos)
if c == 101 or c == 69 then -- e or E
pos = pos + 1
if pos <= len then
c = str:byte(pos)
if c == 43 or c == 45 then pos = pos + 1 end -- + or -
end
local found_digit = false
while pos <= len do
c = str:byte(pos)
if c < 48 or c > 57 then break end
found_digit = true
pos = pos + 1
end
if not found_digit then error('invalid number') end
end
end
return tonumber(str:sub(start, pos - 1))
end
local function decode_value()
skip_whitespace()
if pos > len then error('unexpected end') end
local c = str:byte(pos)
if c == 34 then -- "
return decode_string()
elseif c == 123 then -- {
local result = {}
pos = pos + 1
skip_whitespace()
if pos <= len and str:byte(pos) == 125 then -- }
pos = pos + 1
return result
end
while true do
skip_whitespace()
if pos > len or str:byte(pos) ~= 34 then error('expected string key') end
local key = decode_string()
skip_whitespace()
if pos > len or str:byte(pos) ~= 58 then error('expected :') end
pos = pos + 1
result[key] = decode_value()
skip_whitespace()
if pos > len then error('unexpected end') end
c = str:byte(pos)
if c == 125 then -- }
pos = pos + 1
return result
elseif c == 44 then -- ,
pos = pos + 1
else
error('expected , or }')
end
end
elseif c == 91 then -- [
local result = {}
local index = 1
pos = pos + 1
skip_whitespace()
if pos <= len and str:byte(pos) == 93 then -- ]
pos = pos + 1
return result
end
while true do
result[index] = decode_value()
index = index + 1
skip_whitespace()
if pos > len then error('unexpected end') end
c = str:byte(pos)
if c == 93 then -- ]
pos = pos + 1
return result
elseif c == 44 then -- ,
pos = pos + 1
else
error('expected , or ]')
end
end
elseif c == 116 then -- true
if str:sub(pos, pos + 3) == 'true' then
pos = pos + 4
return true
end
error('invalid literal')
elseif c == 102 then -- false
if str:sub(pos, pos + 4) == 'false' then
pos = pos + 5
return false
end
error('invalid literal')
elseif c == 110 then -- null
if str:sub(pos, pos + 3) == 'null' then
pos = pos + 4
return nil
end
error('invalid literal')
elseif (c >= 48 and c <= 57) or c == 45 then -- 0-9 or -
return decode_number()
else
error('unexpected character')
end
end
local result = decode_value()
skip_whitespace()
if pos <= len then error('unexpected content after JSON') end
return result
end
function json.load_file(filename)
local file = io.open(filename, "r")
if not file then
error("Cannot open file: " .. filename)
end
local content = file:read("*all")
file:close()
return json.decode(content)
end
function json.save_file(filename, data)
local file = io.open(filename, "w")
if not file then
error("Cannot write to file: " .. filename)
end
file:write(json.encode(data))
file:close()
end
function json.merge(...)
local result = {}
local n = select("#", ...)
for i = 1, n do
local obj = select(i, ...)
if type(obj) == "table" then
for k, v in pairs(obj) do
result[k] = v
end
end
end
return result
end
function json.extract(data, path)
local current = data
local start = 1
local len = #path
while start <= len do
local dot_pos = path:find(".", start, true)
local part = dot_pos and path:sub(start, dot_pos - 1) or path:sub(start)
if type(current) ~= "table" then
return nil
end
local bracket_start, bracket_end = part:find("^%[(%d+)%]$")
if bracket_start then
local index = tonumber(part:sub(2, -2)) + 1
current = current[index]
else
current = current[part]
end
if current == nil then
return nil
end
start = dot_pos and dot_pos + 1 or len + 1
end
return current
end
function json.pretty(value, indent)
local buffer = {}
local pos = 1
indent = indent or " "
local function encode_string(s)
buffer[pos] = '"'
pos = pos + 1
local start = 1
for i = 1, #s do
local c = s:byte(i)
if c == 34 then -- "
if i > start then
buffer[pos] = s:sub(start, i - 1)
pos = pos + 1
end
buffer[pos] = '\\"'
pos = pos + 1
start = i + 1
elseif c == 92 then -- \
if i > start then
buffer[pos] = s:sub(start, i - 1)
pos = pos + 1
end
buffer[pos] = '\\\\'
pos = pos + 1
start = i + 1
elseif c < 32 then
if i > start then
buffer[pos] = s:sub(start, i - 1)
pos = pos + 1
end
if c == 8 then
buffer[pos] = '\\b'
elseif c == 9 then
buffer[pos] = '\\t'
elseif c == 10 then
buffer[pos] = '\\n'
elseif c == 12 then
buffer[pos] = '\\f'
elseif c == 13 then
buffer[pos] = '\\r'
else
buffer[pos] = ('\\u%04x'):format(c)
end
pos = pos + 1
start = i + 1
end
end
if start <= #s then
buffer[pos] = s:sub(start)
pos = pos + 1
end
buffer[pos] = '"'
pos = pos + 1
end
local function encode_value(v, depth)
local t = type(v)
local current_indent = string.rep(indent, depth)
local next_indent = string.rep(indent, depth + 1)
if t == 'string' then
encode_string(v)
elseif t == 'number' then
if v ~= v then -- NaN
buffer[pos] = 'null'
elseif v == 1/0 or v == -1/0 then -- Infinity
buffer[pos] = 'null'
else
buffer[pos] = tostring(v)
end
pos = pos + 1
elseif t == 'boolean' then
buffer[pos] = v and 'true' or 'false'
pos = pos + 1
elseif t == 'table' then
if depth > 100 then error('circular reference') end
local is_array = true
local max_index = 0
local count = 0
for k, _ in pairs(v) do
count = count + 1
if type(k) ~= 'number' or k <= 0 or k % 1 ~= 0 then
is_array = false
break
end
if k > max_index then max_index = k end
end
if is_array and count == max_index then
buffer[pos] = '[\n'
pos = pos + 1
for i = 1, max_index do
buffer[pos] = next_indent
pos = pos + 1
encode_value(v[i], depth + 1)
if i < max_index then
buffer[pos] = ','
pos = pos + 1
end
buffer[pos] = '\n'
pos = pos + 1
end
buffer[pos] = current_indent .. ']'
pos = pos + 1
else
buffer[pos] = '{\n'
pos = pos + 1
local keys = {}
for k in pairs(v) do
keys[#keys + 1] = k
end
for i, k in ipairs(keys) do
buffer[pos] = next_indent
pos = pos + 1
encode_string(tostring(k))
buffer[pos] = ': '
pos = pos + 1
encode_value(v[k], depth + 1)
if i < #keys then
buffer[pos] = ','
pos = pos + 1
end
buffer[pos] = '\n'
pos = pos + 1
end
buffer[pos] = current_indent .. '}'
pos = pos + 1
end
else
buffer[pos] = 'null'
pos = pos + 1
end
end
encode_value(value, 0)
return table.concat(buffer)
end
function json.validate(data, schema)
local function validate_value(value, schema_value)
local value_type = type(value)
local schema_type = schema_value.type
if schema_type and value_type ~= schema_type then
return false, "Expected " .. schema_type .. ", got " .. value_type
end
if schema_type == "table" and schema_value.properties then
local required = schema_value.required
for prop, prop_schema in pairs(schema_value.properties) do
local prop_value = value[prop]
if required and required[prop] and prop_value == nil then
return false, "Missing required property: " .. prop
end
if prop_value ~= nil then
local valid, err = validate_value(prop_value, prop_schema)
if not valid then
return false, "Property " .. prop .. ": " .. err
end
end
end
end
return true
end
return validate_value(data, schema)
end

505
modules/kv/kv.go Normal file
View File

@ -0,0 +1,505 @@
package kv
import (
"bufio"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/goccy/go-json"
)
var (
stores = make(map[string]*Store)
mutex sync.RWMutex
)
type Store struct {
data map[string]string
expires map[string]int64
filename string
mutex sync.RWMutex
}
func GetFunctionList() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"kv_open": kv_open,
"kv_get": kv_get,
"kv_set": kv_set,
"kv_delete": kv_delete,
"kv_clear": kv_clear,
"kv_has": kv_has,
"kv_size": kv_size,
"kv_keys": kv_keys,
"kv_values": kv_values,
"kv_save": kv_save,
"kv_close": kv_close,
"kv_increment": kv_increment,
"kv_append": kv_append,
"kv_expire": kv_expire,
"kv_cleanup_expired": kv_cleanup_expired,
}
}
// kv_open(name, filename) -> boolean
func kv_open(s *luajit.State) int {
name := s.ToString(1)
filename := s.ToString(2)
if name == "" {
s.PushBoolean(false)
return 1
}
mutex.Lock()
defer mutex.Unlock()
if store, exists := stores[name]; exists {
if filename != "" && store.filename != filename {
store.filename = filename
}
s.PushBoolean(true)
return 1
}
store := &Store{
data: make(map[string]string),
expires: make(map[string]int64),
filename: filename,
}
if filename != "" {
store.load()
}
stores[name] = store
s.PushBoolean(true)
return 1
}
// kv_get(name, key, default) -> value or default
func kv_get(s *luajit.State) int {
name := s.ToString(1)
key := s.ToString(2)
hasDefault := s.GetTop() >= 3
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
if hasDefault {
s.PushCopy(3)
} else {
s.PushNil()
}
return 1
}
store.mutex.RLock()
value, found := store.data[key]
store.mutex.RUnlock()
if found {
s.PushString(value)
} else if hasDefault {
s.PushCopy(3)
} else {
s.PushNil()
}
return 1
}
// kv_set(name, key, value) -> boolean
func kv_set(s *luajit.State) int {
name := s.ToString(1)
key := s.ToString(2)
value := s.ToString(3)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.PushBoolean(false)
return 1
}
store.mutex.Lock()
store.data[key] = value
store.mutex.Unlock()
s.PushBoolean(true)
return 1
}
// kv_delete(name, key) -> boolean
func kv_delete(s *luajit.State) int {
name := s.ToString(1)
key := s.ToString(2)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.PushBoolean(false)
return 1
}
store.mutex.Lock()
_, existed := store.data[key]
delete(store.data, key)
delete(store.expires, key)
store.mutex.Unlock()
s.PushBoolean(existed)
return 1
}
// kv_clear(name) -> boolean
func kv_clear(s *luajit.State) int {
name := s.ToString(1)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.PushBoolean(false)
return 1
}
store.mutex.Lock()
store.data = make(map[string]string)
store.expires = make(map[string]int64)
store.mutex.Unlock()
s.PushBoolean(true)
return 1
}
// kv_has(name, key) -> boolean
func kv_has(s *luajit.State) int {
name := s.ToString(1)
key := s.ToString(2)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.PushBoolean(false)
return 1
}
store.mutex.RLock()
_, found := store.data[key]
store.mutex.RUnlock()
s.PushBoolean(found)
return 1
}
// kv_size(name) -> number
func kv_size(s *luajit.State) int {
name := s.ToString(1)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.PushNumber(0)
return 1
}
store.mutex.RLock()
size := len(store.data)
store.mutex.RUnlock()
s.PushNumber(float64(size))
return 1
}
// kv_keys(name) -> table
func kv_keys(s *luajit.State) int {
name := s.ToString(1)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.NewTable()
return 1
}
store.mutex.RLock()
keys := make([]string, 0, len(store.data))
for k := range store.data {
keys = append(keys, k)
}
store.mutex.RUnlock()
s.PushValue(keys)
return 1
}
// kv_values(name) -> table
func kv_values(s *luajit.State) int {
name := s.ToString(1)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.NewTable()
return 1
}
store.mutex.RLock()
values := make([]string, 0, len(store.data))
for _, v := range store.data {
values = append(values, v)
}
store.mutex.RUnlock()
s.PushValue(values)
return 1
}
// kv_save(name) -> boolean
func kv_save(s *luajit.State) int {
name := s.ToString(1)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists || store.filename == "" {
s.PushBoolean(false)
return 1
}
err := store.save()
s.PushBoolean(err == nil)
return 1
}
// kv_close(name) -> boolean
func kv_close(s *luajit.State) int {
name := s.ToString(1)
mutex.Lock()
defer mutex.Unlock()
store, exists := stores[name]
if !exists {
s.PushBoolean(false)
return 1
}
if store.filename != "" {
store.save()
}
delete(stores, name)
s.PushBoolean(true)
return 1
}
// kv_increment(name, key, delta) -> number
func kv_increment(s *luajit.State) int {
name := s.ToString(1)
key := s.ToString(2)
delta := 1.0
if s.GetTop() >= 3 {
delta = s.ToNumber(3)
}
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.PushNumber(0)
return 1
}
store.mutex.Lock()
current, _ := strconv.ParseFloat(store.data[key], 64)
newValue := current + delta
store.data[key] = strconv.FormatFloat(newValue, 'g', -1, 64)
store.mutex.Unlock()
s.PushNumber(newValue)
return 1
}
// kv_append(name, key, value, separator) -> boolean
func kv_append(s *luajit.State) int {
name := s.ToString(1)
key := s.ToString(2)
value := s.ToString(3)
separator := ""
if s.GetTop() >= 4 {
separator = s.ToString(4)
}
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.PushBoolean(false)
return 1
}
store.mutex.Lock()
current := store.data[key]
if current == "" {
store.data[key] = value
} else {
store.data[key] = current + separator + value
}
store.mutex.Unlock()
s.PushBoolean(true)
return 1
}
// kv_expire(name, key, ttl) -> boolean
func kv_expire(s *luajit.State) int {
name := s.ToString(1)
key := s.ToString(2)
ttl := s.ToNumber(3)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.PushBoolean(false)
return 1
}
store.mutex.Lock()
store.expires[key] = time.Now().Unix() + int64(ttl)
store.mutex.Unlock()
s.PushBoolean(true)
return 1
}
// kv_cleanup_expired(name) -> number
func kv_cleanup_expired(s *luajit.State) int {
name := s.ToString(1)
mutex.RLock()
store, exists := stores[name]
mutex.RUnlock()
if !exists {
s.PushNumber(0)
return 1
}
currentTime := time.Now().Unix()
deleted := 0
store.mutex.Lock()
for key, expireTime := range store.expires {
if currentTime >= expireTime {
delete(store.data, key)
delete(store.expires, key)
deleted++
}
}
store.mutex.Unlock()
s.PushNumber(float64(deleted))
return 1
}
func (store *Store) load() error {
if store.filename == "" {
return nil
}
file, err := os.Open(store.filename)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
defer file.Close()
if strings.HasSuffix(store.filename, ".json") {
decoder := json.NewDecoder(file)
return decoder.Decode(&store.data)
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
store.data[parts[0]] = parts[1]
}
}
return scanner.Err()
}
func (store *Store) save() error {
if store.filename == "" {
return fmt.Errorf("no filename specified")
}
file, err := os.Create(store.filename)
if err != nil {
return err
}
defer file.Close()
store.mutex.RLock()
defer store.mutex.RUnlock()
if strings.HasSuffix(store.filename, ".json") {
encoder := json.NewEncoder(file)
encoder.SetIndent("", "\t")
return encoder.Encode(store.data)
}
for key, value := range store.data {
if _, err := fmt.Fprintf(file, "%s=%s\n", key, value); err != nil {
return err
}
}
return nil
}
// CloseAllStores saves and closes all open stores
func CloseAllStores() {
mutex.Lock()
defer mutex.Unlock()
for name, store := range stores {
if store.filename != "" {
store.save()
}
delete(stores, name)
}
}

196
modules/kv/kv.lua Normal file
View File

@ -0,0 +1,196 @@
local kv = {}
-- ======================================================================
-- BASIC KEY-VALUE OPERATIONS
-- ======================================================================
function kv.open(name, filename)
if type(name) ~= "string" then error("kv.open: store name must be a string", 2) end
if filename ~= nil and type(filename) ~= "string" then error("kv.open: filename must be a string", 2) end
filename = filename or ""
return moonshark.kv_open(name, filename)
end
function kv.get(name, key, default)
if type(name) ~= "string" then error("kv.get: store name must be a string", 2) end
if type(key) ~= "string" then error("kv.get: key must be a string", 2) end
if default ~= nil then
return moonshark.kv_get(name, key, default)
else
return moonshark.kv_get(name, key)
end
end
function kv.set(name, key, value)
if type(name) ~= "string" then error("kv.set: store name must be a string", 2) end
if type(key) ~= "string" then error("kv.set: key must be a string", 2) end
if type(value) ~= "string" then error("kv.set: value must be a string", 2) end
return moonshark.kv_set(name, key, value)
end
function kv.delete(name, key)
if type(name) ~= "string" then error("kv.delete: store name must be a string", 2) end
if type(key) ~= "string" then error("kv.delete: key must be a string", 2) end
return moonshark.kv_delete(name, key)
end
function kv.clear(name)
if type(name) ~= "string" then error("kv.clear: store name must be a string", 2) end
return moonshark.kv_clear(name)
end
function kv.has(name, key)
if type(name) ~= "string" then error("kv.has: store name must be a string", 2) end
if type(key) ~= "string" then error("kv.has: key must be a string", 2) end
return moonshark.kv_has(name, key)
end
function kv.size(name)
if type(name) ~= "string" then error("kv.size: store name must be a string", 2) end
return moonshark.kv_size(name)
end
function kv.keys(name)
if type(name) ~= "string" then error("kv.keys: store name must be a string", 2) end
return moonshark.kv_keys(name)
end
function kv.values(name)
if type(name) ~= "string" then error("kv.values: store name must be a string", 2) end
return moonshark.kv_values(name)
end
function kv.save(name)
if type(name) ~= "string" then error("kv.save: store name must be a string", 2) end
return moonshark.kv_save(name)
end
function kv.close(name)
if type(name) ~= "string" then error("kv.close: store name must be a string", 2) end
return moonshark.kv_close(name)
end
-- ======================================================================
-- UTILITY FUNCTIONS
-- ======================================================================
function kv.increment(name, key, delta)
if type(name) ~= "string" then error("kv.increment: store name must be a string", 2) end
if type(key) ~= "string" then error("kv.increment: key must be a string", 2) end
delta = delta or 1
if type(delta) ~= "number" then error("kv.increment: delta must be a number", 2) end
return moonshark.kv_increment(name, key, delta)
end
function kv.append(name, key, value, separator)
if type(name) ~= "string" then error("kv.append: store name must be a string", 2) end
if type(key) ~= "string" then error("kv.append: key must be a string", 2) end
if type(value) ~= "string" then error("kv.append: value must be a string", 2) end
separator = separator or ""
if type(separator) ~= "string" then error("kv.append: separator must be a string", 2) end
return moonshark.kv_append(name, key, value, separator)
end
function kv.expire(name, key, ttl)
if type(name) ~= "string" then error("kv.expire: store name must be a string", 2) end
if type(key) ~= "string" then error("kv.expire: key must be a string", 2) end
if type(ttl) ~= "number" or ttl <= 0 then error("kv.expire: TTL must be a positive number", 2) end
return moonshark.kv_expire(name, key, ttl)
end
function kv.cleanup_expired(name)
if type(name) ~= "string" then error("kv.cleanup_expired: store name must be a string", 2) end
return moonshark.kv_cleanup_expired(name)
end
-- ======================================================================
-- OBJECT-ORIENTED INTERFACE
-- ======================================================================
local Store = {}
Store.__index = Store
function kv.create(name, filename)
if type(name) ~= "string" then error("kv.create: store name must be a string", 2) end
if filename ~= nil and type(filename) ~= "string" then error("kv.create: filename must be a string", 2) end
local success = kv.open(name, filename)
if not success then
error("kv.create: failed to open store '" .. name .. "'", 2)
end
return setmetatable({name = name}, Store)
end
function Store:get(key, default)
return kv.get(self.name, key, default)
end
function Store:set(key, value)
return kv.set(self.name, key, value)
end
function Store:delete(key)
return kv.delete(self.name, key)
end
function Store:clear()
return kv.clear(self.name)
end
function Store:has(key)
return kv.has(self.name, key)
end
function Store:size()
return kv.size(self.name)
end
function Store:keys()
return kv.keys(self.name)
end
function Store:values()
return kv.values(self.name)
end
function Store:save()
return kv.save(self.name)
end
function Store:close()
return kv.close(self.name)
end
function Store:increment(key, delta)
return kv.increment(self.name, key, delta)
end
function Store:append(key, value, separator)
return kv.append(self.name, key, value, separator)
end
function Store:expire(key, ttl)
return kv.expire(self.name, key, ttl)
end
function Store:cleanup_expired()
return kv.cleanup_expired(self.name)
end
return kv

732
modules/math+/math.lua Normal file
View File

@ -0,0 +1,732 @@
-- ======================================================================
-- ENHANCED CONSTANTS (higher precision)
-- ======================================================================
math.pi = 3.14159265358979323846 -- Replace with higher precision
math.tau = 6.28318530717958647693 -- 2*pi, useful for full rotations
math.e = 2.71828182845904523536
math.phi = 1.61803398874989484820 -- Golden ratio (1 + sqrt(5)) / 2
math.sqrt2 = 1.41421356237309504880
math.sqrt3 = 1.73205080756887729353
math.ln2 = 0.69314718055994530942 -- Natural log of 2
math.ln10 = 2.30258509299404568402 -- Natural log of 10
math.infinity = 1/0
math.nan = 0/0
-- ======================================================================
-- EXTENDED FUNCTIONS
-- ======================================================================
-- Cube root that handles negative numbers correctly
function math.cbrt(x)
return x < 0 and -(-x)^(1/3) or x^(1/3)
end
-- Euclidean distance - more accurate than sqrt(x*x + y*y)
function math.hypot(x, y)
return math.sqrt(x * x + y * y)
end
-- IEEE 754 NaN check
function math.isnan(x)
return x ~= x
end
-- Check if number is finite
function math.isfinite(x)
return x > -math.infinity and x < math.infinity
end
-- Mathematical sign function
function math.sign(x)
return x > 0 and 1 or (x < 0 and -1 or 0)
end
-- Constrain value to range
function math.clamp(x, min, max)
return x < min and min or (x > max and max or x)
end
-- Linear interpolation
function math.lerp(a, b, t)
return a + (b - a) * t
end
-- Smooth interpolation using Hermite polynomial
function math.smoothstep(a, b, t)
t = math.clamp((t - a) / (b - a), 0, 1)
return t * t * (3 - 2 * t)
end
-- Map value from input range to output range
function math.map(x, in_min, in_max, out_min, out_max)
return (x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min
end
-- Round to nearest integer
function math.round(x)
return x >= 0 and math.floor(x + 0.5) or math.ceil(x - 0.5)
end
-- Round to specified decimal places
function math.roundto(x, decimals)
local mult = 10 ^ (decimals or 0)
return math.floor(x * mult + 0.5) / mult
end
-- Normalize angle to [-π, π] range
function math.normalize_angle(angle)
return angle - 2 * math.pi * math.floor((angle + math.pi) / (2 * math.pi))
end
-- 2D Euclidean distance
function math.distance(x1, y1, x2, y2)
local dx, dy = x2 - x1, y2 - y1
return math.sqrt(dx * dx + dy * dy)
end
-- Factorial with bounds checking
function math.factorial(n)
if n < 0 or n ~= math.floor(n) or n > 170 then
return nil
end
local result = 1
for i = 2, n do
result = result * i
end
return result
end
-- Greatest common divisor using Euclidean algorithm
function math.gcd(a, b)
a, b = math.floor(math.abs(a)), math.floor(math.abs(b))
while b ~= 0 do
a, b = b, a % b
end
return a
end
-- Least common multiple
function math.lcm(a, b)
if a == 0 or b == 0 then return 0 end
return math.abs(a * b) / math.gcd(a, b)
end
-- ======================================================================
-- ENHANCED RANDOM FUNCTIONS
-- ======================================================================
-- Random float in range
function math.randomf(min, max)
if not min and not max then
return math.random()
elseif not max then
max = min
min = 0
end
return min + math.random() * (max - min)
end
-- Random integer in range
function math.randint(min, max)
if not max then
max = min
min = 1
end
return math.floor(math.random() * (max - min + 1) + min)
end
-- Random boolean with probability
function math.randboolean(p)
p = p or 0.5
return math.random() < p
end
-- ======================================================================
-- STATISTICS FUNCTIONS
-- ======================================================================
function math.sum(t)
if type(t) ~= "table" then return 0 end
local sum = 0
for i=1, #t do
if type(t[i]) == "number" then
sum = sum + t[i]
end
end
return sum
end
function math.mean(t)
if type(t) ~= "table" or #t == 0 then return 0 end
local sum = 0
local count = 0
for i=1, #t do
if type(t[i]) == "number" then
sum = sum + t[i]
count = count + 1
end
end
return count > 0 and sum / count or 0
end
function math.median(t)
if type(t) ~= "table" or #t == 0 then return 0 end
local nums = {}
local count = 0
for i=1, #t do
if type(t[i]) == "number" then
count = count + 1
nums[count] = t[i]
end
end
if count == 0 then return 0 end
table.sort(nums)
if count % 2 == 0 then
return (nums[count/2] + nums[count/2 + 1]) / 2
else
return nums[math.ceil(count/2)]
end
end
function math.variance(t)
if type(t) ~= "table" then return 0 end
local count = 0
local m = math.mean(t)
local sum = 0
for i=1, #t do
if type(t[i]) == "number" then
local dev = t[i] - m
sum = sum + dev * dev
count = count + 1
end
end
return count > 1 and sum / count or 0
end
function math.stdev(t)
return math.sqrt(math.variance(t))
end
function math.pvariance(t)
if type(t) ~= "table" then return 0 end
local count = 0
local m = math.mean(t)
local sum = 0
for i=1, #t do
if type(t[i]) == "number" then
local dev = t[i] - m
sum = sum + dev * dev
count = count + 1
end
end
return count > 0 and sum / count or 0
end
function math.pstdev(t)
return math.sqrt(math.pvariance(t))
end
function math.mode(t)
if type(t) ~= "table" or #t == 0 then return nil end
local counts = {}
local most_frequent = nil
local max_count = 0
for i=1, #t do
local v = t[i]
counts[v] = (counts[v] or 0) + 1
if counts[v] > max_count then
max_count = counts[v]
most_frequent = v
end
end
return most_frequent
end
function math.minmax(t)
if type(t) ~= "table" or #t == 0 then return nil, nil end
local min, max
for i=1, #t do
if type(t[i]) == "number" then
min = t[i]
max = t[i]
break
end
end
if min == nil then return nil, nil end
for i=1, #t do
if type(t[i]) == "number" then
if t[i] < min then min = t[i] end
if t[i] > max then max = t[i] end
end
end
return min, max
end
-- ======================================================================
-- 2D VECTOR OPERATIONS
-- ======================================================================
math.vec2 = {
new = function(x, y)
return {x = x or 0, y = y or 0}
end,
copy = function(v)
return {x = v.x, y = v.y}
end,
add = function(a, b)
return {x = a.x + b.x, y = a.y + b.y}
end,
sub = function(a, b)
return {x = a.x - b.x, y = a.y - b.y}
end,
mul = function(a, b)
if type(b) == "number" then
return {x = a.x * b, y = a.y * b}
end
return {x = a.x * b.x, y = a.y * b.y}
end,
div = function(a, b)
if type(b) == "number" then
local inv = 1 / b
return {x = a.x * inv, y = a.y * inv}
end
return {x = a.x / b.x, y = a.y / b.y}
end,
dot = function(a, b)
return a.x * b.x + a.y * b.y
end,
length = function(v)
return math.sqrt(v.x * v.x + v.y * v.y)
end,
length_squared = function(v)
return v.x * v.x + v.y * v.y
end,
distance = function(a, b)
local dx, dy = b.x - a.x, b.y - a.y
return math.sqrt(dx * dx + dy * dy)
end,
distance_squared = function(a, b)
local dx, dy = b.x - a.x, b.y - a.y
return dx * dx + dy * dy
end,
normalize = function(v)
local len = math.sqrt(v.x * v.x + v.y * v.y)
if len > 1e-10 then
local inv_len = 1 / len
return {x = v.x * inv_len, y = v.y * inv_len}
end
return {x = 0, y = 0}
end,
rotate = function(v, angle)
local c, s = math.cos(angle), math.sin(angle)
return {
x = v.x * c - v.y * s,
y = v.x * s + v.y * c
}
end,
angle = function(v)
return math.atan2(v.y, v.x)
end,
lerp = function(a, b, t)
t = math.clamp(t, 0, 1)
return {
x = a.x + (b.x - a.x) * t,
y = a.y + (b.y - a.y) * t
}
end,
reflect = function(v, normal)
local dot = v.x * normal.x + v.y * normal.y
return {
x = v.x - 2 * dot * normal.x,
y = v.y - 2 * dot * normal.y
}
end
}
-- ======================================================================
-- 3D VECTOR OPERATIONS
-- ======================================================================
math.vec3 = {
new = function(x, y, z)
return {x = x or 0, y = y or 0, z = z or 0}
end,
copy = function(v)
return {x = v.x, y = v.y, z = v.z}
end,
add = function(a, b)
return {x = a.x + b.x, y = a.y + b.y, z = a.z + b.z}
end,
sub = function(a, b)
return {x = a.x - b.x, y = a.y - b.y, z = a.z - b.z}
end,
mul = function(a, b)
if type(b) == "number" then
return {x = a.x * b, y = a.y * b, z = a.z * b}
end
return {x = a.x * b.x, y = a.y * b.y, z = a.z * b.z}
end,
div = function(a, b)
if type(b) == "number" then
local inv = 1 / b
return {x = a.x * inv, y = a.y * inv, z = a.z * inv}
end
return {x = a.x / b.x, y = a.y / b.y, z = a.z / b.z}
end,
dot = function(a, b)
return a.x * b.x + a.y * b.y + a.z * b.z
end,
cross = function(a, b)
return {
x = a.y * b.z - a.z * b.y,
y = a.z * b.x - a.x * b.z,
z = a.x * b.y - a.y * b.x
}
end,
length = function(v)
return math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
end,
length_squared = function(v)
return v.x * v.x + v.y * v.y + v.z * v.z
end,
distance = function(a, b)
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
return math.sqrt(dx * dx + dy * dy + dz * dz)
end,
distance_squared = function(a, b)
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
return dx * dx + dy * dy + dz * dz
end,
normalize = function(v)
local len = math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
if len > 1e-10 then
local inv_len = 1 / len
return {x = v.x * inv_len, y = v.y * inv_len, z = v.z * inv_len}
end
return {x = 0, y = 0, z = 0}
end,
lerp = function(a, b, t)
t = math.clamp(t, 0, 1)
return {
x = a.x + (b.x - a.x) * t,
y = a.y + (b.y - a.y) * t,
z = a.z + (b.z - a.z) * t
}
end,
reflect = function(v, normal)
local dot = v.x * normal.x + v.y * normal.y + v.z * normal.z
return {
x = v.x - 2 * dot * normal.x,
y = v.y - 2 * dot * normal.y,
z = v.z - 2 * dot * normal.z
}
end
}
-- ======================================================================
-- MATRIX OPERATIONS
-- ======================================================================
math.mat2 = {
new = function(a, b, c, d)
return {
{a or 1, b or 0},
{c or 0, d or 1}
}
end,
identity = function()
return {{1, 0}, {0, 1}}
end,
mul = function(a, b)
return {
{
a[1][1] * b[1][1] + a[1][2] * b[2][1],
a[1][1] * b[1][2] + a[1][2] * b[2][2]
},
{
a[2][1] * b[1][1] + a[2][2] * b[2][1],
a[2][1] * b[1][2] + a[2][2] * b[2][2]
}
}
end,
det = function(m)
return m[1][1] * m[2][2] - m[1][2] * m[2][1]
end,
inverse = function(m)
local det = m[1][1] * m[2][2] - m[1][2] * m[2][1]
if math.abs(det) < 1e-10 then
return nil
end
local inv_det = 1 / det
return {
{m[2][2] * inv_det, -m[1][2] * inv_det},
{-m[2][1] * inv_det, m[1][1] * inv_det}
}
end,
rotation = function(angle)
local cos, sin = math.cos(angle), math.sin(angle)
return {
{cos, -sin},
{sin, cos}
}
end,
transform = function(m, v)
return {
x = m[1][1] * v.x + m[1][2] * v.y,
y = m[2][1] * v.x + m[2][2] * v.y
}
end,
scale = function(sx, sy)
sy = sy or sx
return {
{sx, 0},
{0, sy}
}
end
}
math.mat3 = {
identity = function()
return {
{1, 0, 0},
{0, 1, 0},
{0, 0, 1}
}
end,
transform = function(x, y, angle, sx, sy)
sx = sx or 1
sy = sy or sx
local cos, sin = math.cos(angle), math.sin(angle)
return {
{cos * sx, -sin * sy, x},
{sin * sx, cos * sy, y},
{0, 0, 1}
}
end,
mul = function(a, b)
local result = {
{0, 0, 0},
{0, 0, 0},
{0, 0, 0}
}
for i = 1, 3 do
for j = 1, 3 do
for k = 1, 3 do
result[i][j] = result[i][j] + a[i][k] * b[k][j]
end
end
end
return result
end,
transform_point = function(m, v)
local x = m[1][1] * v.x + m[1][2] * v.y + m[1][3]
local y = m[2][1] * v.x + m[2][2] * v.y + m[2][3]
local w = m[3][1] * v.x + m[3][2] * v.y + m[3][3]
if math.abs(w) < 1e-10 then
return {x = 0, y = 0}
end
return {x = x / w, y = y / w}
end,
translation = function(x, y)
return {
{1, 0, x},
{0, 1, y},
{0, 0, 1}
}
end,
rotation = function(angle)
local cos, sin = math.cos(angle), math.sin(angle)
return {
{cos, -sin, 0},
{sin, cos, 0},
{0, 0, 1}
}
end,
scale = function(sx, sy)
sy = sy or sx
return {
{sx, 0, 0},
{0, sy, 0},
{0, 0, 1}
}
end,
det = function(m)
return m[1][1] * (m[2][2] * m[3][3] - m[2][3] * m[3][2]) -
m[1][2] * (m[2][1] * m[3][3] - m[2][3] * m[3][1]) +
m[1][3] * (m[2][1] * m[3][2] - m[2][2] * m[3][1])
end
}
-- ======================================================================
-- GEOMETRY FUNCTIONS
-- ======================================================================
math.geometry = {
point_line_distance = function(px, py, x1, y1, x2, y2)
local dx, dy = x2 - x1, y2 - y1
local len_sq = dx * dx + dy * dy
if len_sq < 1e-10 then
return math.distance(px, py, x1, y1)
end
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
t = math.clamp(t, 0, 1)
local nearestX = x1 + t * dx
local nearestY = y1 + t * dy
return math.distance(px, py, nearestX, nearestY)
end,
point_in_polygon = function(px, py, vertices)
local inside = false
local n = #vertices / 2
for i = 1, n do
local x1, y1 = vertices[i*2-1], vertices[i*2]
local x2, y2
if i == n then
x2, y2 = vertices[1], vertices[2]
else
x2, y2 = vertices[i*2+1], vertices[i*2+2]
end
if ((y1 > py) ~= (y2 > py)) and
(px < (x2 - x1) * (py - y1) / (y2 - y1) + x1) then
inside = not inside
end
end
return inside
end,
triangle_area = function(x1, y1, x2, y2, x3, y3)
return math.abs((x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2)) / 2)
end,
point_in_triangle = function(px, py, x1, y1, x2, y2, x3, y3)
local area = math.geometry.triangle_area(x1, y1, x2, y2, x3, y3)
local area1 = math.geometry.triangle_area(px, py, x2, y2, x3, y3)
local area2 = math.geometry.triangle_area(x1, y1, px, py, x3, y3)
local area3 = math.geometry.triangle_area(x1, y1, x2, y2, px, py)
return math.abs(area - (area1 + area2 + area3)) < 1e-10
end,
line_intersect = function(x1, y1, x2, y2, x3, y3, x4, y4)
local d = (y4 - y3) * (x2 - x1) - (x4 - x3) * (y2 - y1)
if math.abs(d) < 1e-10 then
return false, nil, nil
end
local ua = ((x4 - x3) * (y1 - y3) - (y4 - y3) * (x1 - x3)) / d
local ub = ((x2 - x1) * (y1 - y3) - (y2 - y1) * (x1 - x3)) / d
if ua >= 0 and ua <= 1 and ub >= 0 and ub <= 1 then
local x = x1 + ua * (x2 - x1)
local y = y1 + ua * (y2 - y1)
return true, x, y
end
return false, nil, nil
end,
closest_point_on_segment = function(px, py, x1, y1, x2, y2)
local dx, dy = x2 - x1, y2 - y1
local len_sq = dx * dx + dy * dy
if len_sq < 1e-10 then
return x1, y1
end
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
t = math.clamp(t, 0, 1)
return x1 + t * dx, y1 + t * dy
end
}
-- ======================================================================
-- INTERPOLATION FUNCTIONS
-- ======================================================================
math.interpolation = {
bezier = function(t, p0, p1, p2, p3)
t = math.clamp(t, 0, 1)
local t2 = t * t
local t3 = t2 * t
local mt = 1 - t
local mt2 = mt * mt
local mt3 = mt2 * mt
return p0 * mt3 + 3 * p1 * mt2 * t + 3 * p2 * mt * t2 + p3 * t3
end,
catmull_rom = function(t, p0, p1, p2, p3)
t = math.clamp(t, 0, 1)
local t2 = t * t
local t3 = t2 * t
return 0.5 * (
(2 * p1) +
(-p0 + p2) * t +
(2 * p0 - 5 * p1 + 4 * p2 - p3) * t2 +
(-p0 + 3 * p1 - 3 * p2 + p3) * t3
)
end,
hermite = function(t, p0, p1, m0, m1)
t = math.clamp(t, 0, 1)
local t2 = t * t
local t3 = t2 * t
local h00 = 2 * t3 - 3 * t2 + 1
local h10 = t3 - 2 * t2 + t
local h01 = -2 * t3 + 3 * t2
local h11 = t3 - t2
return h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
end,
quadratic_bezier = function(t, p0, p1, p2)
t = math.clamp(t, 0, 1)
local mt = 1 - t
return mt * mt * p0 + 2 * mt * t * p1 + t * t * p2
end,
step = function(t, edge, x)
return t < edge and 0 or x
end,
smootherstep = function(edge0, edge1, x)
local t = math.clamp((x - edge0) / (edge1 - edge0), 0, 1)
return t * t * t * (t * (t * 6 - 15) + 10)
end
}

814
modules/mysql/mysql.lua Normal file
View File

@ -0,0 +1,814 @@
local mysql = {}
local Connection = {}
Connection.__index = Connection
function Connection:close()
if self._id then
local ok = moonshark.sql_close(self._id)
self._id = nil
return ok
end
return false
end
function Connection:ping()
if not self._id then
error("Connection is closed")
end
return moonshark.sql_ping(self._id)
end
function Connection:query(query_str, ...)
if not self._id then
error("Connection is closed")
end
return moonshark.sql_query(self._id, query_str:normalize_whitespace(), ...)
end
function Connection:exec(query_str, ...)
if not self._id then
error("Connection is closed")
end
return moonshark.sql_exec(self._id, query_str:normalize_whitespace(), ...)
end
function Connection:query_row(query_str, ...)
local results = self:query(query_str, ...)
return results and #results > 0 and results[1] or nil
end
function Connection:query_value(query_str, ...)
local row = self:query_row(query_str, ...)
if row then
for _, value in pairs(row) do
return value
end
end
return nil
end
function Connection:begin()
local result = self:exec("BEGIN")
if result then
return {
conn = self,
active = true,
commit = function(tx)
if tx.active then
tx.active = false
return tx.conn:exec("COMMIT")
end
return false
end,
rollback = function(tx)
if tx.active then
tx.active = false
return tx.conn:exec("ROLLBACK")
end
return false
end,
savepoint = function(tx, name)
if not tx.active then error("Transaction is not active") end
if name:is_blank() then error("Savepoint name cannot be empty") end
return tx.conn:exec("SAVEPOINT {{name}}":parse({name = name}))
end,
rollback_to = function(tx, name)
if not tx.active then error("Transaction is not active") end
if name:is_blank() then error("Savepoint name cannot be empty") end
return tx.conn:exec("ROLLBACK TO SAVEPOINT {{name}}":parse({name = name}))
end,
query = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:query(query_str, ...)
end,
exec = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:exec(query_str, ...)
end,
query_row = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:query_row(query_str, ...)
end,
query_value = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:query_value(query_str, ...)
end
}
end
return nil
end
function Connection:insert(table_name, data)
if table_name:is_blank() then
error("Table name cannot be empty")
end
local keys = table.keys(data)
local values = table.values(data)
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}})":parse({
table = table_name,
columns = keys:join(", "),
placeholders = placeholders
})
return self:exec(query, unpack(values))
end
function Connection:upsert(table_name, data, update_data)
if table_name:is_blank() then
error("Table name cannot be empty")
end
local keys = table.keys(data)
local values = table.values(data)
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
-- Use update_data if provided, otherwise update with same data
local update_source = update_data or data
local updates = table.map(table.keys(update_source), function(key)
return key .. " = VALUES(" .. key .. ")"
end)
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}}) ON DUPLICATE KEY UPDATE {{updates}}":parse({
table = table_name,
columns = keys:join(", "),
placeholders = placeholders,
updates = updates:join(", ")
})
return self:exec(query, unpack(values))
end
function Connection:replace(table_name, data)
if table_name:is_blank() then
error("Table name cannot be empty")
end
local keys = table.keys(data)
local values = table.values(data)
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
local query = "REPLACE INTO {{table}} ({{columns}}) VALUES ({{placeholders}})":parse({
table = table_name,
columns = keys:join(", "),
placeholders = placeholders
})
return self:exec(query, unpack(values))
end
function Connection:update(table_name, data, where_clause, ...)
if table_name:is_blank() then
error("Table name cannot be empty")
end
if where_clause:is_blank() then
error("WHERE clause cannot be empty for UPDATE")
end
local keys = table.keys(data)
local values = table.values(data)
local sets = table.map(keys, function(key) return key .. " = ?" end)
local query = "UPDATE {{table}} SET {{sets}} WHERE {{where}}":parse({
table = table_name,
sets = sets:join(", "),
where = where_clause
})
table.extend(values, {...})
return self:exec(query, unpack(values))
end
function Connection:delete(table_name, where_clause, ...)
if table_name:is_blank() then
error("Table name cannot be empty")
end
if where_clause:is_blank() then
error("WHERE clause cannot be empty for DELETE")
end
local query = "DELETE FROM {{table}} WHERE {{where}}":parse({
table = table_name,
where = where_clause
})
return self:exec(query, ...)
end
function Connection:select(table_name, columns, where_clause, ...)
if table_name:is_blank() then
error("Table name cannot be empty")
end
columns = columns or "*"
if type(columns) == "table" then
columns = table.concat(columns, ", ")
end
if where_clause and not where_clause:is_blank() then
local query = "SELECT {{columns}} FROM {{table}} WHERE {{where}}":parse({
columns = columns,
table = table_name,
where = where_clause
})
return self:query(query, ...)
else
local query = "SELECT {{columns}} FROM {{table}}":parse({
columns = columns,
table = table_name
})
return self:query(query)
end
end
-- MySQL schema helpers
function Connection:database_exists(database_name)
if database_name:is_blank() then return false end
return self:query_value("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?",
database_name:trim()) ~= nil
end
function Connection:table_exists(table_name, database_name)
if table_name:is_blank() then return false end
database_name = database_name or self:current_database()
if not database_name then return false end
return self:query_value("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?",
database_name:trim(), table_name:trim()) ~= nil
end
function Connection:column_exists(table_name, column_name, database_name)
if table_name:is_blank() or column_name:is_blank() then return false end
database_name = database_name or self:current_database()
if not database_name then return false end
return self:query_value([[
SELECT COLUMN_NAME FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?
]], database_name:trim(), table_name:trim(), column_name:trim()) ~= nil
end
function Connection:create_database(database_name, charset, collation)
if database_name:is_blank() then
error("Database name cannot be empty")
end
local charset_clause = charset and " CHARACTER SET " .. charset or ""
local collation_clause = collation and " COLLATE " .. collation or ""
return self:exec("CREATE DATABASE IF NOT EXISTS {{database}}{{charset}}{{collation}}":parse({
database = database_name,
charset = charset_clause,
collation = collation_clause
}))
end
function Connection:drop_database(database_name)
if database_name:is_blank() then
error("Database name cannot be empty")
end
return self:exec("DROP DATABASE IF EXISTS {{database}}":parse({database = database_name}))
end
function Connection:create_table(table_name, schema, engine, charset)
if table_name:is_blank() or schema:is_blank() then
error("Table name and schema cannot be empty")
end
local engine_clause = engine and " ENGINE=" .. engine:upper() or ""
local charset_clause = charset and " CHARACTER SET " .. charset or ""
return self:exec("CREATE TABLE IF NOT EXISTS {{table}} ({{schema}}){{engine}}{{charset}}":parse({
table = table_name,
schema = schema:trim(),
engine = engine_clause,
charset = charset_clause
}))
end
function Connection:drop_table(table_name)
if table_name:is_blank() then
error("Table name cannot be empty")
end
return self:exec("DROP TABLE IF EXISTS {{table}}":parse({table = table_name}))
end
function Connection:add_column(table_name, column_def, position)
if table_name:is_blank() or column_def:is_blank() then
error("Table name and column definition cannot be empty")
end
local position_clause = position and " " .. position or ""
return self:exec("ALTER TABLE {{table}} ADD COLUMN {{column}}{{position}}":parse({
table = table_name,
column = column_def:trim(),
position = position_clause
}))
end
function Connection:drop_column(table_name, column_name)
if table_name:is_blank() or column_name:is_blank() then
error("Table name and column name cannot be empty")
end
return self:exec("ALTER TABLE {{table}} DROP COLUMN {{column}}":parse({
table = table_name,
column = column_name
}))
end
function Connection:modify_column(table_name, column_def)
if table_name:is_blank() or column_def:is_blank() then
error("Table name and column definition cannot be empty")
end
return self:exec("ALTER TABLE {{table}} MODIFY COLUMN {{column}}":parse({
table = table_name,
column = column_def:trim()
}))
end
function Connection:rename_table(old_name, new_name)
if old_name:is_blank() or new_name:is_blank() then
error("Old and new table names cannot be empty")
end
return self:exec("RENAME TABLE {{old}} TO {{new}}":parse({old = old_name, new = new_name}))
end
function Connection:create_index(index_name, table_name, columns, unique, type)
if index_name:is_blank() or table_name:is_blank() then
error("Index name and table name cannot be empty")
end
local unique_clause = unique and "UNIQUE " or ""
local type_clause = type and " USING " .. type:upper() or ""
local columns_str = type(columns) == "table" and table.concat(columns, ", ") or tostring(columns)
return self:exec("CREATE {{unique}}INDEX {{index}} ON {{table}} ({{columns}}){{type}}":parse({
unique = unique_clause,
index = index_name,
table = table_name,
columns = columns_str,
type = type_clause
}))
end
function Connection:drop_index(index_name, table_name)
if index_name:is_blank() or table_name:is_blank() then
error("Index name and table name cannot be empty")
end
return self:exec("DROP INDEX {{index}} ON {{table}}":parse({
index = index_name,
table = table_name
}))
end
-- MySQL maintenance functions
function Connection:optimize(table_name)
local table_clause = table_name and " " .. table_name or ""
return self:query("OPTIMIZE TABLE{{table}}":parse({table = table_clause}))
end
function Connection:repair(table_name)
if table_name:is_blank() then
error("Table name cannot be empty for REPAIR")
end
return self:query("REPAIR TABLE {{table}}":parse({table = table_name}))
end
function Connection:check_table(table_name, options)
if table_name:is_blank() then
error("Table name cannot be empty for CHECK")
end
local options_clause = ""
if options then
local valid_options = {"QUICK", "FAST", "MEDIUM", "EXTENDED", "CHANGED"}
local options_upper = options:upper()
if table.contains(valid_options, options_upper) then
options_clause = " " .. options_upper
end
end
return self:query("CHECK TABLE {{table}}{{options}}":parse({
table = table_name,
options = options_clause
}))
end
function Connection:analyze_table(table_name)
if table_name:is_blank() then
error("Table name cannot be empty for ANALYZE")
end
return self:query("ANALYZE TABLE {{table}}":parse({table = table_name}))
end
-- MySQL settings and introspection
function Connection:show(what)
if what:is_blank() then
error("SHOW parameter cannot be empty")
end
return self:query("SHOW {{what}}":parse({what = what:upper()}))
end
function Connection:current_database()
return self:query_value("SELECT DATABASE() AS db")
end
function Connection:version()
return self:query_value("SELECT VERSION() AS version")
end
function Connection:connection_id()
return self:query_value("SELECT CONNECTION_ID()")
end
function Connection:list_databases()
return self:query("SHOW DATABASES")
end
function Connection:list_tables(database_name)
if database_name and not database_name:is_blank() then
return self:query("SHOW TABLES FROM {{database}}":parse({database = database_name}))
else
return self:query("SHOW TABLES")
end
end
function Connection:describe_table(table_name)
if table_name:is_blank() then
error("Table name cannot be empty")
end
return self:query("DESCRIBE {{table}}":parse({table = table_name}))
end
function Connection:show_create_table(table_name)
if table_name:is_blank() then
error("Table name cannot be empty")
end
return self:query("SHOW CREATE TABLE {{table}}":parse({table = table_name}))
end
function Connection:show_indexes(table_name)
if table_name:is_blank() then
error("Table name cannot be empty")
end
return self:query("SHOW INDEXES FROM {{table}}":parse({table = table_name}))
end
function Connection:show_table_status(table_name)
if table_name and not table_name:is_blank() then
return self:query("SHOW TABLE STATUS LIKE ?", table_name)
else
return self:query("SHOW TABLE STATUS")
end
end
-- MySQL user and privilege management
function Connection:create_user(username, password, host)
if username:is_blank() or password:is_blank() then
error("Username and password cannot be empty")
end
host = host or "%"
return self:exec("CREATE USER '{{username}}'@'{{host}}' IDENTIFIED BY ?":parse({
username = username,
host = host
}), password)
end
function Connection:drop_user(username, host)
if username:is_blank() then
error("Username cannot be empty")
end
host = host or "%"
return self:exec("DROP USER IF EXISTS '{{username}}'@'{{host}}'":parse({
username = username,
host = host
}))
end
function Connection:grant(privileges, database, table_name, username, host)
if privileges:is_blank() or database:is_blank() or username:is_blank() then
error("Privileges, database, and username cannot be empty")
end
host = host or "%"
table_name = table_name or "*"
local object = database .. "." .. table_name
return self:exec("GRANT {{privileges}} ON {{object}} TO '{{username}}'@'{{host}}'":parse({
privileges = privileges:upper(),
object = object,
username = username,
host = host
}))
end
function Connection:revoke(privileges, database, table_name, username, host)
if privileges:is_blank() or database:is_blank() or username:is_blank() then
error("Privileges, database, and username cannot be empty")
end
host = host or "%"
table_name = table_name or "*"
local object = database .. "." .. table_name
return self:exec("REVOKE {{privileges}} ON {{object}} FROM '{{username}}'@'{{host}}'":parse({
privileges = privileges:upper(),
object = object,
username = username,
host = host
}))
end
function Connection:flush_privileges()
return self:exec("FLUSH PRIVILEGES")
end
-- MySQL variables and configuration
function Connection:set_variable(name, value, global)
if name:is_blank() then
error("Variable name cannot be empty")
end
local scope = global and "GLOBAL " or "SESSION "
return self:exec("SET {{scope}}{{name}} = ?":parse({scope = scope, name = name}), value)
end
function Connection:get_variable(name, global)
if name:is_blank() then
error("Variable name cannot be empty")
end
local scope = global and "global." or "session."
return self:query_value("SELECT @@{{scope}}{{name}}":parse({scope = scope, name = name}))
end
function Connection:show_variables(pattern)
if pattern and not pattern:is_blank() then
return self:query("SHOW VARIABLES LIKE ?", pattern)
else
return self:query("SHOW VARIABLES")
end
end
function Connection:show_status(pattern)
if pattern and not pattern:is_blank() then
return self:query("SHOW STATUS LIKE ?", pattern)
else
return self:query("SHOW STATUS")
end
end
-- Connection management
function mysql.connect(dsn)
if dsn:is_blank() then
error("DSN cannot be empty")
end
local conn_id = moonshark.sql_connect("mysql", dsn:trim())
if conn_id then
return setmetatable({_id = conn_id}, Connection)
end
return nil
end
mysql.open = mysql.connect
-- Quick execution functions
function mysql.query(dsn, query_str, ...)
local conn = mysql.connect(dsn)
if not conn then
error("Failed to connect to MySQL database")
end
local results = conn:query(query_str, ...)
conn:close()
return results
end
function mysql.exec(dsn, query_str, ...)
local conn = mysql.connect(dsn)
if not conn then
error("Failed to connect to MySQL database")
end
local result = conn:exec(query_str, ...)
conn:close()
return result
end
function mysql.query_row(dsn, query_str, ...)
local results = mysql.query(dsn, query_str, ...)
return results and #results > 0 and results[1] or nil
end
function mysql.query_value(dsn, query_str, ...)
local row = mysql.query_row(dsn, query_str, ...)
if row then
for _, value in pairs(row) do
return value
end
end
return nil
end
-- Migration helpers
function mysql.migrate(dsn, migrations, database_name)
local conn = mysql.connect(dsn)
if not conn then
error("Failed to connect to MySQL database for migration")
end
-- Use specified database if provided
if database_name and not database_name:is_blank() then
conn:exec("USE {{database}}":parse({database = database_name}))
end
-- Create migrations table
conn:create_table("_migrations", "id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) UNIQUE NOT NULL, applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP")
local tx = conn:begin()
if not tx then
conn:close()
error("Failed to begin migration transaction")
end
for _, migration in ipairs(migrations) do
if not migration.name or migration.name:is_blank() then
tx:rollback()
conn:close()
error("Migration must have a non-empty name")
end
-- Check if migration already applied
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?", migration.name:trim())
if not existing then
local ok, err = pcall(function()
if type(migration.up) == "string" then
conn:exec(migration.up)
elseif type(migration.up) == "function" then
migration.up(conn)
else
error("Migration 'up' must be string or function")
end
end)
if ok then
conn:exec("INSERT INTO _migrations (name) VALUES (?)", migration.name:trim())
print("Applied migration: {{name}}":parse({name = migration.name}))
else
tx:rollback()
conn:close()
error("Migration '{{name}}' failed: {{error}}":parse({
name = migration.name,
error = err or "unknown error"
}))
end
end
end
tx:commit()
conn:close()
return true
end
-- Result processing utilities
function mysql.to_array(results, column_name)
if not results or table.is_empty(results) then return {} end
if column_name:is_blank() then error("Column name cannot be empty") end
return table.map(results, function(row) return row[column_name] end)
end
function mysql.to_map(results, key_column, value_column)
if not results or table.is_empty(results) then return {} end
if key_column:is_blank() then error("Key column name cannot be empty") end
local map = {}
for _, row in ipairs(results) do
local key = row[key_column]
map[key] = value_column and row[value_column] or row
end
return map
end
function mysql.group_by(results, column_name)
if not results or table.is_empty(results) then return {} end
if column_name:is_blank() then error("Column name cannot be empty") end
return table.group_by(results, function(row) return row[column_name] end)
end
function mysql.print_results(results)
if not results or table.is_empty(results) then
print("No results")
return
end
local columns = table.keys(results[1])
table.sort(columns)
-- Calculate column widths
local widths = {}
for _, col in ipairs(columns) do
widths[col] = col:length()
for _, row in ipairs(results) do
local value = tostring(row[col] or "")
widths[col] = math.max(widths[col], value:length())
end
end
-- Print header and separator
local header_parts = table.map(columns, function(col) return col:pad_right(widths[col]) end)
local separator_parts = table.map(columns, function(col) return string.repeat_("-", widths[col]) end)
print(table.concat(header_parts, " | "))
print(table.concat(separator_parts, "-+-"))
-- Print rows
for _, row in ipairs(results) do
local value_parts = table.map(columns, function(col)
local value = tostring(row[col] or "")
return value:pad_right(widths[col])
end)
print(table.concat(value_parts, " | "))
end
end
-- MySQL-specific utilities
function mysql.escape_string(str_val)
if type(str_val) ~= "string" then
return tostring(str_val)
end
return str_val:replace("'", "\\'")
end
function mysql.escape_identifier(name)
if name:is_blank() then
error("Identifier name cannot be empty")
end
return "`{{name}}`":parse({name = name:replace("`", "``")})
end
-- DSN builder helper
function mysql.build_dsn(options)
if type(options) ~= "table" then
error("Options must be a table")
end
local parts = {}
if options.username and not options.username:is_blank() then
table.insert(parts, options.username)
if options.password and not options.password:is_blank() then
parts[#parts] = parts[#parts] .. ":" .. options.password
end
parts[#parts] = parts[#parts] .. "@"
end
if options.protocol and not options.protocol:is_blank() then
local host_part = options.protocol .. "("
if options.host and not options.host:is_blank() then
host_part = host_part .. options.host
if options.port then
host_part = host_part .. ":" .. tostring(options.port)
end
end
table.insert(parts, host_part .. ")")
elseif options.host and not options.host:is_blank() then
local host_part = "tcp(" .. options.host
if options.port then
host_part = host_part .. ":" .. tostring(options.port)
end
table.insert(parts, host_part .. ")")
end
if options.database and not options.database:is_blank() then
table.insert(parts, "/" .. options.database)
end
-- Add parameters
local params = {}
if options.charset and not options.charset:is_blank() then
table.insert(params, "charset=" .. options.charset)
end
if options.parseTime ~= nil then
table.insert(params, "parseTime=" .. tostring(options.parseTime))
end
if options.timeout and not options.timeout:is_blank() then
table.insert(params, "timeout=" .. options.timeout)
end
if options.tls and not options.tls:is_blank() then
table.insert(params, "tls=" .. options.tls)
end
if #params > 0 then
table.insert(parts, "?" .. table.concat(params, "&"))
end
return table.concat(parts, "")
end
return mysql

View File

@ -0,0 +1,688 @@
local postgres = {}
local Connection = {}
Connection.__index = Connection
function Connection:close()
if self._id then
local ok = moonshark.sql_close(self._id)
self._id = nil
return ok
end
return false
end
function Connection:ping()
if not self._id then
error("Connection is closed")
end
return moonshark.sql_ping(self._id)
end
function Connection:query(query_str, ...)
if not self._id then
error("Connection is closed")
end
return moonshark.sql_query(self._id, query_str:normalize_whitespace(), ...)
end
function Connection:exec(query_str, ...)
if not self._id then
error("Connection is closed")
end
return moonshark.sql_exec(self._id, query_str:normalize_whitespace(), ...)
end
function Connection:query_row(query_str, ...)
local results = self:query(query_str, ...)
return results and #results > 0 and results[1] or nil
end
function Connection:query_value(query_str, ...)
local row = self:query_row(query_str, ...)
if row then
for _, value in pairs(row) do
return value
end
end
return nil
end
function Connection:begin()
local result = self:exec("BEGIN")
if result then
return {
conn = self,
active = true,
commit = function(tx)
if tx.active then
tx.active = false
return tx.conn:exec("COMMIT")
end
return false
end,
rollback = function(tx)
if tx.active then
tx.active = false
return tx.conn:exec("ROLLBACK")
end
return false
end,
savepoint = function(tx, name)
if not tx.active then error("Transaction is not active") end
if name:is_blank() then error("Savepoint name cannot be empty") end
return tx.conn:exec("SAVEPOINT {{name}}":parse({name = name}))
end,
rollback_to = function(tx, name)
if not tx.active then error("Transaction is not active") end
if name:is_blank() then error("Savepoint name cannot be empty") end
return tx.conn:exec("ROLLBACK TO SAVEPOINT {{name}}":parse({name = name}))
end,
query = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:query(query_str, ...)
end,
exec = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:exec(query_str, ...)
end,
query_row = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:query_row(query_str, ...)
end,
query_value = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:query_value(query_str, ...)
end
}
end
return nil
end
-- Build PostgreSQL parameters ($1, $2, etc.)
local function build_postgres_params(data)
local keys = table.keys(data)
local values = table.values(data)
local placeholders = {}
for i = 1, #keys do
placeholders[i] = "$" .. i
end
return keys, values, placeholders
end
function Connection:insert(table_name, data, returning)
if table_name:is_blank() then
error("Table name cannot be empty")
end
local keys, values, placeholders = build_postgres_params(data)
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}})":parse({
table = table_name,
columns = keys:join(", "),
placeholders = table.concat(placeholders, ", ")
})
if returning and not returning:is_blank() then
query = query .. " RETURNING " .. returning
return self:query(query, unpack(values))
else
return self:exec(query, unpack(values))
end
end
function Connection:upsert(table_name, data, conflict_columns, returning)
if table_name:is_blank() then
error("Table name cannot be empty")
end
local keys, values, placeholders = build_postgres_params(data)
local updates = table.map(keys, function(key) return key .. " = EXCLUDED." .. key end)
local conflict_clause = ""
if conflict_columns then
if type(conflict_columns) == "string" then
conflict_clause = "(" .. conflict_columns .. ")"
else
conflict_clause = "(" .. table.concat(conflict_columns, ", ") .. ")"
end
end
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}}) ON CONFLICT {{conflict}} DO UPDATE SET {{updates}}":parse({
table = table_name,
columns = keys:join(", "),
placeholders = table.concat(placeholders, ", "),
conflict = conflict_clause,
updates = updates:join(", ")
})
if returning and not returning:is_blank() then
query = query .. " RETURNING " .. returning
return self:query(query, unpack(values))
else
return self:exec(query, unpack(values))
end
end
function Connection:update(table_name, data, where_clause, returning, ...)
if table_name:is_blank() then
error("Table name cannot be empty")
end
if where_clause:is_blank() then
error("WHERE clause cannot be empty for UPDATE")
end
local keys = table.keys(data)
local values = table.values(data)
local param_count = #keys
-- Build SET clause with numbered parameters
local sets = {}
for i, key in ipairs(keys) do
sets[i] = key .. " = $" .. i
end
-- Handle WHERE parameters
local where_args = {...}
local where_clause_final = where_clause
for i = 1, #where_args do
param_count = param_count + 1
values[#values + 1] = where_args[i]
where_clause_final = where_clause_final:replace("?", "$" .. param_count, 1)
end
local query = "UPDATE {{table}} SET {{sets}} WHERE {{where}}":parse({
table = table_name,
sets = table.concat(sets, ", "),
where = where_clause_final
})
if returning and not returning:is_blank() then
query = query .. " RETURNING " .. returning
return self:query(query, unpack(values))
else
return self:exec(query, unpack(values))
end
end
function Connection:delete(table_name, where_clause, returning, ...)
if table_name:is_blank() then
error("Table name cannot be empty")
end
if where_clause:is_blank() then
error("WHERE clause cannot be empty for DELETE")
end
local where_args = {...}
local values = {}
local where_clause_final = where_clause
for i = 1, #where_args do
values[i] = where_args[i]
where_clause_final = where_clause_final:replace("?", "$" .. i, 1)
end
local query = "DELETE FROM {{table}} WHERE {{where}}":parse({
table = table_name,
where = where_clause_final
})
if returning and not returning:is_blank() then
query = query .. " RETURNING " .. returning
return self:query(query, unpack(values))
else
return self:exec(query, unpack(values))
end
end
function Connection:select(table_name, columns, where_clause, ...)
if table_name:is_blank() then
error("Table name cannot be empty")
end
columns = columns or "*"
if type(columns) == "table" then
columns = table.concat(columns, ", ")
end
if where_clause and not where_clause:is_blank() then
local where_args = {...}
local values = {}
local where_clause_final = where_clause
for i = 1, #where_args do
values[i] = where_args[i]
where_clause_final = where_clause_final:replace("?", "$" .. i, 1)
end
local query = "SELECT {{columns}} FROM {{table}} WHERE {{where}}":parse({
columns = columns,
table = table_name,
where = where_clause_final
})
return self:query(query, unpack(values))
else
local query = "SELECT {{columns}} FROM {{table}}":parse({
columns = columns,
table = table_name
})
return self:query(query)
end
end
-- Schema helpers
function Connection:table_exists(table_name, schema_name)
if table_name:is_blank() then return false end
schema_name = schema_name or "public"
return self:query_value("SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2",
schema_name:trim(), table_name:trim()) ~= nil
end
function Connection:column_exists(table_name, column_name, schema_name)
if table_name:is_blank() or column_name:is_blank() then return false end
schema_name = schema_name or "public"
return self:query_value([[
SELECT column_name FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2 AND column_name = $3
]], schema_name:trim(), table_name:trim(), column_name:trim()) ~= nil
end
function Connection:create_table(table_name, schema)
if table_name:is_blank() or schema:is_blank() then
error("Table name and schema cannot be empty")
end
return self:exec("CREATE TABLE IF NOT EXISTS {{table}} ({{schema}})":parse({
table = table_name,
schema = schema:trim()
}))
end
function Connection:drop_table(table_name, cascade)
if table_name:is_blank() then
error("Table name cannot be empty")
end
local cascade_clause = cascade and " CASCADE" or ""
return self:exec("DROP TABLE IF EXISTS {{table}}{{cascade}}":parse({
table = table_name,
cascade = cascade_clause
}))
end
function Connection:add_column(table_name, column_def)
if table_name:is_blank() or column_def:is_blank() then
error("Table name and column definition cannot be empty")
end
return self:exec("ALTER TABLE {{table}} ADD COLUMN IF NOT EXISTS {{column}}":parse({
table = table_name,
column = column_def:trim()
}))
end
function Connection:drop_column(table_name, column_name, cascade)
if table_name:is_blank() or column_name:is_blank() then
error("Table name and column name cannot be empty")
end
local cascade_clause = cascade and " CASCADE" or ""
return self:exec("ALTER TABLE {{table}} DROP COLUMN IF EXISTS {{column}}{{cascade}}":parse({
table = table_name,
column = column_name,
cascade = cascade_clause
}))
end
function Connection:create_index(index_name, table_name, columns, unique, method)
if index_name:is_blank() or table_name:is_blank() then
error("Index name and table name cannot be empty")
end
local unique_clause = unique and "UNIQUE " or ""
local method_clause = method and " USING " .. method:upper() or ""
local columns_str = type(columns) == "table" and table.concat(columns, ", ") or tostring(columns)
return self:exec("CREATE {{unique}}INDEX IF NOT EXISTS {{index}} ON {{table}}{{method}} ({{columns}})":parse({
unique = unique_clause,
index = index_name,
table = table_name,
method = method_clause,
columns = columns_str
}))
end
function Connection:drop_index(index_name, cascade)
if index_name:is_blank() then
error("Index name cannot be empty")
end
local cascade_clause = cascade and " CASCADE" or ""
return self:exec("DROP INDEX IF EXISTS {{index}}{{cascade}}":parse({
index = index_name,
cascade = cascade_clause
}))
end
-- PostgreSQL-specific functions
function Connection:vacuum(table_name, analyze)
local analyze_clause = analyze and " ANALYZE" or ""
local table_clause = table_name and " " .. table_name or ""
return self:exec("VACUUM{{analyze}}{{table}}":parse({
analyze = analyze_clause,
table = table_clause
}))
end
function Connection:analyze(table_name)
local table_clause = table_name and " " .. table_name or ""
return self:exec("ANALYZE{{table}}":parse({table = table_clause}))
end
function Connection:reindex(name, type)
if name:is_blank() then
error("Name cannot be empty for REINDEX")
end
type = (type or "INDEX"):upper()
local valid_types = {"INDEX", "TABLE", "SCHEMA", "DATABASE", "SYSTEM"}
if not table.contains(valid_types, type) then
error("Invalid REINDEX type: " .. type)
end
return self:exec("REINDEX {{type}} {{name}}":parse({type = type, name = name}))
end
function Connection:show(setting)
if setting:is_blank() then
error("Setting name cannot be empty")
end
return self:query_value("SHOW {{setting}}":parse({setting = setting}))
end
function Connection:set(setting, value)
if setting:is_blank() then
error("Setting name cannot be empty")
end
return self:exec("SET {{setting}} = {{value}}":parse({
setting = setting,
value = tostring(value)
}))
end
function Connection:current_database()
return self:query_value("SELECT current_database()")
end
function Connection:current_schema()
return self:query_value("SELECT current_schema()")
end
function Connection:version()
return self:query_value("SELECT version()")
end
function Connection:list_schemas()
return self:query("SELECT schema_name FROM information_schema.schemata ORDER BY schema_name")
end
function Connection:list_tables(schema_name)
schema_name = schema_name or "public"
return self:query("SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename", schema_name:trim())
end
function Connection:describe_table(table_name, schema_name)
if table_name:is_blank() then
error("Table name cannot be empty")
end
schema_name = schema_name or "public"
return self:query([[
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position
]], schema_name:trim(), table_name:trim())
end
-- JSON/JSONB helpers
function Connection:json_extract(column, path)
if column:is_blank() or path:is_blank() then
error("Column and path cannot be empty")
end
return "{{column}}->'{{path}}'":parse({column = column, path = path})
end
function Connection:json_extract_text(column, path)
if column:is_blank() or path:is_blank() then
error("Column and path cannot be empty")
end
return "{{column}}->>'{{path}}'":parse({column = column, path = path})
end
function Connection:jsonb_contains(column, value)
if column:is_blank() or value:is_blank() then
error("Column and value cannot be empty")
end
return "{{column}} @> '{{value}}'":parse({column = column, value = value})
end
function Connection:jsonb_contained_by(column, value)
if column:is_blank() or value:is_blank() then
error("Column and value cannot be empty")
end
return "{{column}} <@ '{{value}}'":parse({column = column, value = value})
end
-- Array helpers
function Connection:array_contains(column, value)
if column:is_blank() then
error("Column cannot be empty")
end
return "$1 = ANY({{column}})":parse({column = column})
end
function Connection:array_length(column)
if column:is_blank() then
error("Column cannot be empty")
end
return "array_length({{column}}, 1)":parse({column = column})
end
-- Connection management
function postgres.parse_dsn(dsn)
if dsn:is_blank() then
return nil, "DSN cannot be empty"
end
local parts = {}
for pair in dsn:trim():gmatch("[^%s]+") do
local key, value = pair:match("([^=]+)=(.+)")
if key and value then
parts[key:trim()] = value:trim()
end
end
return parts
end
function postgres.connect(dsn)
if dsn:is_blank() then
error("DSN cannot be empty")
end
local conn_id = moonshark.sql_connect("postgres", dsn:trim())
if conn_id then
return setmetatable({_id = conn_id}, Connection)
end
return nil
end
postgres.open = postgres.connect
-- Quick execution functions
function postgres.query(dsn, query_str, ...)
local conn = postgres.connect(dsn)
if not conn then
error("Failed to connect to PostgreSQL database")
end
local results = conn:query(query_str, ...)
conn:close()
return results
end
function postgres.exec(dsn, query_str, ...)
local conn = postgres.connect(dsn)
if not conn then
error("Failed to connect to PostgreSQL database")
end
local result = conn:exec(query_str, ...)
conn:close()
return result
end
function postgres.query_row(dsn, query_str, ...)
local results = postgres.query(dsn, query_str, ...)
return results and #results > 0 and results[1] or nil
end
function postgres.query_value(dsn, query_str, ...)
local row = postgres.query_row(dsn, query_str, ...)
if row then
for _, value in pairs(row) do
return value
end
end
return nil
end
-- Migration helpers
function postgres.migrate(dsn, migrations, schema)
schema = schema or "public"
local conn = postgres.connect(dsn)
if not conn then
error("Failed to connect to PostgreSQL database for migration")
end
conn:create_table("_migrations", "id SERIAL PRIMARY KEY, name TEXT UNIQUE NOT NULL, applied_at TIMESTAMPTZ DEFAULT NOW()")
local tx = conn:begin()
if not tx then
conn:close()
error("Failed to begin migration transaction")
end
for _, migration in ipairs(migrations) do
if not migration.name or migration.name:is_blank() then
tx:rollback()
conn:close()
error("Migration must have a non-empty name")
end
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = $1", migration.name:trim())
if not existing then
local ok, err = pcall(function()
if type(migration.up) == "string" then
conn:exec(migration.up)
elseif type(migration.up) == "function" then
migration.up(conn)
else
error("Migration 'up' must be string or function")
end
end)
if ok then
conn:exec("INSERT INTO _migrations (name) VALUES ($1)", migration.name:trim())
print("Applied migration: {{name}}":parse({name = migration.name}))
else
tx:rollback()
conn:close()
error("Migration '{{name}}' failed: {{error}}":parse({
name = migration.name,
error = err or "unknown error"
}))
end
end
end
tx:commit()
conn:close()
return true
end
-- Result processing utilities
function postgres.to_array(results, column_name)
if not results or table.is_empty(results) then return {} end
if column_name:is_blank() then error("Column name cannot be empty") end
return table.map(results, function(row) return row[column_name] end)
end
function postgres.to_map(results, key_column, value_column)
if not results or table.is_empty(results) then return {} end
if key_column:is_blank() then error("Key column name cannot be empty") end
local map = {}
for _, row in ipairs(results) do
local key = row[key_column]
map[key] = value_column and row[value_column] or row
end
return map
end
function postgres.group_by(results, column_name)
if not results or table.is_empty(results) then return {} end
if column_name:is_blank() then error("Column name cannot be empty") end
return table.group_by(results, function(row) return row[column_name] end)
end
function postgres.print_results(results)
if not results or table.is_empty(results) then
print("No results")
return
end
local columns = table.keys(results[1])
table.sort(columns)
-- Calculate column widths
local widths = {}
for _, col in ipairs(columns) do
widths[col] = col:length()
for _, row in ipairs(results) do
local value = tostring(row[col] or "")
widths[col] = math.max(widths[col], value:length())
end
end
-- Print header and separator
local header_parts = table.map(columns, function(col) return col:pad_right(widths[col]) end)
local separator_parts = table.map(columns, function(col) return string.repeat_("-", widths[col]) end)
print(table.concat(header_parts, " | "))
print(table.concat(separator_parts, "-+-"))
-- Print rows
for _, row in ipairs(results) do
local value_parts = table.map(columns, function(col)
local value = tostring(row[col] or "")
return value:pad_right(widths[col])
end)
print(table.concat(value_parts, " | "))
end
end
function postgres.escape_identifier(name)
if name:is_blank() then
error("Identifier name cannot be empty")
end
return '"{{name}}"':parse({name = name:replace('"', '""')})
end
function postgres.escape_literal(value)
if type(value) == "string" then
return "'{{value}}'":parse({value = value:replace("'", "''")})
end
return tostring(value)
end
return postgres

155
modules/registry.go Normal file
View File

@ -0,0 +1,155 @@
package modules
import (
"embed"
"fmt"
"maps"
"strings"
"Moonshark/modules/crypto"
"Moonshark/modules/fs"
"Moonshark/modules/http"
"Moonshark/modules/kv"
"Moonshark/modules/sql"
lua_string "Moonshark/modules/string+"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
var Global *Registry
//go:embed **/*.lua
var embeddedModules embed.FS
type Registry struct {
modules map[string]string
globalModules map[string]string // globalName -> moduleSource
goFuncs map[string]luajit.GoFunction
}
func New() *Registry {
r := &Registry{
modules: make(map[string]string),
globalModules: make(map[string]string),
goFuncs: make(map[string]luajit.GoFunction),
}
maps.Copy(r.goFuncs, lua_string.GetFunctionList())
maps.Copy(r.goFuncs, crypto.GetFunctionList())
maps.Copy(r.goFuncs, fs.GetFunctionList())
maps.Copy(r.goFuncs, http.GetFunctionList())
maps.Copy(r.goFuncs, sql.GetFunctionList())
maps.Copy(r.goFuncs, kv.GetFunctionList())
r.loadEmbeddedModules()
return r
}
func (r *Registry) loadEmbeddedModules() {
dirs, _ := embeddedModules.ReadDir(".")
for _, dir := range dirs {
if !dir.IsDir() {
continue
}
dirName := dir.Name()
isGlobal := strings.HasSuffix(dirName, "+")
var moduleName, globalName string
if isGlobal {
moduleName = strings.TrimSuffix(dirName, "+")
globalName = moduleName
} else {
moduleName = dirName
}
modulePath := fmt.Sprintf("%s/%s.lua", dirName, moduleName)
if source, err := embeddedModules.ReadFile(modulePath); err == nil {
r.modules[moduleName] = string(source)
if isGlobal {
r.globalModules[globalName] = string(source)
}
}
}
}
func (r *Registry) InstallInState(state *luajit.State) error {
// Create moonshark global table with Go functions
state.NewTable()
for name, fn := range r.goFuncs {
if err := state.PushGoFunction(fn); err != nil {
return fmt.Errorf("failed to register Go function '%s': %w", name, err)
}
state.SetField(-2, name)
}
state.SetGlobal("moonshark")
// Auto-enhance all global modules
for globalName, source := range r.globalModules {
if err := r.enhanceGlobal(state, globalName, source); err != nil {
return fmt.Errorf("failed to enhance %s global: %w", globalName, err)
}
}
// Backup original require and install custom one
state.GetGlobal("require")
state.SetGlobal("_require_original")
return state.RegisterGoFunction("require", func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("require: %v", err)
}
moduleName, err := s.SafeToString(1)
if err != nil {
return s.PushError("require: module name must be a string")
}
// Return global if this module enhances a global
if _, isGlobal := r.globalModules[moduleName]; isGlobal {
s.GetGlobal(moduleName)
return 1
}
// Check built-in modules
if source, exists := r.modules[moduleName]; exists {
if err := s.LoadString(source); err != nil {
return s.PushError("require: failed to load module '%s': %v", moduleName, err)
}
if err := s.Call(0, 1); err != nil {
return s.PushError("require: failed to execute module '%s': %v", moduleName, err)
}
return 1
}
// Fall back to original require
s.GetGlobal("_require_original")
if s.IsFunction(-1) {
s.PushString(moduleName)
if err := s.Call(1, 1); err != nil {
return s.PushError("require: %v", err)
}
return 1
}
return s.PushError("require: module '%s' not found", moduleName)
})
}
func (r *Registry) enhanceGlobal(state *luajit.State, globalName, source string) error {
// Execute the module - it directly modifies the global
if err := state.LoadString(source); err != nil {
return fmt.Errorf("failed to load %s module: %w", globalName, err)
}
if err := state.Call(0, 0); err != nil { // 0 results expected
return fmt.Errorf("failed to execute %s module: %w", globalName, err)
}
return nil
}
func Initialize() error {
Global = New()
return nil
}

205
modules/sql/mysql.go Normal file
View File

@ -0,0 +1,205 @@
package sql
import (
"context"
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
)
// MySQLDriver implements the Driver interface for MySQL
type MySQLDriver struct{}
func (d *MySQLDriver) Name() string {
return "mysql"
}
func (d *MySQLDriver) Open(dsn string) (Connection, error) {
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, fmt.Errorf("mysql: failed to open database: %w", err)
}
// Test the connection
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("mysql: failed to ping database: %w", err)
}
return &MySQLConnection{db: db}, nil
}
// MySQLConnection implements the Connection interface
type MySQLConnection struct {
db *sql.DB
}
func (c *MySQLConnection) Close() error {
return c.db.Close()
}
func (c *MySQLConnection) Ping(ctx context.Context) error {
return c.db.PingContext(ctx)
}
func (c *MySQLConnection) Begin(ctx context.Context) (Transaction, error) {
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("mysql: failed to begin transaction: %w", err)
}
return &MySQLTransaction{tx: tx}, nil
}
func (c *MySQLConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
rows, err := c.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("mysql: query failed: %w", err)
}
return &MySQLRows{rows: rows}, nil
}
func (c *MySQLConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
row := c.db.QueryRowContext(ctx, query, args...)
return &MySQLRow{row: row}
}
func (c *MySQLConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
result, err := c.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("mysql: exec failed: %w", err)
}
return &MySQLResult{result: result}, nil
}
func (c *MySQLConnection) Prepare(ctx context.Context, query string) (Statement, error) {
stmt, err := c.db.PrepareContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("mysql: failed to prepare statement: %w", err)
}
return &MySQLStatement{stmt: stmt}, nil
}
// MySQLTransaction implements the Transaction interface
type MySQLTransaction struct {
tx *sql.Tx
}
func (t *MySQLTransaction) Commit() error {
return t.tx.Commit()
}
func (t *MySQLTransaction) Rollback() error {
return t.tx.Rollback()
}
func (t *MySQLTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
rows, err := t.tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("mysql: transaction query failed: %w", err)
}
return &MySQLRows{rows: rows}, nil
}
func (t *MySQLTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
row := t.tx.QueryRowContext(ctx, query, args...)
return &MySQLRow{row: row}
}
func (t *MySQLTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
result, err := t.tx.ExecContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("mysql: transaction exec failed: %w", err)
}
return &MySQLResult{result: result}, nil
}
func (t *MySQLTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
stmt, err := t.tx.PrepareContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("mysql: failed to prepare transaction statement: %w", err)
}
return &MySQLStatement{stmt: stmt}, nil
}
// MySQLRows implements the Rows interface
type MySQLRows struct {
rows *sql.Rows
}
func (r *MySQLRows) Next() bool {
return r.rows.Next()
}
func (r *MySQLRows) Scan(dest ...any) error {
return r.rows.Scan(dest...)
}
func (r *MySQLRows) Columns() ([]string, error) {
return r.rows.Columns()
}
func (r *MySQLRows) Close() error {
return r.rows.Close()
}
func (r *MySQLRows) Err() error {
return r.rows.Err()
}
// MySQLRow implements the Row interface
type MySQLRow struct {
row *sql.Row
}
func (r *MySQLRow) Scan(dest ...any) error {
return r.row.Scan(dest...)
}
// MySQLResult implements the Result interface
type MySQLResult struct {
result sql.Result
}
func (r *MySQLResult) LastInsertId() (int64, error) {
return r.result.LastInsertId()
}
func (r *MySQLResult) RowsAffected() (int64, error) {
return r.result.RowsAffected()
}
// MySQLStatement implements the Statement interface
type MySQLStatement struct {
stmt *sql.Stmt
}
func (s *MySQLStatement) Close() error {
return s.stmt.Close()
}
func (s *MySQLStatement) Query(ctx context.Context, args ...any) (Rows, error) {
rows, err := s.stmt.QueryContext(ctx, args...)
if err != nil {
return nil, fmt.Errorf("mysql: statement query failed: %w", err)
}
return &MySQLRows{rows: rows}, nil
}
func (s *MySQLStatement) QueryRow(ctx context.Context, args ...any) Row {
row := s.stmt.QueryRowContext(ctx, args...)
return &MySQLRow{row: row}
}
func (s *MySQLStatement) Exec(ctx context.Context, args ...any) (Result, error) {
result, err := s.stmt.ExecContext(ctx, args...)
if err != nil {
return nil, fmt.Errorf("mysql: statement exec failed: %w", err)
}
return &MySQLResult{result: result}, nil
}
func init() {
// Register MySQL driver on import
RegisterDriver("mysql", &MySQLDriver{})
}

234
modules/sql/postgres.go Normal file
View File

@ -0,0 +1,234 @@
package sql
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
// PostgresDriver implements the Driver interface for PostgreSQL
type PostgresDriver struct{}
func (d *PostgresDriver) Name() string {
return "postgres"
}
func (d *PostgresDriver) Open(dsn string) (Connection, error) {
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("postgres: failed to parse config: %w", err)
}
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return nil, fmt.Errorf("postgres: failed to create pool: %w", err)
}
return &PostgresConnection{pool: pool}, nil
}
// PostgresConnection implements the Connection interface
type PostgresConnection struct {
pool *pgxpool.Pool
}
func (c *PostgresConnection) Close() error {
c.pool.Close()
return nil
}
func (c *PostgresConnection) Ping(ctx context.Context) error {
return c.pool.Ping(ctx)
}
func (c *PostgresConnection) Begin(ctx context.Context) (Transaction, error) {
tx, err := c.pool.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("postgres: failed to begin transaction: %w", err)
}
return &PostgresTransaction{tx: tx}, nil
}
func (c *PostgresConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
rows, err := c.pool.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("postgres: query failed: %w", err)
}
return &PostgresRows{rows: rows}, nil
}
func (c *PostgresConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
row := c.pool.QueryRow(ctx, query, args...)
return &PostgresRow{row: row}
}
func (c *PostgresConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
tag, err := c.pool.Exec(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("postgres: exec failed: %w", err)
}
return &PostgresResult{tag: tag}, nil
}
func (c *PostgresConnection) Prepare(ctx context.Context, query string) (Statement, error) {
// pgx doesn't have explicit prepared statements like database/sql
// We'll store the query and use it with the pool
return &PostgresStatement{pool: c.pool, query: query}, nil
}
// PostgresTransaction implements the Transaction interface
type PostgresTransaction struct {
tx pgx.Tx
}
func (t *PostgresTransaction) Commit() error {
return t.tx.Commit(context.Background())
}
func (t *PostgresTransaction) Rollback() error {
return t.tx.Rollback(context.Background())
}
func (t *PostgresTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
rows, err := t.tx.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("postgres: transaction query failed: %w", err)
}
return &PostgresRows{rows: rows}, nil
}
func (t *PostgresTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
row := t.tx.QueryRow(ctx, query, args...)
return &PostgresRow{row: row}
}
func (t *PostgresTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
tag, err := t.tx.Exec(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("postgres: transaction exec failed: %w", err)
}
return &PostgresResult{tag: tag}, nil
}
func (t *PostgresTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
return &PostgresStatement{tx: t.tx, query: query}, nil
}
// PostgresRows implements the Rows interface
type PostgresRows struct {
rows pgx.Rows
}
func (r *PostgresRows) Next() bool {
return r.rows.Next()
}
func (r *PostgresRows) Scan(dest ...any) error {
return r.rows.Scan(dest...)
}
func (r *PostgresRows) Columns() ([]string, error) {
fields := r.rows.FieldDescriptions()
columns := make([]string, len(fields))
for i, field := range fields {
columns[i] = field.Name
}
return columns, nil
}
func (r *PostgresRows) Close() error {
r.rows.Close()
return nil
}
func (r *PostgresRows) Err() error {
return r.rows.Err()
}
// PostgresRow implements the Row interface
type PostgresRow struct {
row pgx.Row
}
func (r *PostgresRow) Scan(dest ...any) error {
return r.row.Scan(dest...)
}
// PostgresResult implements the Result interface
type PostgresResult struct {
tag pgconn.CommandTag
}
func (r *PostgresResult) LastInsertId() (int64, error) {
// PostgreSQL doesn't have AUTO_INCREMENT like MySQL
// Users should use RETURNING clause or sequences
return 0, fmt.Errorf("postgres: LastInsertId not supported, use RETURNING clause")
}
func (r *PostgresResult) RowsAffected() (int64, error) {
return r.tag.RowsAffected(), nil
}
// PostgresStatement implements the Statement interface
type PostgresStatement struct {
pool *pgxpool.Pool
tx pgx.Tx
query string
}
func (s *PostgresStatement) Close() error {
// pgx doesn't require explicit statement cleanup
return nil
}
func (s *PostgresStatement) Query(ctx context.Context, args ...any) (Rows, error) {
var rows pgx.Rows
var err error
if s.tx != nil {
rows, err = s.tx.Query(ctx, s.query, args...)
} else {
rows, err = s.pool.Query(ctx, s.query, args...)
}
if err != nil {
return nil, fmt.Errorf("postgres: statement query failed: %w", err)
}
return &PostgresRows{rows: rows}, nil
}
func (s *PostgresStatement) QueryRow(ctx context.Context, args ...any) Row {
var row pgx.Row
if s.tx != nil {
row = s.tx.QueryRow(ctx, s.query, args...)
} else {
row = s.pool.QueryRow(ctx, s.query, args...)
}
return &PostgresRow{row: row}
}
func (s *PostgresStatement) Exec(ctx context.Context, args ...any) (Result, error) {
var tag pgconn.CommandTag
var err error
if s.tx != nil {
tag, err = s.tx.Exec(ctx, s.query, args...)
} else {
tag, err = s.pool.Exec(ctx, s.query, args...)
}
if err != nil {
return nil, fmt.Errorf("postgres: statement exec failed: %w", err)
}
return &PostgresResult{tag: tag}, nil
}
func init() {
// Register PostgreSQL driver on import
RegisterDriver("postgres", &PostgresDriver{})
}

377
modules/sql/sql.go Normal file
View File

@ -0,0 +1,377 @@
package sql
import (
"context"
"fmt"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Driver interface for SQL database implementations
type Driver interface {
Open(dsn string) (Connection, error)
Name() string
}
// Connection represents a database connection
type Connection interface {
Close() error
Ping(ctx context.Context) error
Begin(ctx context.Context) (Transaction, error)
Query(ctx context.Context, query string, args ...any) (Rows, error)
QueryRow(ctx context.Context, query string, args ...any) Row
Exec(ctx context.Context, query string, args ...any) (Result, error)
Prepare(ctx context.Context, query string) (Statement, error)
}
// Transaction represents a database transaction
type Transaction interface {
Commit() error
Rollback() error
Query(ctx context.Context, query string, args ...any) (Rows, error)
QueryRow(ctx context.Context, query string, args ...any) Row
Exec(ctx context.Context, query string, args ...any) (Result, error)
Prepare(ctx context.Context, query string) (Statement, error)
}
// Rows represents query result rows
type Rows interface {
Next() bool
Scan(dest ...any) error
Columns() ([]string, error)
Close() error
Err() error
}
// Row represents a single query result row
type Row interface {
Scan(dest ...any) error
}
// Result represents the result of an executed statement
type Result interface {
LastInsertId() (int64, error)
RowsAffected() (int64, error)
}
// Statement represents a prepared statement
type Statement interface {
Close() error
Query(ctx context.Context, args ...any) (Rows, error)
QueryRow(ctx context.Context, args ...any) Row
Exec(ctx context.Context, args ...any) (Result, error)
}
// Registry manages database drivers and connections
type Registry struct {
mu sync.RWMutex
drivers map[string]Driver
conns map[string]Connection
nextID int
}
var global = &Registry{
drivers: make(map[string]Driver),
conns: make(map[string]Connection),
}
// RegisterDriver registers a database driver
func RegisterDriver(name string, driver Driver) {
global.mu.Lock()
defer global.mu.Unlock()
global.drivers[name] = driver
}
// GetDriver returns a registered driver
func GetDriver(name string) (Driver, bool) {
global.mu.RLock()
defer global.mu.RUnlock()
driver, exists := global.drivers[name]
return driver, exists
}
// Connect opens a new database connection
func Connect(driverName, dsn string) (string, error) {
driver, exists := GetDriver(driverName)
if !exists {
return "", fmt.Errorf("unknown driver: %s", driverName)
}
conn, err := driver.Open(dsn)
if err != nil {
return "", err
}
global.mu.Lock()
defer global.mu.Unlock()
id := fmt.Sprintf("%s_%d", driverName, global.nextID)
global.nextID++
global.conns[id] = conn
return id, nil
}
// GetConnection retrieves a connection by ID
func GetConnection(id string) (Connection, bool) {
global.mu.RLock()
defer global.mu.RUnlock()
conn, exists := global.conns[id]
return conn, exists
}
// CloseConnection closes and removes a connection
func CloseConnection(id string) error {
global.mu.Lock()
defer global.mu.Unlock()
conn, exists := global.conns[id]
if !exists {
return fmt.Errorf("connection not found: %s", id)
}
err := conn.Close()
delete(global.conns, id)
return err
}
func CloseAllConnections() {
global.mu.Lock()
defer global.mu.Unlock()
for id, conn := range global.conns {
conn.Close()
delete(global.conns, id)
}
}
// Lua function implementations
func luaConnect(s *luajit.State) int {
if err := s.CheckExactArgs(2); err != nil {
return s.PushError("connect: %v", err)
}
driver, err := s.SafeToString(1)
if err != nil {
return s.PushError("connect: driver must be a string")
}
dsn, err := s.SafeToString(2)
if err != nil {
return s.PushError("connect: dsn must be a string")
}
connID, err := Connect(driver, dsn)
if err != nil {
return s.PushError("connect: %v", err)
}
s.PushString(connID)
return 1
}
func luaClose(s *luajit.State) int {
if err := s.CheckExactArgs(1); err != nil {
return s.PushError("close: %v", err)
}
connID, err := s.SafeToString(1)
if err != nil {
return s.PushError("close: connection id must be a string")
}
if err := CloseConnection(connID); err != nil {
return s.PushError("close: %v", err)
}
s.PushBoolean(true)
return 1
}
func luaPing(s *luajit.State) int {
if err := s.CheckExactArgs(1); err != nil {
return s.PushError("ping: %v", err)
}
connID, err := s.SafeToString(1)
if err != nil {
return s.PushError("ping: connection id must be a string")
}
conn, exists := GetConnection(connID)
if !exists {
return s.PushError("ping: connection not found")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := conn.Ping(ctx); err != nil {
return s.PushError("ping: %v", err)
}
s.PushBoolean(true)
return 1
}
func luaQuery(s *luajit.State) int {
if err := s.CheckMinArgs(2); err != nil {
return s.PushError("query: %v", err)
}
connID, err := s.SafeToString(1)
if err != nil {
return s.PushError("query: connection id must be a string")
}
query, err := s.SafeToString(2)
if err != nil {
return s.PushError("query: query must be a string")
}
conn, exists := GetConnection(connID)
if !exists {
return s.PushError("query: connection not found")
}
// Collect arguments
args := make([]any, s.GetTop()-2)
for i := 3; i <= s.GetTop(); i++ {
val, err := s.ToValue(i)
if err != nil {
args[i-3] = nil
} else {
args[i-3] = val
}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
rows, err := conn.Query(ctx, query, args...)
if err != nil {
return s.PushError("query: %v", err)
}
defer rows.Close()
// Get column names
columns, err := rows.Columns()
if err != nil {
return s.PushError("query: failed to get columns: %v", err)
}
// Build result array
s.CreateTable(0, 0)
rowIndex := 1
for rows.Next() {
// Create values slice for scanning
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return s.PushError("query: scan error: %v", err)
}
// Create row table
s.CreateTable(0, len(columns))
for i, col := range columns {
s.PushString(col)
if err := s.PushValue(values[i]); err != nil {
s.PushNil()
}
s.SetTable(-3)
}
// Add to result array
s.PushNumber(float64(rowIndex))
s.PushCopy(-2)
s.SetTable(-4)
s.Pop(1) // Remove row table copy
rowIndex++
}
if err := rows.Err(); err != nil {
return s.PushError("query: %v", err)
}
return 1
}
func luaExec(s *luajit.State) int {
if err := s.CheckMinArgs(2); err != nil {
return s.PushError("exec: %v", err)
}
connID, err := s.SafeToString(1)
if err != nil {
return s.PushError("exec: connection id must be a string")
}
query, err := s.SafeToString(2)
if err != nil {
return s.PushError("exec: query must be a string")
}
conn, exists := GetConnection(connID)
if !exists {
return s.PushError("exec: connection not found")
}
// Collect arguments
args := make([]any, s.GetTop()-2)
for i := 3; i <= s.GetTop(); i++ {
val, err := s.ToValue(i)
if err != nil {
args[i-3] = nil
} else {
args[i-3] = val
}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
result, err := conn.Exec(ctx, query, args...)
if err != nil {
return s.PushError("exec: %v", err)
}
// Return result info
s.CreateTable(0, 2)
lastID, _ := result.LastInsertId()
s.PushString("last_insert_id")
s.PushNumber(float64(lastID))
s.SetTable(-3)
affected, _ := result.RowsAffected()
s.PushString("rows_affected")
s.PushNumber(float64(affected))
s.SetTable(-3)
return 1
}
// GetFunctionList returns all Lua-callable functions
func GetFunctionList() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"sql_connect": luaConnect,
"sql_close": luaClose,
"sql_ping": luaPing,
"sql_query": luaQuery,
"sql_exec": luaExec,
}
}
func init() {
// Register SQLite driver on import
RegisterDriver("sqlite", &SQLiteDriver{})
}

384
modules/sql/sqlite.go Normal file
View File

@ -0,0 +1,384 @@
package sql
import (
"context"
"fmt"
"zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
)
// SQLiteDriver implements the Driver interface for SQLite
type SQLiteDriver struct{}
func (d *SQLiteDriver) Name() string {
return "sqlite"
}
func (d *SQLiteDriver) Open(dsn string) (Connection, error) {
conn, err := sqlite.OpenConn(dsn, sqlite.OpenReadWrite|sqlite.OpenCreate)
if err != nil {
return nil, fmt.Errorf("sqlite: failed to open database: %w", err)
}
return &SQLiteConnection{conn: conn}, nil
}
// SQLiteConnection implements the Connection interface
type SQLiteConnection struct {
conn *sqlite.Conn
}
func (c *SQLiteConnection) Close() error {
return c.conn.Close()
}
func (c *SQLiteConnection) Ping(ctx context.Context) error {
return sqlitex.ExecuteTransient(c.conn, "SELECT 1", nil)
}
func (c *SQLiteConnection) Begin(ctx context.Context) (Transaction, error) {
if err := sqlitex.ExecuteTransient(c.conn, "BEGIN", nil); err != nil {
return nil, fmt.Errorf("sqlite: failed to begin transaction: %w", err)
}
return &SQLiteTransaction{conn: c.conn}, nil
}
func (c *SQLiteConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
stmt, err := c.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("sqlite: failed to prepare query: %w", err)
}
if err := c.bindArgs(stmt, args...); err != nil {
stmt.Finalize()
return nil, err
}
return &SQLiteRows{stmt: stmt, hasNext: true}, nil
}
func (c *SQLiteConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
rows, err := c.Query(ctx, query, args...)
if err != nil {
return &SQLiteRow{err: err}
}
return &SQLiteRow{rows: rows.(*SQLiteRows)}
}
func (c *SQLiteConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
stmt, err := c.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
}
defer stmt.Finalize()
if err := c.bindArgs(stmt, args...); err != nil {
return nil, err
}
hasRow, err := stmt.Step()
if err != nil {
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
}
// Consume all rows if any
for hasRow {
hasRow, err = stmt.Step()
if err != nil {
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
}
}
return &SQLiteResult{
lastInsertID: c.conn.LastInsertRowID(),
rowsAffected: c.conn.Changes(),
}, nil
}
func (c *SQLiteConnection) Prepare(ctx context.Context, query string) (Statement, error) {
stmt, err := c.conn.Prepare(query)
if err != nil {
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
}
return &SQLiteStatement{stmt: stmt, conn: c.conn}, nil
}
func (c *SQLiteConnection) bindArgs(stmt *sqlite.Stmt, args ...any) error {
for i, arg := range args {
paramIndex := i + 1
if arg == nil {
stmt.BindNull(paramIndex)
continue
}
switch v := arg.(type) {
case int:
stmt.BindInt64(paramIndex, int64(v))
case int64:
stmt.BindInt64(paramIndex, v)
case float64:
stmt.BindFloat(paramIndex, v)
case string:
stmt.BindText(paramIndex, v)
case bool:
if v {
stmt.BindInt64(paramIndex, 1)
} else {
stmt.BindInt64(paramIndex, 0)
}
case []byte:
stmt.BindBytes(paramIndex, v)
default:
return fmt.Errorf("sqlite: unsupported parameter type: %T", arg)
}
}
return nil
}
// SQLiteTransaction implements the Transaction interface
type SQLiteTransaction struct {
conn *sqlite.Conn
}
func (t *SQLiteTransaction) Commit() error {
return sqlitex.ExecuteTransient(t.conn, "COMMIT", nil)
}
func (t *SQLiteTransaction) Rollback() error {
return sqlitex.ExecuteTransient(t.conn, "ROLLBACK", nil)
}
func (t *SQLiteTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
conn := &SQLiteConnection{conn: t.conn}
return conn.Query(ctx, query, args...)
}
func (t *SQLiteTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
conn := &SQLiteConnection{conn: t.conn}
return conn.QueryRow(ctx, query, args...)
}
func (t *SQLiteTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
conn := &SQLiteConnection{conn: t.conn}
return conn.Exec(ctx, query, args...)
}
func (t *SQLiteTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
conn := &SQLiteConnection{conn: t.conn}
return conn.Prepare(ctx, query)
}
// SQLiteRows implements the Rows interface
type SQLiteRows struct {
stmt *sqlite.Stmt
hasNext bool
err error
}
func (r *SQLiteRows) Next() bool {
if r.err != nil {
return false
}
if !r.hasNext {
return false
}
var err error
r.hasNext, err = r.stmt.Step()
if err != nil {
r.err = err
return false
}
return r.hasNext
}
func (r *SQLiteRows) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
for i, d := range dest {
if i >= r.stmt.ColumnCount() {
break
}
switch ptr := d.(type) {
case *any:
*ptr = r.getValue(i)
case *string:
*ptr = r.stmt.ColumnText(i)
case *int:
*ptr = int(r.stmt.ColumnInt64(i))
case *int64:
*ptr = r.stmt.ColumnInt64(i)
case *float64:
*ptr = r.stmt.ColumnFloat(i)
case *bool:
*ptr = r.stmt.ColumnInt64(i) != 0
case *[]byte:
if r.stmt.ColumnType(i) == sqlite.TypeBlob {
// Get blob size first
size := r.stmt.ColumnBytes(i, nil)
if size == 0 {
*ptr = []byte{}
} else {
buf := make([]byte, size)
r.stmt.ColumnBytes(i, buf)
*ptr = buf
}
} else {
// Convert text to bytes
*ptr = []byte(r.stmt.ColumnText(i))
}
default:
return fmt.Errorf("sqlite: unsupported scan destination type: %T", d)
}
}
return nil
}
func (r *SQLiteRows) getValue(index int) any {
switch r.stmt.ColumnType(index) {
case sqlite.TypeInteger:
return r.stmt.ColumnInt64(index)
case sqlite.TypeFloat:
return r.stmt.ColumnFloat(index)
case sqlite.TypeText:
return r.stmt.ColumnText(index)
case sqlite.TypeBlob:
// For blob columns, we need to handle this differently
// First, get the size by calling with nil buffer
size := r.stmt.ColumnBytes(index, nil)
if size == 0 {
return []byte{}
}
// Now allocate buffer and get the actual data
buf := make([]byte, size)
r.stmt.ColumnBytes(index, buf)
return buf
case sqlite.TypeNull:
return nil
default:
return r.stmt.ColumnText(index)
}
}
func (r *SQLiteRows) Columns() ([]string, error) {
if r.err != nil {
return nil, r.err
}
columns := make([]string, r.stmt.ColumnCount())
for i := range columns {
columns[i] = r.stmt.ColumnName(i)
}
return columns, nil
}
func (r *SQLiteRows) Close() error {
if r.stmt != nil {
return r.stmt.Finalize()
}
return nil
}
func (r *SQLiteRows) Err() error {
return r.err
}
// SQLiteRow implements the Row interface
type SQLiteRow struct {
rows *SQLiteRows
err error
}
func (r *SQLiteRow) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
if r.rows == nil {
return fmt.Errorf("sqlite: no rows available")
}
if !r.rows.Next() {
if r.rows.Err() != nil {
return r.rows.Err()
}
return fmt.Errorf("sqlite: no rows in result set")
}
return r.rows.Scan(dest...)
}
// SQLiteResult implements the Result interface
type SQLiteResult struct {
lastInsertID int64
rowsAffected int
}
func (r *SQLiteResult) LastInsertId() (int64, error) {
return r.lastInsertID, nil
}
func (r *SQLiteResult) RowsAffected() (int64, error) {
return int64(r.rowsAffected), nil
}
// SQLiteStatement implements the Statement interface
type SQLiteStatement struct {
stmt *sqlite.Stmt
conn *sqlite.Conn
}
func (s *SQLiteStatement) Close() error {
return s.stmt.Finalize()
}
func (s *SQLiteStatement) Query(ctx context.Context, args ...any) (Rows, error) {
conn := &SQLiteConnection{conn: s.conn}
if err := conn.bindArgs(s.stmt, args...); err != nil {
return nil, err
}
return &SQLiteRows{stmt: s.stmt, hasNext: true}, nil
}
func (s *SQLiteStatement) QueryRow(ctx context.Context, args ...any) Row {
rows, err := s.Query(ctx, args...)
if err != nil {
return &SQLiteRow{err: err}
}
return &SQLiteRow{rows: rows.(*SQLiteRows)}
}
func (s *SQLiteStatement) Exec(ctx context.Context, args ...any) (Result, error) {
conn := &SQLiteConnection{conn: s.conn}
if err := conn.bindArgs(s.stmt, args...); err != nil {
return nil, err
}
hasRow, err := s.stmt.Step()
if err != nil {
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
}
// Consume all rows if any
for hasRow {
hasRow, err = s.stmt.Step()
if err != nil {
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
}
}
return &SQLiteResult{
lastInsertID: s.conn.LastInsertRowID(),
rowsAffected: s.conn.Changes(),
}, nil
}

502
modules/sqlite/sqlite.lua Normal file
View File

@ -0,0 +1,502 @@
local sqlite = {}
local Connection = {}
Connection.__index = Connection
function Connection:close()
if self._id then
local ok = moonshark.sql_close(self._id)
self._id = nil
return ok
end
return false
end
function Connection:ping()
if not self._id then
error("Connection is closed")
end
return moonshark.sql_ping(self._id)
end
function Connection:query(query_str, ...)
if not self._id then
error("Connection is closed")
end
return moonshark.sql_query(self._id, query_str:normalize_whitespace(), ...)
end
function Connection:exec(query_str, ...)
if not self._id then
error("Connection is closed")
end
return moonshark.sql_exec(self._id, query_str:normalize_whitespace(), ...)
end
function Connection:query_row(query_str, ...)
local results = self:query(query_str, ...)
return results and #results > 0 and results[1] or nil
end
function Connection:query_value(query_str, ...)
local row = self:query_row(query_str, ...)
if row then
for _, value in pairs(row) do
return value
end
end
return nil
end
function Connection:begin()
local result = self:exec("BEGIN")
if result then
return {
conn = self,
active = true,
commit = function(tx)
if tx.active then
tx.active = false
return tx.conn:exec("COMMIT")
end
return false
end,
rollback = function(tx)
if tx.active then
tx.active = false
return tx.conn:exec("ROLLBACK")
end
return false
end,
query = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:query(query_str, ...)
end,
exec = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:exec(query_str, ...)
end,
query_row = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:query_row(query_str, ...)
end,
query_value = function(tx, query_str, ...)
if not tx.active then error("Transaction is not active") end
return tx.conn:query_value(query_str, ...)
end
}
end
return nil
end
function Connection:insert(table_name, data)
if table_name:is_blank() then
error("Table name cannot be empty")
end
local keys = table.keys(data)
local values = table.values(data)
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}})":parse({
table = table_name,
columns = keys:join(", "),
placeholders = placeholders
})
return self:exec(query, unpack(values))
end
function Connection:upsert(table_name, data, conflict_columns)
if table_name:is_blank() then
error("Table name cannot be empty")
end
local keys = table.keys(data)
local values = table.values(data)
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
local updates = table.map(keys, function(key) return key .. " = excluded." .. key end):join(", ")
local conflict_clause = ""
if conflict_columns then
if type(conflict_columns) == "string" then
conflict_clause = "(" .. conflict_columns .. ")"
else
conflict_clause = "(" .. table.concat(conflict_columns, ", ") .. ")"
end
end
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}}) ON CONFLICT {{conflict}} DO UPDATE SET {{updates}}":parse({
table = table_name,
columns = keys:join(", "),
placeholders = placeholders,
conflict = conflict_clause,
updates = updates
})
return self:exec(query, unpack(values))
end
function Connection:update(table_name, data, where_clause, ...)
if table_name:is_blank() then
error("Table name cannot be empty")
end
if where_clause:is_blank() then
error("WHERE clause cannot be empty for UPDATE")
end
local keys = table.keys(data)
local values = table.values(data)
local sets = table.map(keys, function(key) return key .. " = ?" end):join(", ")
local query = "UPDATE {{table}} SET {{sets}} WHERE {{where}}":parse({
table = table_name,
sets = sets,
where = where_clause
})
table.extend(values, {...})
return self:exec(query, unpack(values))
end
function Connection:delete(table_name, where_clause, ...)
if table_name:is_blank() then
error("Table name cannot be empty")
end
if where_clause:is_blank() then
error("WHERE clause cannot be empty for DELETE")
end
local query = "DELETE FROM {{table}} WHERE {{where}}":parse({
table = table_name,
where = where_clause
})
return self:exec(query, ...)
end
function Connection:select(table_name, columns, where_clause, ...)
if table_name:is_blank() then
error("Table name cannot be empty")
end
columns = columns or "*"
if type(columns) == "table" then
columns = table.concat(columns, ", ")
end
if where_clause and not where_clause:is_blank() then
local query = "SELECT {{columns}} FROM {{table}} WHERE {{where}}":parse({
columns = columns,
table = table_name,
where = where_clause
})
return self:query(query, ...)
else
local query = "SELECT {{columns}} FROM {{table}}":parse({
columns = columns,
table = table_name
})
return self:query(query)
end
end
function Connection:table_exists(table_name)
if table_name:is_blank() then return false end
return self:query_value("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table_name:trim()) ~= nil
end
function Connection:column_exists(table_name, column_name)
if table_name:is_blank() or column_name:is_blank() then return false end
local result = self:query("PRAGMA table_info({{table}})":parse({table = table_name}))
if result then
return table.any(result, function(row)
return row.name:iequals(column_name:trim())
end)
end
return false
end
function Connection:create_table(table_name, schema)
if table_name:is_blank() or schema:is_blank() then
error("Table name and schema cannot be empty")
end
local query = "CREATE TABLE IF NOT EXISTS {{table}} ({{schema}})":parse({
table = table_name,
schema = schema:trim()
})
return self:exec(query)
end
function Connection:drop_table(table_name)
if table_name:is_blank() then
error("Table name cannot be empty")
end
return self:exec("DROP TABLE IF EXISTS {{table}}":parse({table = table_name}))
end
function Connection:add_column(table_name, column_def)
if table_name:is_blank() or column_def:is_blank() then
error("Table name and column definition cannot be empty")
end
local query = "ALTER TABLE {{table}} ADD COLUMN {{column}}":parse({
table = table_name,
column = column_def:trim()
})
return self:exec(query)
end
function Connection:create_index(index_name, table_name, columns, unique)
if index_name:is_blank() or table_name:is_blank() then
error("Index name and table name cannot be empty")
end
local unique_clause = unique and "UNIQUE " or ""
local columns_str = type(columns) == "table" and table.concat(columns, ", ") or tostring(columns)
local query = "CREATE {{unique}}INDEX IF NOT EXISTS {{index}} ON {{table}} ({{columns}})":parse({
unique = unique_clause,
index = index_name,
table = table_name,
columns = columns_str
})
return self:exec(query)
end
function Connection:drop_index(index_name)
if index_name:is_blank() then
error("Index name cannot be empty")
end
return self:exec("DROP INDEX IF EXISTS {{index}}":parse({index = index_name}))
end
-- SQLite-specific functions
function Connection:vacuum()
return self:exec("VACUUM")
end
function Connection:analyze()
return self:exec("ANALYZE")
end
function Connection:integrity_check()
return self:query("PRAGMA integrity_check")
end
function Connection:foreign_keys(enabled)
local value = enabled and "ON" or "OFF"
return self:exec("PRAGMA foreign_keys = {{value}}":parse({value = value}))
end
function Connection:journal_mode(mode)
mode = (mode or "WAL"):upper()
local valid_modes = {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}
if not table.contains(valid_modes, mode) then
error("Invalid journal mode: " .. mode)
end
return self:query("PRAGMA journal_mode = {{mode}}":parse({mode = mode}))
end
function Connection:synchronous(level)
level = (level or "NORMAL"):upper()
local valid_levels = {"OFF", "NORMAL", "FULL", "EXTRA"}
if not table.contains(valid_levels, level) then
error("Invalid synchronous level: " .. level)
end
return self:exec("PRAGMA synchronous = {{level}}":parse({level = level}))
end
function Connection:cache_size(size)
size = size or -64000
if type(size) ~= "number" then
error("Cache size must be a number")
end
return self:exec("PRAGMA cache_size = {{size}}":parse({size = tostring(size)}))
end
function Connection:temp_store(mode)
mode = (mode or "MEMORY"):upper()
local valid_modes = {"DEFAULT", "FILE", "MEMORY"}
if not table.contains(valid_modes, mode) then
error("Invalid temp_store mode: " .. mode)
end
return self:exec("PRAGMA temp_store = {{mode}}":parse({mode = mode}))
end
-- Connection management
function sqlite.open(database_path)
database_path = database_path or ":memory:"
if database_path ~= ":memory:" and database_path:is_blank() then
database_path = ":memory:"
end
local conn_id = moonshark.sql_connect("sqlite", database_path:trim())
if conn_id then
return setmetatable({_id = conn_id}, Connection)
end
return nil
end
sqlite.connect = sqlite.open
-- Quick execution functions
function sqlite.query(database_path, query_str, ...)
local conn = sqlite.open(database_path)
if not conn then
error("Failed to open SQLite database: {{path}}":parse({path = database_path or ":memory:"}))
end
local results = conn:query(query_str, ...)
conn:close()
return results
end
function sqlite.exec(database_path, query_str, ...)
local conn = sqlite.open(database_path)
if not conn then
error("Failed to open SQLite database: {{path}}":parse({path = database_path or ":memory:"}))
end
local result = conn:exec(query_str, ...)
conn:close()
return result
end
function sqlite.query_row(database_path, query_str, ...)
local results = sqlite.query(database_path, query_str, ...)
return results and #results > 0 and results[1] or nil
end
function sqlite.query_value(database_path, query_str, ...)
local row = sqlite.query_row(database_path, query_str, ...)
if row then
for _, value in pairs(row) do
return value
end
end
return nil
end
-- Migration helpers
function sqlite.migrate(database_path, migrations)
local conn = sqlite.open(database_path)
if not conn then
error("Failed to open SQLite database for migration")
end
conn:create_table("_migrations", "id INTEGER PRIMARY KEY, name TEXT UNIQUE, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP")
local tx = conn:begin()
if not tx then
conn:close()
error("Failed to begin migration transaction")
end
for _, migration in ipairs(migrations) do
if not migration.name or migration.name:is_blank() then
tx:rollback()
conn:close()
error("Migration must have a non-empty name")
end
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?", migration.name:trim())
if not existing then
local ok, err = pcall(function()
if type(migration.up) == "string" then
conn:exec(migration.up)
elseif type(migration.up) == "function" then
migration.up(conn)
else
error("Migration 'up' must be string or function")
end
end)
if ok then
conn:exec("INSERT INTO _migrations (name) VALUES (?)", migration.name:trim())
print("Applied migration: {{name}}":parse({name = migration.name}))
else
tx:rollback()
conn:close()
error("Migration '{{name}}' failed: {{error}}":parse({
name = migration.name,
error = err or "unknown error"
}))
end
end
end
tx:commit()
conn:close()
return true
end
-- Result processing utilities
function sqlite.to_array(results, column_name)
if not results or table.is_empty(results) then return {} end
if column_name:is_blank() then error("Column name cannot be empty") end
return table.map(results, function(row) return row[column_name] end)
end
function sqlite.to_map(results, key_column, value_column)
if not results or table.is_empty(results) then return {} end
if key_column:is_blank() then error("Key column name cannot be empty") end
local map = {}
for _, row in ipairs(results) do
local key = row[key_column]
map[key] = value_column and row[value_column] or row
end
return map
end
function sqlite.group_by(results, column_name)
if not results or table.is_empty(results) then return {} end
if column_name:is_blank() then error("Column name cannot be empty") end
return table.group_by(results, function(row) return row[column_name] end)
end
function sqlite.print_results(results)
if not results or table.is_empty(results) then
print("No results")
return
end
local columns = table.keys(results[1])
table.sort(columns)
-- Calculate column widths
local widths = {}
for _, col in ipairs(columns) do
widths[col] = col:length()
for _, row in ipairs(results) do
local value = tostring(row[col] or "")
widths[col] = math.max(widths[col], value:length())
end
end
-- Print header and separator
local header_parts = table.map(columns, function(col) return col:pad_right(widths[col]) end)
local separator_parts = table.map(columns, function(col) return string.repeat_("-", widths[col]) end)
print(table.concat(header_parts, " | "))
print(table.concat(separator_parts, "-+-"))
-- Print rows
for _, row in ipairs(results) do
local value_parts = table.map(columns, function(col)
local value = tostring(row[col] or "")
return value:pad_right(widths[col])
end)
print(table.concat(value_parts, " | "))
end
end
return sqlite

113
modules/string+/string.go Normal file
View File

@ -0,0 +1,113 @@
package string
import (
"regexp"
"unicode/utf8"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
func GetFunctionList() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"string_slice": string_slice,
"string_reverse": string_reverse,
"string_length": string_length,
"string_byte_length": string_byte_length,
"regex_replace": regex_replace,
"string_is_valid_utf8": string_is_valid_utf8,
}
}
func string_slice(s *luajit.State) int {
str := s.ToString(1)
start := int(s.ToNumber(2))
if !utf8.ValidString(str) {
s.PushNil()
s.PushString("invalid UTF-8")
return 2
}
runes := []rune(str)
length := len(runes)
startIdx := max(start-1, 0) // Convert from 1-indexed
if startIdx >= length {
s.PushString("")
return 1
}
endIdx := length
if s.GetTop() >= 3 && !s.IsNil(3) {
end := int(s.ToNumber(3))
if end < 0 {
endIdx = length + end + 1
} else {
endIdx = end
}
if endIdx < 0 {
endIdx = 0
}
if endIdx > length {
endIdx = length
}
}
if startIdx >= endIdx {
s.PushString("")
return 1
}
s.PushString(string(runes[startIdx:endIdx]))
return 1
}
func string_reverse(s *luajit.State) int {
str := s.ToString(1)
if !utf8.ValidString(str) {
s.PushNil()
s.PushString("invalid UTF-8")
return 2
}
runes := []rune(str)
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
runes[i], runes[j] = runes[j], runes[i]
}
s.PushString(string(runes))
return 1
}
func string_length(s *luajit.State) int {
str := s.ToString(1)
s.PushNumber(float64(utf8.RuneCountInString(str)))
return 1
}
func string_byte_length(s *luajit.State) int {
str := s.ToString(1)
s.PushNumber(float64(len(str)))
return 1
}
func regex_replace(s *luajit.State) int {
pattern := s.ToString(1)
str := s.ToString(2)
replacement := s.ToString(3)
re, err := regexp.Compile(pattern)
if err != nil {
s.PushString(str)
return 1
}
result := re.ReplaceAllString(str, replacement)
s.PushString(result)
return 1
}
func string_is_valid_utf8(s *luajit.State) int {
str := s.ToString(1)
s.PushBoolean(utf8.ValidString(str))
return 1
}

1079
modules/string+/string.lua Normal file

File diff suppressed because it is too large Load Diff

1157
modules/table+/table.lua Normal file

File diff suppressed because it is too large Load Diff

168
moonshark.go Normal file
View File

@ -0,0 +1,168 @@
package main
import (
"flag"
"fmt"
"log"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
"Moonshark/modules/http"
"Moonshark/modules/sql"
"Moonshark/state"
)
var (
watchFlag = flag.Bool("watch", false, "Watch script files for changes and restart")
wFlag = flag.Bool("w", false, "Watch script files for changes and restart")
)
func main() {
flag.Parse()
if flag.NArg() < 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [--watch|-w] <script.lua>\n", filepath.Base(os.Args[0]))
os.Exit(1)
}
scriptPath := flag.Arg(0)
watchMode := *watchFlag || *wFlag
if watchMode {
runWithWatcher(scriptPath)
} else {
runOnce(scriptPath)
}
}
func runOnce(scriptPath string) {
luaState, err := state.NewFromScript(scriptPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
defer luaState.Close()
if err := luaState.ExecuteFile(scriptPath); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
if http.HasActiveServers() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
fmt.Println("HTTP servers running. Press Ctrl+C to exit.")
go func() {
<-sigChan
fmt.Println("\nShutting down...")
// Close main state first (saves KV stores)
luaState.Close()
// Then stop servers (closes worker states)
http.StopAllServers()
sql.CloseAllConnections()
os.Exit(0)
}()
http.WaitForServers()
}
}
func runWithWatcher(scriptPath string) {
watcher, err := NewFileWatcher(500) // 500ms debounce
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create file watcher: %v\n", err)
os.Exit(1)
}
defer watcher.Close()
if err := watcher.DiscoverRequiredFiles(scriptPath); err != nil {
fmt.Fprintf(os.Stderr, "Failed to watch files: %v\n", err)
os.Exit(1)
}
restartCh := watcher.Start()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
fmt.Printf("Starting %s in watch mode...\n", scriptPath)
var hadError bool
firstRun := true
for {
// If we had an error on the last run and this isn't the first run,
// wait for file changes before retrying
if hadError && !firstRun {
fmt.Println("Waiting for file changes before retrying...")
select {
case <-restartCh:
fmt.Println("Files changed, retrying...")
case <-sigChan:
fmt.Println("\nExiting...")
return
}
}
firstRun = false
hadError = false
// Clear cache before each run
state.ClearCache()
// Create and run state
luaState, err := state.NewFromScript(scriptPath)
if err != nil {
log.Printf("Error creating state: %v", err)
hadError = true
continue
}
if err := luaState.ExecuteFile(scriptPath); err != nil {
log.Printf("Execution error: %v", err)
luaState.Close()
hadError = true
continue
}
// If not a long-running process, wait for changes and restart
if !http.HasActiveServers() {
fmt.Println("Script completed. Waiting for changes...")
luaState.Close()
select {
case <-restartCh:
fmt.Println("Files changed, restarting...")
continue
case <-sigChan:
fmt.Println("\nExiting...")
return
}
}
// Long-running process - wait for restart signal or exit signal
fmt.Println("HTTP servers running. Watching for file changes...")
select {
case <-restartCh:
fmt.Println("Files changed, restarting...")
http.StopAllServers()
luaState.Close()
sql.CloseAllConnections()
time.Sleep(100 * time.Millisecond) // Brief pause for cleanup
continue
case <-sigChan:
fmt.Println("\nShutting down...")
http.StopAllServers()
luaState.Close()
sql.CloseAllConnections()
return
}
}
}

View File

@ -1,90 +0,0 @@
package router
import (
"os"
"path/filepath"
"strings"
)
// buildRoutes scans the routes directory and builds the routing tree
func (r *LuaRouter) buildRoutes() error {
r.failedRoutes = make(map[string]*RouteError)
r.middlewareFiles = make(map[string][]string)
// First pass: collect all middleware files
err := filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
return err
}
if strings.TrimSuffix(info.Name(), ".lua") == "middleware" {
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
if err != nil {
return err
}
fsPath := "/"
if relDir != "." {
fsPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
}
// Use filesystem path for middleware (includes groups)
r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path)
}
return nil
})
if err != nil {
return err
}
// Second pass: build routes with combined middleware + handler
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
return err
}
fileName := strings.TrimSuffix(info.Name(), ".lua")
// Skip middleware files (already processed)
if fileName == "middleware" {
return nil
}
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
if err != nil {
return err
}
fsPath := "/"
if relDir != "." {
fsPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
}
pathInfo := parsePathWithGroups(fsPath)
// Handle index.lua files
if fileName == "index" {
for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"} {
root := r.routes[method]
node := r.findOrCreateNode(root, pathInfo.urlPath)
node.indexFile = path
node.modTime = info.ModTime()
node.fsPath = pathInfo.fsPath
r.compileWithMiddleware(node, pathInfo.fsPath, path)
}
return nil
}
// Handle method files
method := strings.ToUpper(fileName)
root, exists := r.routes[method]
if !exists {
return nil
}
r.addRoute(root, pathInfo, path, info.ModTime())
return nil
})
}

View File

@ -1,62 +0,0 @@
package router
import (
"encoding/binary"
"hash/fnv"
"time"
"github.com/VictoriaMetrics/fastcache"
)
// hashString generates a hash for a string
func hashString(s string) uint64 {
h := fnv.New64a()
h.Write([]byte(s))
return h.Sum64()
}
// uint64ToBytes converts a uint64 to bytes for cache key
func uint64ToBytes(n uint64) []byte {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, n)
return b
}
// getCacheKey generates a cache key for a method and path
func getCacheKey(method, path string) []byte {
key := hashString(method + ":" + path)
return uint64ToBytes(key)
}
// getBytecodeKey generates a cache key for a handler path
func getBytecodeKey(handlerPath string) []byte {
key := hashString(handlerPath)
return uint64ToBytes(key)
}
// ClearCache clears all caches
func (r *LuaRouter) ClearCache() {
r.routeCache.Reset()
r.bytecodeCache.Reset()
r.middlewareCache = make(map[string][]byte)
r.sourceCache = make(map[string][]byte)
r.sourceMtimes = make(map[string]time.Time)
}
// GetCacheStats returns statistics about the cache
func (r *LuaRouter) GetCacheStats() map[string]any {
var routeStats fastcache.Stats
var bytecodeStats fastcache.Stats
r.routeCache.UpdateStats(&routeStats)
r.bytecodeCache.UpdateStats(&bytecodeStats)
return map[string]any{
"routeEntries": routeStats.EntriesCount,
"routeBytes": routeStats.BytesSize,
"routeCollisions": routeStats.Collisions,
"bytecodeEntries": bytecodeStats.EntriesCount,
"bytecodeBytes": bytecodeStats.BytesSize,
"bytecodeCollisions": bytecodeStats.Collisions,
}
}

View File

@ -1,176 +0,0 @@
package router
import (
"os"
"strings"
"time"
)
// compileWithMiddleware combines middleware and handler source, then compiles
func (r *LuaRouter) compileWithMiddleware(n *node, fsPath, scriptPath string) error {
if scriptPath == "" {
return nil
}
// Check if we need to recompile by comparing modification times
sourceKey := r.getSourceCacheKey(fsPath, scriptPath)
needsRecompile := false
// Check handler modification time
handlerInfo, err := os.Stat(scriptPath)
if err != nil {
n.err = err
return err
}
lastCompiled, exists := r.sourceMtimes[sourceKey]
if !exists || handlerInfo.ModTime().After(lastCompiled) {
needsRecompile = true
}
// Check middleware modification times
if !needsRecompile {
middlewareChain := r.getMiddlewareChain(fsPath)
for _, mwPath := range middlewareChain {
mwInfo, err := os.Stat(mwPath)
if err != nil {
n.err = err
return err
}
if mwInfo.ModTime().After(lastCompiled) {
needsRecompile = true
break
}
}
}
// Use cached bytecode if available and fresh
if !needsRecompile {
if bytecode, exists := r.sourceCache[sourceKey]; exists {
bytecodeKey := getBytecodeKey(scriptPath)
r.bytecodeCache.Set(bytecodeKey, bytecode)
return nil
}
}
// Build combined source
combinedSource, err := r.buildCombinedSource(fsPath, scriptPath)
if err != nil {
n.err = err
return err
}
// Compile combined source using shared state
r.compileStateMu.Lock()
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
r.compileStateMu.Unlock()
if err != nil {
n.err = err
return err
}
// Cache everything
bytecodeKey := getBytecodeKey(scriptPath)
r.bytecodeCache.Set(bytecodeKey, bytecode)
r.sourceCache[sourceKey] = bytecode
r.sourceMtimes[sourceKey] = time.Now()
n.err = nil
return nil
}
// buildCombinedSource builds the combined middleware + handler source
func (r *LuaRouter) buildCombinedSource(fsPath, scriptPath string) (string, error) {
var combinedSource strings.Builder
// Get middleware chain using filesystem path
middlewareChain := r.getMiddlewareChain(fsPath)
// Add middleware in order
for _, mwPath := range middlewareChain {
content, err := r.getFileContent(mwPath)
if err != nil {
return "", err
}
combinedSource.WriteString("-- Middleware: ")
combinedSource.WriteString(mwPath)
combinedSource.WriteString("\n")
combinedSource.Write(content)
combinedSource.WriteString("\n")
}
// Add main handler
content, err := r.getFileContent(scriptPath)
if err != nil {
return "", err
}
combinedSource.WriteString("-- Handler: ")
combinedSource.WriteString(scriptPath)
combinedSource.WriteString("\n")
combinedSource.Write(content)
return combinedSource.String(), nil
}
// getFileContent reads file content with caching
func (r *LuaRouter) getFileContent(path string) ([]byte, error) {
// Check cache first
if content, exists := r.middlewareCache[path]; exists {
// Verify file hasn't changed
info, err := os.Stat(path)
if err == nil {
if cachedTime, exists := r.sourceMtimes[path]; exists && !info.ModTime().After(cachedTime) {
return content, nil
}
}
}
// Read from disk
content, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Cache it
r.middlewareCache[path] = content
r.sourceMtimes[path] = time.Now()
return content, nil
}
// getSourceCacheKey generates a unique key for combined source
func (r *LuaRouter) getSourceCacheKey(fsPath, scriptPath string) string {
middlewareChain := r.getMiddlewareChain(fsPath)
var keyParts []string
keyParts = append(keyParts, middlewareChain...)
keyParts = append(keyParts, scriptPath)
return strings.Join(keyParts, "|")
}
// getMiddlewareChain returns middleware files that apply to the given filesystem path
func (r *LuaRouter) getMiddlewareChain(fsPath string) []string {
var chain []string
// Collect middleware from root to specific path using filesystem path (includes groups)
pathParts := strings.Split(strings.Trim(fsPath, "/"), "/")
if pathParts[0] == "" {
pathParts = []string{}
}
// Add root middleware
if mw, exists := r.middlewareFiles["/"]; exists {
chain = append(chain, mw...)
}
// Add middleware from each path level (including groups)
currentPath := ""
for _, part := range pathParts {
currentPath += "/" + part
if mw, exists := r.middlewareFiles[currentPath]; exists {
chain = append(chain, mw...)
}
}
return chain
}

View File

@ -1,25 +0,0 @@
package router
import "errors"
var (
// ErrRoutesCompilationErrors indicates that some routes failed to compile
// but the router is still operational
ErrRoutesCompilationErrors = errors.New("some routes failed to compile")
)
// RouteError represents an error with a specific route
type RouteError struct {
Path string // The URL path
Method string // HTTP method
ScriptPath string // Path to the Lua script
Err error // The actual error
}
// Error returns the error message
func (re *RouteError) Error() string {
if re.Err == nil {
return "unknown route error"
}
return re.Err.Error()
}

View File

@ -1,187 +0,0 @@
package router
import (
"os"
"strings"
)
// Match finds a handler for the given method and path (URL path, excludes groups)
func (r *LuaRouter) Match(method, path string, params *Params) (*node, bool) {
params.Count = 0
r.mu.RLock()
root, exists := r.routes[method]
r.mu.RUnlock()
if !exists {
return nil, false
}
segments := strings.Split(strings.Trim(path, "/"), "/")
return r.matchPath(root, segments, params, 0)
}
// matchPath recursively matches a path against the routing tree
func (r *LuaRouter) matchPath(current *node, segments []string, params *Params, depth int) (*node, bool) {
// Filter empty segments
filteredSegments := segments[:0]
for _, segment := range segments {
if segment != "" {
filteredSegments = append(filteredSegments, segment)
}
}
segments = filteredSegments
if len(segments) == 0 {
if current.handler != "" {
return current, true
}
if current.indexFile != "" {
return current, true
}
return nil, false
}
segment := segments[0]
remaining := segments[1:]
// Try static child first
if child, exists := current.staticChild[segment]; exists {
if node, found := r.matchPath(child, remaining, params, depth+1); found {
return node, true
}
}
// Try parameter child
if current.paramChild != nil {
if params.Count < maxParams {
params.Keys[params.Count] = current.paramChild.paramName
params.Values[params.Count] = segment
params.Count++
}
if node, found := r.matchPath(current.paramChild, remaining, params, depth+1); found {
return node, true
}
params.Count--
}
// Fall back to index.lua
if current.indexFile != "" {
return current, true
}
return nil, false
}
// GetRouteInfo returns bytecode, script path, error, params, and found status
func (r *LuaRouter) GetRouteInfo(method, path []byte) ([]byte, string, error, *Params, bool) {
// Convert to string for internal processing
methodStr := string(method)
pathStr := string(path)
routeCacheKey := getCacheKey(methodStr, pathStr)
routeCacheData := r.routeCache.Get(nil, routeCacheKey)
// Fast path: found in cache
if len(routeCacheData) > 0 {
handlerPath := string(routeCacheData[8:])
bytecodeKey := routeCacheData[:8]
bytecode := r.bytecodeCache.Get(nil, bytecodeKey)
n, exists := r.nodeForHandler(handlerPath)
if !exists {
r.routeCache.Del(routeCacheKey)
return nil, "", nil, nil, false
}
// Check if recompilation needed
if len(bytecode) > 0 {
// For cached routes, we need to re-match to get params
params := &Params{}
r.Match(methodStr, pathStr, params)
return bytecode, handlerPath, n.err, params, true
}
// Recompile if needed
fileInfo, err := os.Stat(handlerPath)
if err != nil || fileInfo.ModTime().After(n.modTime) {
scriptPath := n.handler
if scriptPath == "" {
scriptPath = n.indexFile
}
fsPath := n.fsPath
if fsPath == "" {
fsPath = "/"
}
if err := r.compileWithMiddleware(n, fsPath, scriptPath); err != nil {
params := &Params{}
r.Match(methodStr, pathStr, params)
return nil, handlerPath, n.err, params, true
}
newBytecodeKey := getBytecodeKey(handlerPath)
bytecode = r.bytecodeCache.Get(nil, newBytecodeKey)
newCacheData := make([]byte, 8+len(handlerPath))
copy(newCacheData[:8], newBytecodeKey)
copy(newCacheData[8:], handlerPath)
r.routeCache.Set(routeCacheKey, newCacheData)
params := &Params{}
r.Match(methodStr, pathStr, params)
return bytecode, handlerPath, n.err, params, true
}
params := &Params{}
r.Match(methodStr, pathStr, params)
return bytecode, handlerPath, n.err, params, true
}
// Slow path: lookup and compile
params := &Params{}
node, found := r.Match(methodStr, pathStr, params)
if !found {
return nil, "", nil, nil, false
}
scriptPath := node.handler
if scriptPath == "" && node.indexFile != "" {
scriptPath = node.indexFile
}
if scriptPath == "" {
return nil, "", nil, nil, false
}
bytecodeKey := getBytecodeKey(scriptPath)
bytecode := r.bytecodeCache.Get(nil, bytecodeKey)
if len(bytecode) == 0 {
fsPath := node.fsPath
if fsPath == "" {
fsPath = "/"
}
if err := r.compileWithMiddleware(node, fsPath, scriptPath); err != nil {
return nil, scriptPath, node.err, params, true
}
bytecode = r.bytecodeCache.Get(nil, bytecodeKey)
}
// Cache the route
cacheData := make([]byte, 8+len(scriptPath))
copy(cacheData[:8], bytecodeKey)
copy(cacheData[8:], scriptPath)
r.routeCache.Set(routeCacheKey, cacheData)
return bytecode, scriptPath, node.err, params, true
}
// GetRouteInfoString is a convenience method that accepts strings
func (r *LuaRouter) GetRouteInfoString(method, path string) ([]byte, string, error, *Params, bool) {
return r.GetRouteInfo([]byte(method), []byte(path))
}

View File

@ -1,190 +0,0 @@
package router
import (
"path/filepath"
"strings"
"time"
)
// node represents a node in the radix trie
type node struct {
// Static children mapped by path segment
staticChild map[string]*node
// Parameter child for dynamic segments (e.g., :id)
paramChild *node
paramName string
// Handler information
handler string // Path to the handler file
indexFile string // Path to index.lua if exists
modTime time.Time // Modification time of the handler
fsPath string // Filesystem path (includes groups)
// Compilation error if any
err error
}
// pathInfo holds both URL path and filesystem path
type pathInfo struct {
urlPath string // URL path without groups (e.g., /users)
fsPath string // Filesystem path with groups (e.g., /(admin)/users)
}
// parsePathWithGroups parses a filesystem path, extracting groups
func parsePathWithGroups(fsPath string) *pathInfo {
segments := strings.Split(strings.Trim(fsPath, "/"), "/")
var urlSegments []string
for _, segment := range segments {
if segment == "" {
continue
}
// Skip group segments (enclosed in parentheses)
if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") {
continue
}
urlSegments = append(urlSegments, segment)
}
urlPath := "/"
if len(urlSegments) > 0 {
urlPath = "/" + strings.Join(urlSegments, "/")
}
return &pathInfo{
urlPath: urlPath,
fsPath: fsPath,
}
}
// findOrCreateNode finds or creates a node at the given URL path
func (r *LuaRouter) findOrCreateNode(root *node, urlPath string) *node {
segments := strings.Split(strings.Trim(urlPath, "/"), "/")
if len(segments) == 1 && segments[0] == "" {
return root
}
current := root
for _, segment := range segments {
if segment == "" {
continue
}
// Check if it's a parameter
if strings.HasPrefix(segment, ":") {
paramName := segment[1:]
if current.paramChild == nil {
current.paramChild = &node{
staticChild: make(map[string]*node),
paramName: paramName,
}
}
current = current.paramChild
} else {
// Static segment
if _, exists := current.staticChild[segment]; !exists {
current.staticChild[segment] = &node{
staticChild: make(map[string]*node),
}
}
current = current.staticChild[segment]
}
}
return current
}
// addRoute adds a route to the tree
func (r *LuaRouter) addRoute(root *node, pathInfo *pathInfo, handlerPath string, modTime time.Time) {
node := r.findOrCreateNode(root, pathInfo.urlPath)
node.handler = handlerPath
node.modTime = modTime
node.fsPath = pathInfo.fsPath
// Compile the route with middleware
r.compileWithMiddleware(node, pathInfo.fsPath, handlerPath)
// Track failed routes
if node.err != nil {
key := filepath.Base(handlerPath) + ":" + pathInfo.urlPath
r.failedRoutes[key] = &RouteError{
Path: pathInfo.urlPath,
Method: strings.ToUpper(strings.TrimSuffix(filepath.Base(handlerPath), ".lua")),
ScriptPath: handlerPath,
Err: node.err,
}
}
}
// nodeForHandler finds a node by its handler path
func (r *LuaRouter) nodeForHandler(handlerPath string) (*node, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, root := range r.routes {
if node := findNodeByHandler(root, handlerPath); node != nil {
return node, true
}
}
return nil, false
}
// findNodeByHandler recursively searches for a node with the given handler path
func findNodeByHandler(n *node, handlerPath string) *node {
if n.handler == handlerPath || n.indexFile == handlerPath {
return n
}
// Search static children
for _, child := range n.staticChild {
if found := findNodeByHandler(child, handlerPath); found != nil {
return found
}
}
// Search param child
if n.paramChild != nil {
if found := findNodeByHandler(n.paramChild, handlerPath); found != nil {
return found
}
}
return nil
}
// countNodesAndBytecode counts nodes and bytecode size in the tree
func countNodesAndBytecode(n *node) (int, int64) {
if n == nil {
return 0, 0
}
count := 0
bytes := int64(0)
// Count this node if it has a handler
if n.handler != "" || n.indexFile != "" {
count = 1
// Estimate bytecode size (would need actual bytecode cache lookup for accuracy)
bytes = 1024 // Placeholder
}
// Count static children
for _, child := range n.staticChild {
c, b := countNodesAndBytecode(child)
count += c
bytes += b
}
// Count param child
if n.paramChild != nil {
c, b := countNodesAndBytecode(n.paramChild)
count += c
bytes += b
}
return count, bytes
}

View File

@ -1,44 +0,0 @@
package router
// Maximum number of URL parameters per route
const maxParams = 20
// Params holds URL parameters with fixed-size arrays to avoid allocations
type Params struct {
Keys [maxParams]string
Values [maxParams]string
Count int
}
// Get returns a parameter value by name
func (p *Params) Get(name string) string {
for i := range p.Count {
if p.Keys[i] == name {
return p.Values[i]
}
}
return ""
}
// Reset clears all parameters
func (p *Params) Reset() {
p.Count = 0
}
// Set adds or updates a parameter
func (p *Params) Set(name, value string) {
// Try to update existing
for i := range p.Count {
if p.Keys[i] == name {
p.Values[i] = value
return
}
}
// Add new if space available
if p.Count < maxParams {
p.Keys[p.Count] = name
p.Values[p.Count] = value
p.Count++
}
}

View File

@ -1,156 +0,0 @@
package router
import (
"errors"
"os"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/VictoriaMetrics/fastcache"
)
// Default cache sizes
const (
defaultBytecodeMaxBytes = 32 * 1024 * 1024 // 32MB for bytecode cache
defaultRouteMaxBytes = 8 * 1024 * 1024 // 8MB for route match cache
)
// LuaRouter is a filesystem-based HTTP router for Lua files
type LuaRouter struct {
routesDir string // Root directory containing route files
routes map[string]*node // Method -> route tree
failedRoutes map[string]*RouteError // Track failed routes
mu sync.RWMutex // Lock for concurrent access to routes
routeCache *fastcache.Cache // Cache for route lookups
bytecodeCache *fastcache.Cache // Cache for compiled bytecode
// Middleware tracking for path hierarchy
middlewareFiles map[string][]string // filesystem path -> middleware file paths
// Caching fields
middlewareCache map[string][]byte // path -> content
sourceCache map[string][]byte // combined source cache key -> compiled bytecode
sourceMtimes map[string]time.Time // track modification times
// Shared Lua state for compilation
compileState *luajit.State
compileStateMu sync.Mutex // Protect concurrent access to Lua state
}
// NewLuaRouter creates a new LuaRouter instance
func NewLuaRouter(routesDir string) (*LuaRouter, error) {
info, err := os.Stat(routesDir)
if err != nil {
return nil, err
}
if !info.IsDir() {
return nil, errors.New("routes path is not a directory")
}
// Create shared Lua state
compileState := luajit.New()
if compileState == nil {
return nil, errors.New("failed to create Lua compile state")
}
r := &LuaRouter{
routesDir: routesDir,
routes: make(map[string]*node),
failedRoutes: make(map[string]*RouteError),
middlewareFiles: make(map[string][]string),
routeCache: fastcache.New(defaultRouteMaxBytes),
bytecodeCache: fastcache.New(defaultBytecodeMaxBytes),
middlewareCache: make(map[string][]byte),
sourceCache: make(map[string][]byte),
sourceMtimes: make(map[string]time.Time),
compileState: compileState,
}
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
for _, method := range methods {
r.routes[method] = &node{
staticChild: make(map[string]*node),
}
}
err = r.buildRoutes()
if len(r.failedRoutes) > 0 {
return r, ErrRoutesCompilationErrors
}
return r, err
}
// Refresh rebuilds the router by rescanning the routes directory
func (r *LuaRouter) Refresh() error {
r.mu.Lock()
defer r.mu.Unlock()
for method := range r.routes {
r.routes[method] = &node{
staticChild: make(map[string]*node),
}
}
r.failedRoutes = make(map[string]*RouteError)
r.middlewareFiles = make(map[string][]string)
r.middlewareCache = make(map[string][]byte)
r.sourceCache = make(map[string][]byte)
r.sourceMtimes = make(map[string]time.Time)
err := r.buildRoutes()
if len(r.failedRoutes) > 0 {
return ErrRoutesCompilationErrors
}
return err
}
// ReportFailedRoutes returns a list of routes that failed to compile
func (r *LuaRouter) ReportFailedRoutes() []*RouteError {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]*RouteError, 0, len(r.failedRoutes))
for _, re := range r.failedRoutes {
result = append(result, re)
}
return result
}
// Close cleans up the router and its resources
func (r *LuaRouter) Close() {
r.compileStateMu.Lock()
if r.compileState != nil {
r.compileState.Close()
r.compileState = nil
}
r.compileStateMu.Unlock()
}
// GetRouteStats returns statistics about the router
func (r *LuaRouter) GetRouteStats() (int, int64) {
r.mu.RLock()
defer r.mu.RUnlock()
routeCount := 0
bytecodeBytes := int64(0)
for _, root := range r.routes {
count, bytes := countNodesAndBytecode(root)
routeCount += count
bytecodeBytes += bytes
}
return routeCount, bytecodeBytes
}
type NodeWithError struct {
ScriptPath string
Error error
}

View File

@ -1,117 +0,0 @@
package runner
import (
"sync"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
// Context represents execution context for a Lua script
type Context struct {
// Values stores any context values (route params, HTTP request info, etc.)
Values map[string]any
// FastHTTP context if this was created from an HTTP request
RequestCtx *fasthttp.RequestCtx
// Buffer for efficient string operations
buffer *bytebufferpool.ByteBuffer
}
// Context pool to reduce allocations
var contextPool = sync.Pool{
New: func() any {
return &Context{
Values: make(map[string]any, 32),
}
},
}
// NewContext creates a new context, potentially reusing one from the pool
func NewContext() *Context {
ctx := contextPool.Get().(*Context)
return ctx
}
// NewHTTPContext creates a new context from a fasthttp RequestCtx
func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context {
ctx := NewContext()
ctx.RequestCtx = requestCtx
// Extract common HTTP values that Lua might need
if requestCtx != nil {
ctx.Values["_request_method"] = string(requestCtx.Method())
ctx.Values["_request_path"] = string(requestCtx.Path())
ctx.Values["_request_url"] = string(requestCtx.RequestURI())
// Extract cookies
cookies := make(map[string]any)
requestCtx.Request.Header.VisitAllCookie(func(key, value []byte) {
cookies[string(key)] = string(value)
})
ctx.Values["_request_cookies"] = cookies
// Extract query params
query := make(map[string]any)
requestCtx.QueryArgs().VisitAll(func(key, value []byte) {
query[string(key)] = string(value)
})
ctx.Values["_request_query"] = query
// Extract form data if present
if requestCtx.IsPost() || requestCtx.IsPut() {
form := make(map[string]any)
requestCtx.PostArgs().VisitAll(func(key, value []byte) {
form[string(key)] = string(value)
})
ctx.Values["_request_form"] = form
}
// Extract headers
headers := make(map[string]any)
requestCtx.Request.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
ctx.Values["_request_headers"] = headers
}
return ctx
}
// Release returns the context to the pool after clearing its values
func (c *Context) Release() {
// Clear all values to prevent data leakage
for k := range c.Values {
delete(c.Values, k)
}
// Reset request context
c.RequestCtx = nil
// Return buffer to pool if we have one
if c.buffer != nil {
bytebufferpool.Put(c.buffer)
c.buffer = nil
}
contextPool.Put(c)
}
// GetBuffer returns a byte buffer for efficient string operations
func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer {
if c.buffer == nil {
c.buffer = bytebufferpool.Get()
}
return c.buffer
}
// Set adds a value to the context
func (c *Context) Set(key string, value any) {
c.Values[key] = value
}
// Get retrieves a value from the context
func (c *Context) Get(key string) any {
return c.Values[key]
}

View File

@ -1,404 +0,0 @@
package runner
import (
"crypto/hmac"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/binary"
"encoding/hex"
"fmt"
"hash"
"math"
mrand "math/rand/v2"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
var (
// Map to store state-specific RNGs
stateRngs = make(map[*luajit.State]*mrand.PCG)
stateRngsMu sync.Mutex
)
// RegisterCryptoFunctions registers all crypto functions with the Lua state
func RegisterCryptoFunctions(state *luajit.State) error {
// Create a state-specific RNG
stateRngsMu.Lock()
stateRngs[state] = mrand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano()>>32))
stateRngsMu.Unlock()
// Register hash functions
if err := state.RegisterGoFunction("__crypto_hash", cryptoHash); err != nil {
return err
}
// Register HMAC functions
if err := state.RegisterGoFunction("__crypto_hmac", cryptoHmac); err != nil {
return err
}
// Register UUID generation
if err := state.RegisterGoFunction("__crypto_uuid", cryptoUuid); err != nil {
return err
}
// Register random functions
if err := state.RegisterGoFunction("__crypto_random", cryptoRandom); err != nil {
return err
}
if err := state.RegisterGoFunction("__crypto_random_bytes", cryptoRandomBytes); err != nil {
return err
}
if err := state.RegisterGoFunction("__crypto_random_int", cryptoRandomInt); err != nil {
return err
}
if err := state.RegisterGoFunction("__crypto_random_seed", cryptoRandomSeed); err != nil {
return err
}
// Override Lua's math.random
if err := OverrideLuaRandom(state); err != nil {
return err
}
return nil
}
// CleanupCrypto cleans up resources when a state is closed
func CleanupCrypto(state *luajit.State) {
stateRngsMu.Lock()
delete(stateRngs, state)
stateRngsMu.Unlock()
}
// cryptoHash generates hash digests using various algorithms
func cryptoHash(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) {
state.PushString("hash: expected (string data, string algorithm)")
return 1
}
data := state.ToString(1)
algorithm := state.ToString(2)
var h hash.Hash
switch algorithm {
case "md5":
h = md5.New()
case "sha1":
h = sha1.New()
case "sha256":
h = sha256.New()
case "sha512":
h = sha512.New()
default:
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
return 1
}
h.Write([]byte(data))
hashBytes := h.Sum(nil)
// Output format
outputFormat := "hex"
if state.GetTop() >= 3 && state.IsString(3) {
outputFormat = state.ToString(3)
}
switch outputFormat {
case "hex":
state.PushString(hex.EncodeToString(hashBytes))
case "binary":
state.PushString(string(hashBytes))
default:
state.PushString(hex.EncodeToString(hashBytes))
}
return 1
}
// cryptoHmac generates HMAC using various hash algorithms
func cryptoHmac(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) || !state.IsString(3) {
state.PushString("hmac: expected (string data, string key, string algorithm)")
return 1
}
data := state.ToString(1)
key := state.ToString(2)
algorithm := state.ToString(3)
var h func() hash.Hash
switch algorithm {
case "md5":
h = md5.New
case "sha1":
h = sha1.New
case "sha256":
h = sha256.New
case "sha512":
h = sha512.New
default:
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
return 1
}
mac := hmac.New(h, []byte(key))
mac.Write([]byte(data))
macBytes := mac.Sum(nil)
// Output format
outputFormat := "hex"
if state.GetTop() >= 4 && state.IsString(4) {
outputFormat = state.ToString(4)
}
switch outputFormat {
case "hex":
state.PushString(hex.EncodeToString(macBytes))
case "binary":
state.PushString(string(macBytes))
default:
state.PushString(hex.EncodeToString(macBytes))
}
return 1
}
// cryptoUuid generates a random UUID v4
func cryptoUuid(state *luajit.State) int {
uuid := make([]byte, 16)
_, err := rand.Read(uuid)
if err != nil {
state.PushString(fmt.Sprintf("uuid: generation error: %v", err))
return 1
}
// Set version (4) and variant (RFC 4122)
uuid[6] = (uuid[6] & 0x0F) | 0x40
uuid[8] = (uuid[8] & 0x3F) | 0x80
uuidStr := fmt.Sprintf("%x-%x-%x-%x-%x",
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:])
state.PushString(uuidStr)
return 1
}
// cryptoRandomBytes generates random bytes
func cryptoRandomBytes(state *luajit.State) int {
if !state.IsNumber(1) {
state.PushString("random_bytes: expected (number length)")
return 1
}
length := int(state.ToNumber(1))
if length <= 0 {
state.PushString("random_bytes: length must be positive")
return 1
}
// Check if secure
secure := true
if state.GetTop() >= 2 && state.IsBoolean(2) {
secure = state.ToBoolean(2)
}
bytes := make([]byte, length)
if secure {
_, err := rand.Read(bytes)
if err != nil {
state.PushString(fmt.Sprintf("random_bytes: error: %v", err))
return 1
}
} else {
stateRngsMu.Lock()
stateRng, ok := stateRngs[state]
stateRngsMu.Unlock()
if !ok {
state.PushString("random_bytes: RNG not initialized")
return 1
}
for i := range bytes {
bytes[i] = byte(stateRng.Uint64() & 0xFF)
}
}
// Output format
outputFormat := "binary"
if state.GetTop() >= 3 && state.IsString(3) {
outputFormat = state.ToString(3)
}
switch outputFormat {
case "binary":
state.PushString(string(bytes))
case "hex":
state.PushString(hex.EncodeToString(bytes))
default:
state.PushString(string(bytes))
}
return 1
}
// cryptoRandomInt generates a random integer in range [min, max]
func cryptoRandomInt(state *luajit.State) int {
if !state.IsNumber(1) || !state.IsNumber(2) {
state.PushString("random_int: expected (number min, number max)")
return 1
}
min := int64(state.ToNumber(1))
max := int64(state.ToNumber(2))
if max <= min {
state.PushString("random_int: max must be greater than min")
return 1
}
// Check if secure
secure := true
if state.GetTop() >= 3 && state.IsBoolean(3) {
secure = state.ToBoolean(3)
}
range_size := max - min + 1
var result int64
if secure {
bytes := make([]byte, 8)
_, err := rand.Read(bytes)
if err != nil {
state.PushString(fmt.Sprintf("random_int: error: %v", err))
return 1
}
val := binary.BigEndian.Uint64(bytes)
result = min + int64(val%uint64(range_size))
} else {
stateRngsMu.Lock()
stateRng, ok := stateRngs[state]
stateRngsMu.Unlock()
if !ok {
state.PushString("random_int: RNG not initialized")
return 1
}
result = min + int64(stateRng.Uint64()%uint64(range_size))
}
state.PushNumber(float64(result))
return 1
}
// cryptoRandom implements math.random functionality
func cryptoRandom(state *luajit.State) int {
numArgs := state.GetTop()
// Check if secure
secure := false
// math.random() - return [0,1)
if numArgs == 0 {
if secure {
bytes := make([]byte, 8)
_, err := rand.Read(bytes)
if err != nil {
state.PushString(fmt.Sprintf("random: error: %v", err))
return 1
}
val := binary.BigEndian.Uint64(bytes)
state.PushNumber(float64(val) / float64(math.MaxUint64))
} else {
stateRngsMu.Lock()
stateRng, ok := stateRngs[state]
stateRngsMu.Unlock()
if !ok {
state.PushString("random: RNG not initialized")
return 1
}
state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64))
}
return 1
}
// math.random(n) - return integer [1,n]
if numArgs == 1 && state.IsNumber(1) {
n := int64(state.ToNumber(1))
if n < 1 {
state.PushString("random: upper bound must be >= 1")
return 1
}
state.PushNumber(1) // min
state.PushNumber(float64(n)) // max
state.PushBoolean(secure) // secure flag
return cryptoRandomInt(state)
}
// math.random(m, n) - return integer [m,n]
if numArgs >= 2 && state.IsNumber(1) && state.IsNumber(2) {
state.PushBoolean(secure) // secure flag
return cryptoRandomInt(state)
}
state.PushString("random: invalid arguments")
return 1
}
// cryptoRandomSeed sets seed for non-secure RNG
func cryptoRandomSeed(state *luajit.State) int {
if !state.IsNumber(1) {
state.PushString("randomseed: expected (number seed)")
return 1
}
seed := uint64(state.ToNumber(1))
stateRngsMu.Lock()
stateRngs[state] = mrand.NewPCG(seed, seed>>32)
stateRngsMu.Unlock()
return 0
}
// OverrideLuaRandom replaces Lua's math.random with Go implementation
func OverrideLuaRandom(state *luajit.State) error {
if err := state.RegisterGoFunction("go_math_random", cryptoRandom); err != nil {
return err
}
if err := state.RegisterGoFunction("go_math_randomseed", cryptoRandomSeed); err != nil {
return err
}
// Replace original functions
return state.DoString(`
-- Save original functions
_G._original_math_random = math.random
_G._original_math_randomseed = math.randomseed
-- Replace with Go implementations
math.random = go_math_random
math.randomseed = go_math_randomseed
-- Clean up global namespace
go_math_random = nil
go_math_randomseed = nil
`)
}

View File

@ -1,166 +0,0 @@
package runner
import (
_ "embed"
"sync"
"sync/atomic"
"Moonshark/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
//go:embed lua/sandbox.lua
var sandboxLuaCode string
//go:embed lua/json.lua
var jsonLuaCode string
//go:embed lua/sqlite.lua
var sqliteLuaCode string
//go:embed lua/fs.lua
var fsLuaCode string
//go:embed lua/util.lua
var utilLuaCode string
//go:embed lua/string.lua
var stringLuaCode string
//go:embed lua/table.lua
var tableLuaCode string
//go:embed lua/crypto.lua
var cryptoLuaCode string
//go:embed lua/time.lua
var timeLuaCode string
//go:embed lua/math.lua
var mathLuaCode string
//go:embed lua/env.lua
var envLuaCode string
// ModuleInfo holds information about an embeddable Lua module
type ModuleInfo struct {
Name string // Module name
Code string // Module source code
Bytecode atomic.Pointer[[]byte] // Cached bytecode
Once sync.Once // For one-time compilation
DefinesGlobal bool // Whether module defines globals directly
}
var (
sandbox = ModuleInfo{Name: "sandbox", Code: sandboxLuaCode}
modules = []ModuleInfo{
{Name: "json", Code: jsonLuaCode, DefinesGlobal: true},
{Name: "sqlite", Code: sqliteLuaCode},
{Name: "fs", Code: fsLuaCode, DefinesGlobal: true},
{Name: "util", Code: utilLuaCode, DefinesGlobal: true},
{Name: "string", Code: stringLuaCode},
{Name: "table", Code: tableLuaCode},
{Name: "crypto", Code: cryptoLuaCode, DefinesGlobal: true},
{Name: "time", Code: timeLuaCode},
{Name: "math", Code: mathLuaCode},
{Name: "env", Code: envLuaCode, DefinesGlobal: true},
}
)
// precompileModule compiles a module's code to bytecode once
func precompileModule(m *ModuleInfo) {
m.Once.Do(func() {
tempState := luajit.New()
if tempState == nil {
logger.Fatal("Failed to create temp Lua state for %s module compilation", m.Name)
return
}
defer tempState.Close()
defer tempState.Cleanup()
code, err := tempState.CompileBytecode(m.Code, m.Name+".lua")
if err != nil {
logger.Error("Failed to compile %s module: %v", m.Name, err)
return
}
bytecode := make([]byte, len(code))
copy(bytecode, code)
m.Bytecode.Store(&bytecode)
logger.Debug("Successfully precompiled %s.lua to bytecode (%d bytes)", m.Name, len(code))
})
}
// loadModule loads a module into a Lua state
func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error {
// Ensure bytecode is compiled
precompileModule(m)
// Attempt to load from bytecode
bytecode := m.Bytecode.Load()
if bytecode != nil && len(*bytecode) > 0 {
if verbose {
logger.Debug("Loading %s.lua from precompiled bytecode", m.Name)
}
if err := state.LoadBytecode(*bytecode, m.Name+".lua"); err != nil {
return err
}
if m.DefinesGlobal {
// Module defines its own globals, just run it
if err := state.RunBytecode(); err != nil {
return err
}
} else {
// Module returns a table, capture and set as global
if err := state.RunBytecodeWithResults(1); err != nil {
return err
}
state.SetGlobal(m.Name)
}
} else {
// Fallback to interpreting the source
if verbose {
logger.Warning("Using non-precompiled %s.lua", m.Name)
}
if err := state.DoString(m.Code); err != nil {
return err
}
}
return nil
}
// loadSandboxIntoState loads all modules and sandbox into a Lua state
func loadSandboxIntoState(state *luajit.State, verbose bool) error {
// Load all modules first
for i := range modules {
if err := loadModule(state, &modules[i], verbose); err != nil {
return err
}
}
// Initialize active connections tracking (specific to SQLite)
if err := state.DoString(`__active_sqlite_connections = {}`); err != nil {
return err
}
// Load the sandbox last
precompileModule(&sandbox)
bytecode := sandbox.Bytecode.Load()
if bytecode != nil && len(*bytecode) > 0 {
if verbose {
logger.Debug("Loading sandbox.lua from precompiled bytecode")
}
return state.LoadAndRunBytecode(*bytecode, "sandbox.lua")
}
if verbose {
logger.Warning("Using non-precompiled sandbox.lua")
}
return state.DoString(sandboxLuaCode)
}

View File

@ -1,278 +0,0 @@
package runner
import (
"bufio"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"Moonshark/utils/color"
"Moonshark/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// EnvManager handles loading, storing, and saving environment variables
type EnvManager struct {
envPath string // Path to .env file
vars map[string]any // Environment variables in memory
mu sync.RWMutex // Thread-safe access
}
// Global environment manager instance
var globalEnvManager *EnvManager
// InitEnv initializes the environment manager with the given data directory
func InitEnv(dataDir string) error {
if dataDir == "" {
return fmt.Errorf("data directory cannot be empty")
}
// Create data directory if it doesn't exist
if err := os.MkdirAll(dataDir, 0755); err != nil {
return fmt.Errorf("failed to create data directory: %w", err)
}
envPath := filepath.Join(dataDir, ".env")
globalEnvManager = &EnvManager{
envPath: envPath,
vars: make(map[string]any),
}
// Load existing .env file if it exists
if err := globalEnvManager.load(); err != nil {
logger.Warning("Failed to load .env file: %v", err)
}
count := len(globalEnvManager.vars)
if count > 0 {
logger.Info("Environment loaded: %s vars from %s",
color.Apply(fmt.Sprintf("%d", count), color.Yellow),
color.Apply(envPath, color.Yellow))
} else {
logger.Info("Environment initialized: %s", color.Apply(envPath, color.Yellow))
}
return nil
}
// GetGlobalEnvManager returns the global environment manager instance
func GetGlobalEnvManager() *EnvManager {
return globalEnvManager
}
// load reads the .env file and populates the vars map
func (e *EnvManager) load() error {
file, err := os.Open(e.envPath)
if os.IsNotExist(err) {
// File doesn't exist, start with empty env
return nil
}
if err != nil {
return fmt.Errorf("failed to open .env file: %w", err)
}
defer file.Close()
e.mu.Lock()
defer e.mu.Unlock()
scanner := bufio.NewScanner(file)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Parse key=value
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
logger.Warning("Invalid .env line %d: %s", lineNum, line)
continue
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// Remove quotes if present
if len(value) >= 2 {
if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) ||
(strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) {
value = value[1 : len(value)-1]
}
}
e.vars[key] = value
}
return scanner.Err()
}
// Save writes the current environment variables to the .env file
func (e *EnvManager) Save() error {
if e == nil {
return nil // No env manager initialized
}
e.mu.RLock()
defer e.mu.RUnlock()
file, err := os.Create(e.envPath)
if err != nil {
return fmt.Errorf("failed to create .env file: %w", err)
}
defer file.Close()
// Sort keys for consistent output
keys := make([]string, 0, len(e.vars))
for key := range e.vars {
keys = append(keys, key)
}
sort.Strings(keys)
// Write header comment
fmt.Fprintln(file, "# Environment variables for Moonshark")
fmt.Fprintln(file, "# Generated automatically - you can edit this file")
fmt.Fprintln(file)
// Write each variable
for _, key := range keys {
value := e.vars[key]
// Convert value to string
var strValue string
switch v := value.(type) {
case string:
strValue = v
case nil:
continue // Skip nil values
default:
strValue = fmt.Sprintf("%v", v)
}
// Quote values that contain spaces or special characters
if strings.ContainsAny(strValue, " \t\n\r\"'\\") {
strValue = fmt.Sprintf("\"%s\"", strings.ReplaceAll(strValue, "\"", "\\\""))
}
fmt.Fprintf(file, "%s=%s\n", key, strValue)
}
logger.Debug("Environment saved: %d vars to %s", len(e.vars), e.envPath)
return nil
}
// Get retrieves an environment variable
func (e *EnvManager) Get(key string) (any, bool) {
if e == nil {
return nil, false
}
e.mu.RLock()
defer e.mu.RUnlock()
value, exists := e.vars[key]
return value, exists
}
// Set stores an environment variable
func (e *EnvManager) Set(key string, value any) {
if e == nil {
return
}
e.mu.Lock()
defer e.mu.Unlock()
e.vars[key] = value
}
// GetAll returns a copy of all environment variables
func (e *EnvManager) GetAll() map[string]any {
if e == nil {
return make(map[string]any)
}
e.mu.RLock()
defer e.mu.RUnlock()
result := make(map[string]any, len(e.vars))
for k, v := range e.vars {
result[k] = v
}
return result
}
// CleanupEnv saves the environment and cleans up resources
func CleanupEnv() error {
if globalEnvManager != nil {
return globalEnvManager.Save()
}
return nil
}
// envGet Lua function to get an environment variable
func envGet(state *luajit.State) int {
if !state.IsString(1) {
state.PushNil()
return 1
}
key := state.ToString(1)
if value, exists := globalEnvManager.Get(key); exists {
if err := state.PushValue(value); err != nil {
state.PushNil()
}
} else {
state.PushNil()
}
return 1
}
// envSet Lua function to set an environment variable
func envSet(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) {
state.PushBoolean(false)
return 1
}
key := state.ToString(1)
value := state.ToString(2)
globalEnvManager.Set(key, value)
state.PushBoolean(true)
return 1
}
// envGetAll Lua function to get all environment variables
func envGetAll(state *luajit.State) int {
vars := globalEnvManager.GetAll()
if err := state.PushTable(vars); err != nil {
state.PushNil()
}
return 1
}
// RegisterEnvFunctions registers environment functions with the Lua state
func RegisterEnvFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__env_get", envGet); err != nil {
return err
}
if err := state.RegisterGoFunction("__env_set", envSet); err != nil {
return err
}
if err := state.RegisterGoFunction("__env_get_all", envGetAll); err != nil {
return err
}
return nil
}

View File

@ -1,579 +0,0 @@
package runner
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"Moonshark/utils/color"
"Moonshark/utils/logger"
lru "git.sharkk.net/Go/LRU"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/golang/snappy"
)
// Global filesystem path (set during initialization)
var fsBasePath string
// Global file cache with compressed data
var fileCache *lru.LRUCache
// Cache entry info for statistics/debugging
type cacheStats struct {
hits int64
misses int64
}
var stats cacheStats
// InitFS initializes the filesystem with the given base path
func InitFS(basePath string) error {
if basePath == "" {
return errors.New("filesystem base path cannot be empty")
}
// Create the directory if it doesn't exist
if err := os.MkdirAll(basePath, 0755); err != nil {
return fmt.Errorf("failed to create filesystem directory: %w", err)
}
// Store the absolute path
absPath, err := filepath.Abs(basePath)
if err != nil {
return fmt.Errorf("failed to get absolute path: %w", err)
}
fsBasePath = absPath
// Initialize file cache with 2000 entries (reasonable for most use cases)
fileCache = lru.NewLRUCache(2000)
logger.Info("Filesystem is g2g! %s", color.Apply(fsBasePath, color.Yellow))
return nil
}
// CleanupFS performs any necessary cleanup
func CleanupFS() {
if fileCache != nil {
fileCache.Clear()
logger.Info(
"File cache cleared - %s hits, %s misses",
color.Apply(fmt.Sprintf("%d", stats.hits), color.Yellow),
color.Apply(fmt.Sprintf("%d", stats.misses), color.Red),
)
}
}
// ResolvePath resolves a given path relative to the filesystem base
// Returns the actual path and an error if the path tries to escape the sandbox
func ResolvePath(path string) (string, error) {
if fsBasePath == "" {
return "", errors.New("filesystem not initialized")
}
// Clean the path to remove any .. or . components
cleanPath := filepath.Clean(path)
// Replace backslashes with forward slashes for consistent handling
cleanPath = strings.ReplaceAll(cleanPath, "\\", "/")
// Remove any leading / or drive letter to make it relative
cleanPath = strings.TrimPrefix(cleanPath, "/")
// Remove drive letter on Windows (e.g. C:)
if len(cleanPath) >= 2 && cleanPath[1] == ':' {
cleanPath = cleanPath[2:]
}
// Ensure the path doesn't contain .. to prevent escaping
if strings.Contains(cleanPath, "..") {
return "", errors.New("path cannot contain .. components")
}
// Join with the base path
fullPath := filepath.Join(fsBasePath, cleanPath)
// Verify the path is still within the base directory
if !strings.HasPrefix(fullPath, fsBasePath) {
return "", errors.New("path escapes the filesystem sandbox")
}
return fullPath, nil
}
// getCacheKey creates a cache key from path and modification time
func getCacheKey(fullPath string, modTime time.Time) string {
return fmt.Sprintf("%s:%d", fullPath, modTime.Unix())
}
// fsReadFile reads a file and returns its contents
func fsReadFile(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.read_file: path must be a string")
return -1
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.read_file: " + err.Error())
return -1
}
// Get file info for cache key and validation
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.read_file: " + err.Error())
return -1
}
// Create cache key with path and modification time
cacheKey := getCacheKey(fullPath, info.ModTime())
// Try to get from cache first
if fileCache != nil {
if cachedData, exists := fileCache.Get(cacheKey); exists {
if compressedData, ok := cachedData.([]byte); ok {
// Decompress cached data
data, err := snappy.Decode(nil, compressedData)
if err == nil {
stats.hits++
state.PushString(string(data))
return 1
}
// Cache corruption - continue to disk read
}
}
}
// Cache miss or error - read from disk
stats.misses++
data, err := os.ReadFile(fullPath)
if err != nil {
state.PushString("fs.read_file: " + err.Error())
return -1
}
// Compress and cache the data
if fileCache != nil {
compressedData := snappy.Encode(nil, data)
fileCache.Put(cacheKey, compressedData)
}
state.PushString(string(data))
return 1
}
// fsWriteFile writes data to a file
func fsWriteFile(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.write_file: path must be a string")
return -1
}
path := state.ToString(1)
if !state.IsString(2) {
state.PushString("fs.write_file: content must be a string")
return -1
}
content := state.ToString(2)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.write_file: " + err.Error())
return -1
}
// Ensure the directory exists
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
state.PushString("fs.write_file: failed to create directory: " + err.Error())
return -1
}
err = os.WriteFile(fullPath, []byte(content), 0644)
if err != nil {
state.PushString("fs.write_file: " + err.Error())
return -1
}
// Invalidate cache entries for this file path
if fileCache != nil {
// We can't easily iterate through cache keys, so we'll let the cache
// naturally expire old entries when the file is read again
}
state.PushBoolean(true)
return 1
}
// fsAppendFile appends data to a file
func fsAppendFile(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.append_file: path must be a string")
return -1
}
path := state.ToString(1)
if !state.IsString(2) {
state.PushString("fs.append_file: content must be a string")
return -1
}
content := state.ToString(2)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.append_file: " + err.Error())
return -1
}
// Ensure the directory exists
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
state.PushString("fs.append_file: failed to create directory: " + err.Error())
return -1
}
file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
state.PushString("fs.append_file: " + err.Error())
return -1
}
defer file.Close()
_, err = file.Write([]byte(content))
if err != nil {
state.PushString("fs.append_file: " + err.Error())
return -1
}
state.PushBoolean(true)
return 1
}
// fsExists checks if a file or directory exists
func fsExists(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.exists: path must be a string")
return -1
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.exists: " + err.Error())
return -1
}
_, err = os.Stat(fullPath)
state.PushBoolean(err == nil)
return 1
}
// fsRemoveFile removes a file
func fsRemoveFile(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.remove_file: path must be a string")
return -1
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.remove_file: " + err.Error())
return -1
}
// Check if it's a directory
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.remove_file: " + err.Error())
return -1
}
if info.IsDir() {
state.PushString("fs.remove_file: cannot remove directory, use remove_dir instead")
return -1
}
err = os.Remove(fullPath)
if err != nil {
state.PushString("fs.remove_file: " + err.Error())
return -1
}
state.PushBoolean(true)
return 1
}
// fsGetInfo gets information about a file
func fsGetInfo(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.get_info: path must be a string")
return -1
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.get_info: " + err.Error())
return -1
}
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.get_info: " + err.Error())
return -1
}
state.NewTable()
state.PushString(info.Name())
state.SetField(-2, "name")
state.PushNumber(float64(info.Size()))
state.SetField(-2, "size")
state.PushNumber(float64(info.Mode()))
state.SetField(-2, "mode")
state.PushNumber(float64(info.ModTime().Unix()))
state.SetField(-2, "mod_time")
state.PushBoolean(info.IsDir())
state.SetField(-2, "is_dir")
return 1
}
// fsMakeDir creates a directory
func fsMakeDir(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.make_dir: path must be a string")
return -1
}
path := state.ToString(1)
perm := os.FileMode(0755)
if state.GetTop() >= 2 && state.IsNumber(2) {
perm = os.FileMode(state.ToNumber(2))
}
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.make_dir: " + err.Error())
return -1
}
err = os.MkdirAll(fullPath, perm)
if err != nil {
state.PushString("fs.make_dir: " + err.Error())
return -1
}
state.PushBoolean(true)
return 1
}
// fsListDir lists the contents of a directory
func fsListDir(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.list_dir: path must be a string")
return -1
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.list_dir: " + err.Error())
return -1
}
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.list_dir: " + err.Error())
return -1
}
if !info.IsDir() {
state.PushString("fs.list_dir: not a directory")
return -1
}
files, err := os.ReadDir(fullPath)
if err != nil {
state.PushString("fs.list_dir: " + err.Error())
return -1
}
state.NewTable()
for i, file := range files {
state.PushNumber(float64(i + 1))
state.PushString(file.Name())
state.SetTable(-3)
}
return 1
}
// fsRemoveDir removes a directory
func fsRemoveDir(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.remove_dir: path must be a string")
return -1
}
path := state.ToString(1)
recursive := false
if state.GetTop() >= 2 {
recursive = state.ToBoolean(2)
}
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.remove_dir: " + err.Error())
return -1
}
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.remove_dir: " + err.Error())
return -1
}
if !info.IsDir() {
state.PushString("fs.remove_dir: not a directory")
return -1
}
if recursive {
err = os.RemoveAll(fullPath)
} else {
err = os.Remove(fullPath)
}
if err != nil {
state.PushString("fs.remove_dir: " + err.Error())
return -1
}
state.PushBoolean(true)
return 1
}
// fsJoinPaths joins path components
func fsJoinPaths(state *luajit.State) int {
nargs := state.GetTop()
if nargs < 1 {
state.PushString("fs.join_paths: at least one path component required")
return -1
}
components := make([]string, nargs)
for i := 1; i <= nargs; i++ {
if !state.IsString(i) {
state.PushString("fs.join_paths: all arguments must be strings")
return -1
}
components[i-1] = state.ToString(i)
}
result := filepath.Join(components...)
result = strings.ReplaceAll(result, "\\", "/")
state.PushString(result)
return 1
}
// fsDirName returns the directory portion of a path
func fsDirName(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.dir_name: path must be a string")
return -1
}
path := state.ToString(1)
dir := filepath.Dir(path)
dir = strings.ReplaceAll(dir, "\\", "/")
state.PushString(dir)
return 1
}
// fsBaseName returns the file name portion of a path
func fsBaseName(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.base_name: path must be a string")
return -1
}
path := state.ToString(1)
base := filepath.Base(path)
state.PushString(base)
return 1
}
// fsExtension returns the file extension
func fsExtension(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.extension: path must be a string")
return -1
}
path := state.ToString(1)
ext := filepath.Ext(path)
state.PushString(ext)
return 1
}
// RegisterFSFunctions registers filesystem functions with the Lua state
func RegisterFSFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__fs_read_file", fsReadFile); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_write_file", fsWriteFile); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_append_file", fsAppendFile); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_exists", fsExists); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_remove_file", fsRemoveFile); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_get_info", fsGetInfo); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_make_dir", fsMakeDir); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_list_dir", fsListDir); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_remove_dir", fsRemoveDir); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_join_paths", fsJoinPaths); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_dir_name", fsDirName); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_base_name", fsBaseName); err != nil {
return err
}
if err := state.RegisterGoFunction("__fs_extension", fsExtension); err != nil {
return err
}
return nil
}

View File

@ -1,334 +0,0 @@
package runner
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/url"
"strings"
"time"
"github.com/goccy/go-json"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
"Moonshark/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Default HTTP client with sensible timeout
var defaultFastClient = fasthttp.Client{
MaxConnsPerHost: 1024,
MaxIdleConnDuration: time.Minute,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
DisableHeaderNamesNormalizing: true,
}
// HTTPClientConfig contains client settings
type HTTPClientConfig struct {
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
DefaultTimeout time.Duration // Default request timeout
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
AllowRemote bool // Whether to allow remote connections
}
// DefaultHTTPClientConfig provides sensible defaults
var DefaultHTTPClientConfig = HTTPClientConfig{
MaxTimeout: 60 * time.Second,
DefaultTimeout: 30 * time.Second,
MaxResponseSize: 10 * 1024 * 1024, // 10MB
AllowRemote: true,
}
// ApplyResponse applies a Response to a fasthttp.RequestCtx
func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
// Set status code
ctx.SetStatusCode(resp.Status)
// Set headers
for name, value := range resp.Headers {
ctx.Response.Header.Set(name, value)
}
// Set cookies
for _, cookie := range resp.Cookies {
ctx.Response.Header.SetCookie(cookie)
}
// Process the body based on its type
if resp.Body == nil {
return
}
// Get a buffer from the pool
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Set body based on type
switch body := resp.Body.(type) {
case string:
ctx.SetBodyString(body)
case []byte:
ctx.SetBody(body)
case map[string]any, []any, []float64, []string, []int:
// Marshal JSON
if err := json.NewEncoder(buf).Encode(body); err == nil {
// Set content type if not already set
if len(ctx.Response.Header.ContentType()) == 0 {
ctx.Response.Header.SetContentType("application/json")
}
ctx.SetBody(buf.Bytes())
} else {
// Fallback
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
default:
// Default to string representation
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
}
// httpRequest makes an HTTP request and returns the result to Lua
func httpRequest(state *luajit.State) int {
// Get method (required)
if !state.IsString(1) {
state.PushString("http.client.request: method must be a string")
return -1
}
method := strings.ToUpper(state.ToString(1))
// Get URL (required)
if !state.IsString(2) {
state.PushString("http.client.request: url must be a string")
return -1
}
urlStr := state.ToString(2)
// Parse URL to check if it's valid
parsedURL, err := url.Parse(urlStr)
if err != nil {
state.PushString("Invalid URL: " + err.Error())
return -1
}
// Get client configuration
config := DefaultHTTPClientConfig
// Check if remote connections are allowed
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
state.PushString("Remote connections are not allowed")
return -1
}
// Use bytebufferpool for request and response
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Set up request
req.Header.SetMethod(method)
req.SetRequestURI(urlStr)
req.Header.Set("User-Agent", "Moonshark/1.0")
// Get body (optional)
if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsString(3) {
// String body
req.SetBodyString(state.ToString(3))
} else if state.IsTable(3) {
// Table body - convert to JSON
luaTable, err := state.ToTable(3)
if err != nil {
state.PushString("Failed to parse body table: " + err.Error())
return -1
}
// Use bytebufferpool for JSON serialization
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
state.PushString("Failed to convert body to JSON: " + err.Error())
return -1
}
req.SetBody(buf.Bytes())
req.Header.SetContentType("application/json")
} else {
state.PushString("Body must be a string or table")
return -1
}
}
// Process options (headers, timeout, etc.)
timeout := config.DefaultTimeout
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
// Process headers
state.GetField(4, "headers")
if state.IsTable(-1) {
// Iterate through headers
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack now has key at -2 and value at -1
if state.IsString(-2) && state.IsString(-1) {
headerName := state.ToString(-2)
headerValue := state.ToString(-1)
req.Header.Set(headerName, headerValue)
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1) // Pop headers table
// Get timeout
state.GetField(4, "timeout")
if state.IsNumber(-1) {
requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second
// Apply max timeout if configured
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
timeout = config.MaxTimeout
} else {
timeout = requestTimeout
}
}
state.Pop(1) // Pop timeout
// Process query parameters
state.GetField(4, "query")
if state.IsTable(-1) {
// Create URL args
args := req.URI().QueryArgs()
// Iterate through query params
state.PushNil() // Start iteration
for state.Next(-2) {
if state.IsString(-2) {
paramName := state.ToString(-2)
// Handle different value types
if state.IsString(-1) {
args.Add(paramName, state.ToString(-1))
} else if state.IsNumber(-1) {
args.Add(paramName, strings.TrimRight(strings.TrimRight(
state.ToString(-1), "0"), "."))
} else if state.IsBoolean(-1) {
if state.ToBoolean(-1) {
args.Add(paramName, "true")
} else {
args.Add(paramName, "false")
}
}
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1) // Pop query table
}
// Create context with timeout
_, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// Execute request
err = defaultFastClient.DoTimeout(req, resp, timeout)
if err != nil {
errStr := "Request failed: " + err.Error()
if errors.Is(err, fasthttp.ErrTimeout) {
errStr = "Request timed out after " + timeout.String()
}
state.PushString(errStr)
return -1
}
// Create response table
state.NewTable()
// Set status code
state.PushNumber(float64(resp.StatusCode()))
state.SetField(-2, "status")
// Set status text
statusText := fasthttp.StatusMessage(resp.StatusCode())
state.PushString(statusText)
state.SetField(-2, "status_text")
// Set body
var respBody []byte
// Apply size limits to response
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
// Make a limited copy
respBody = make([]byte, config.MaxResponseSize)
copy(respBody, resp.Body())
} else {
respBody = resp.Body()
}
state.PushString(string(respBody))
state.SetField(-2, "body")
// Parse body as JSON if content type is application/json
contentType := string(resp.Header.ContentType())
if strings.Contains(contentType, "application/json") {
var jsonData any
if err := json.Unmarshal(respBody, &jsonData); err == nil {
if err := state.PushValue(jsonData); err == nil {
state.SetField(-2, "json")
}
}
}
// Set headers
state.NewTable()
resp.Header.VisitAll(func(key, value []byte) {
state.PushString(string(value))
state.SetField(-2, string(key))
})
state.SetField(-2, "headers")
// Create ok field (true if status code is 2xx)
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
state.SetField(-2, "ok")
return 1
}
// generateToken creates a cryptographically secure random token
func generateToken(state *luajit.State) int {
// Get the length from the Lua arguments (default to 32)
length := 32
if state.GetTop() >= 1 && state.IsNumber(1) {
length = int(state.ToNumber(1))
}
// Enforce minimum length for security
if length < 16 {
length = 16
}
// Generate secure random bytes
tokenBytes := make([]byte, length)
if _, err := rand.Read(tokenBytes); err != nil {
logger.Error("Failed to generate secure token: %v", err)
state.PushString("")
return 1 // Return empty string on error
}
// Encode as base64
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
// Trim to requested length (base64 might be longer)
if len(token) > length {
token = token[:length]
}
// Push the token to the Lua stack
state.PushString(token)
return 1 // One return value
}

View File

@ -1,142 +0,0 @@
--[[
crypto.lua - Cryptographic functions powered by Go
]]--
-- ======================================================================
-- HASHING FUNCTIONS
-- ======================================================================
-- Generate hash digest using various algorithms
-- Algorithms: md5, sha1, sha256, sha512
-- Formats: hex (default), binary
function hash(data, algorithm, format)
if type(data) ~= "string" then
error("hash: data must be a string", 2)
end
algorithm = algorithm or "sha256"
format = format or "hex"
return __crypto_hash(data, algorithm, format)
end
function md5(data, format)
return hash(data, "md5", format)
end
function sha1(data, format)
return hash(data, "sha1", format)
end
function sha256(data, format)
return hash(data, "sha256", format)
end
function sha512(data, format)
return hash(data, "sha512", format)
end
-- ======================================================================
-- HMAC FUNCTIONS
-- ======================================================================
-- Generate HMAC using various algorithms
-- Algorithms: md5, sha1, sha256, sha512
-- Formats: hex (default), binary
function hmac(data, key, algorithm, format)
if type(data) ~= "string" then
error("hmac: data must be a string", 2)
end
if type(key) ~= "string" then
error("hmac: key must be a string", 2)
end
algorithm = algorithm or "sha256"
format = format or "hex"
return __crypto_hmac(data, key, algorithm, format)
end
function hmac_md5(data, key, format)
return hmac(data, key, "md5", format)
end
function hmac_sha1(data, key, format)
return hmac(data, key, "sha1", format)
end
function hmac_sha256(data, key, format)
return hmac(data, key, "sha256", format)
end
function hmac_sha512(data, key, format)
return hmac(data, key, "sha512", format)
end
-- ======================================================================
-- RANDOM FUNCTIONS
-- ======================================================================
-- Generate random bytes
-- Formats: binary (default), hex
function random_bytes(length, secure, format)
if type(length) ~= "number" or length <= 0 then
error("random_bytes: length must be positive", 2)
end
secure = secure ~= false -- Default to secure
format = format or "binary"
return __crypto_random_bytes(length, secure, format)
end
-- Generate random integer in range [min, max]
function random_int(min, max, secure)
if type(min) ~= "number" or type(max) ~= "number" then
error("random_int: min and max must be numbers", 2)
end
if max <= min then
error("random_int: max must be greater than min", 2)
end
secure = secure ~= false -- Default to secure
return __crypto_random_int(min, max, secure)
end
-- Generate random string of specified length
function random_string(length, charset, secure)
if type(length) ~= "number" or length <= 0 then
error("random_string: length must be positive", 2)
end
secure = secure ~= false -- Default to secure
-- Default character set: alphanumeric
charset = charset or "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
if type(charset) ~= "string" or #charset == 0 then
error("random_string: charset must be non-empty", 2)
end
local result = ""
local charset_length = #charset
for i = 1, length do
local index = random_int(1, charset_length, secure)
result = result .. charset:sub(index, index)
end
return result
end
-- ======================================================================
-- UUID FUNCTIONS
-- ======================================================================
-- Generate random UUID (v4)
function uuid()
return __crypto_uuid()
end

View File

@ -1,93 +0,0 @@
-- Environment variable module for Moonshark
-- Provides access to persistent environment variables stored in .env file
-- Get an environment variable with a default value
-- Returns the value if it exists, default_value otherwise
function env_get(key, default_value)
if type(key) ~= "string" then
error("env_get: key must be a string")
end
-- First check context for environment variables (no Go call needed)
if _env and _env[key] ~= nil then
return _env[key]
end
return default_value
end
-- Set an environment variable
-- Returns true on success, false on failure
function env_set(key, value)
if type(key) ~= "string" then
error("env_set: key must be a string")
end
-- Update context immediately for future reads
if not _env then
_env = {}
end
_env[key] = value
-- Persist to Go backend
return __env_set(key, value)
end
-- Get all environment variables as a table
-- Returns a table with all key-value pairs
function env_get_all()
-- Return context table directly if available
if _env then
local copy = {}
for k, v in pairs(_env) do
copy[k] = v
end
return copy
end
-- Fallback to Go call
return __env_get_all()
end
-- Check if an environment variable exists
-- Returns true if the variable exists, false otherwise
function env_exists(key)
if type(key) ~= "string" then
error("env_exists: key must be a string")
end
-- Check context first
if _env then
return _env[key] ~= nil
end
return false
end
-- Set multiple environment variables from a table
-- Returns true on success, false if any setting failed
function env_set_many(vars)
if type(vars) ~= "table" then
error("env_set_many: vars must be a table")
end
if not _env then
_env = {}
end
local success = true
for key, value in pairs(vars) do
if type(key) == "string" and type(value) == "string" then
-- Update context
_env[key] = value
-- Persist to Go
if not __env_set(key, value) then
success = false
end
else
error("env_set_many: all keys and values must be strings")
end
end
return success
end

View File

@ -1,134 +0,0 @@
function fs_read(path)
if type(path) ~= "string" then
error("fs_read: path must be a string", 2)
end
return __fs_read_file(path)
end
function fs_write(path, content)
if type(path) ~= "string" then
error("fs_write: path must be a string", 2)
end
if type(content) ~= "string" then
error("fs_write: content must be a string", 2)
end
return __fs_write_file(path, content)
end
function fs_append(path, content)
if type(path) ~= "string" then
error("fs_append: path must be a string", 2)
end
if type(content) ~= "string" then
error("fs_append: content must be a string", 2)
end
return __fs_append_file(path, content)
end
function fs_exists(path)
if type(path) ~= "string" then
error("fs_exists: path must be a string", 2)
end
return __fs_exists(path)
end
function fs_remove(path)
if type(path) ~= "string" then
error("fs_remove: path must be a string", 2)
end
return __fs_remove_file(path)
end
function fs_info(path)
if type(path) ~= "string" then
error("fs_info: path must be a string", 2)
end
local info = __fs_get_info(path)
-- Convert the Unix timestamp to a readable date
if info and info.mod_time then
info.mod_time_str = os.date("%Y-%m-%d %H:%M:%S", info.mod_time)
end
return info
end
-- Directory Operations
function fs_mkdir(path, mode)
if type(path) ~= "string" then
error("fs_mkdir: path must be a string", 2)
end
mode = mode or 0755
return __fs_make_dir(path, mode)
end
function fs_ls(path)
if type(path) ~= "string" then
error("fs_ls: path must be a string", 2)
end
return __fs_list_dir(path)
end
function fs_rmdir(path, recursive)
if type(path) ~= "string" then
error("fs_rmdir: path must be a string", 2)
end
recursive = recursive or false
return __fs_remove_dir(path, recursive)
end
-- Path Operations
function fs_join_paths(...)
return __fs_join_paths(...)
end
function fs_dir_name(path)
if type(path) ~= "string" then
error("fs_dir_name: path must be a string", 2)
end
return __fs_dir_name(path)
end
function fs_base_name(path)
if type(path) ~= "string" then
error("fs_base_name: path must be a string", 2)
end
return __fs_base_name(path)
end
function fs_extension(path)
if type(path) ~= "string" then
error("fs_extension: path must be a string", 2)
end
return __fs_extension(path)
end
-- Utility Functions
function fs_read_json(path)
local content = fs_read(path)
if not content then
return nil, "Could not read file"
end
local ok, result = pcall(json.decode, content)
if not ok then
return nil, "Invalid JSON: " .. tostring(result)
end
return result
end
function fs_write_json(path, data, pretty)
if type(data) ~= "table" then
error("fs_write_json: data must be a table", 2)
end
local content
if pretty then
content = json.pretty_print(data)
else
content = json.encode(data)
end
return fs_write(path, content)
end

View File

@ -1,422 +0,0 @@
-- json.lua: High-performance JSON module for Moonshark
-- Pre-computed escape sequences to avoid recreating table
local escape_chars = {
['"'] = '\\"', ['\\'] = '\\\\',
['\n'] = '\\n', ['\r'] = '\\r', ['\t'] = '\\t'
}
function json_go_encode(value)
return __json_marshal(value)
end
function json_go_decode(str)
if type(str) ~= "string" then
error("json_decode: expected string, got " .. type(str), 2)
end
return __json_unmarshal(str)
end
function json_encode(data)
local t = type(data)
if t == "nil" then return "null" end
if t == "boolean" then return data and "true" or "false" end
if t == "number" then return tostring(data) end
if t == "string" then
return '"' .. data:gsub('[\\"\n\r\t]', escape_chars) .. '"'
end
if t == "table" then
local isArray = true
local count = 0
-- Check if it's an array in one pass
for k, _ in pairs(data) do
count = count + 1
if type(k) ~= "number" or k ~= count or k < 1 then
isArray = false
break
end
end
if isArray then
local result = {}
for i = 1, count do
result[i] = json_encode(data[i])
end
return "[" .. table.concat(result, ",") .. "]"
else
local result = {}
local index = 1
for k, v in pairs(data) do
if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then
result[index] = json_encode(k) .. ":" .. json_encode(v)
index = index + 1
end
end
return "{" .. table.concat(result, ",") .. "}"
end
end
return "null" -- Unsupported type
end
function json_decode(data)
local pos = 1
local len = #data
-- Pre-compute byte values
local b_space = string.byte(' ')
local b_tab = string.byte('\t')
local b_cr = string.byte('\r')
local b_lf = string.byte('\n')
local b_quote = string.byte('"')
local b_backslash = string.byte('\\')
local b_slash = string.byte('/')
local b_lcurly = string.byte('{')
local b_rcurly = string.byte('}')
local b_lbracket = string.byte('[')
local b_rbracket = string.byte(']')
local b_colon = string.byte(':')
local b_comma = string.byte(',')
local b_0 = string.byte('0')
local b_9 = string.byte('9')
local b_minus = string.byte('-')
local b_plus = string.byte('+')
local b_dot = string.byte('.')
local b_e = string.byte('e')
local b_E = string.byte('E')
-- Skip whitespace more efficiently
local function skip()
local b
while pos <= len do
b = data:byte(pos)
if b > b_space or (b ~= b_space and b ~= b_tab and b ~= b_cr and b ~= b_lf) then
break
end
pos = pos + 1
end
end
-- Forward declarations
local parse_value, parse_string, parse_number, parse_object, parse_array
-- Parse a string more efficiently
parse_string = function()
pos = pos + 1 -- Skip opening quote
if pos > len then
error("Unterminated string")
end
-- Use a table to build the string
local result = {}
local result_pos = 1
local start = pos
local c, b
while pos <= len do
b = data:byte(pos)
if b == b_backslash then
-- Add the chunk before the escape character
if pos > start then
result[result_pos] = data:sub(start, pos - 1)
result_pos = result_pos + 1
end
pos = pos + 1
if pos > len then
error("Unterminated string escape")
end
c = data:byte(pos)
if c == b_quote then
result[result_pos] = '"'
elseif c == b_backslash then
result[result_pos] = '\\'
elseif c == b_slash then
result[result_pos] = '/'
elseif c == string.byte('b') then
result[result_pos] = '\b'
elseif c == string.byte('f') then
result[result_pos] = '\f'
elseif c == string.byte('n') then
result[result_pos] = '\n'
elseif c == string.byte('r') then
result[result_pos] = '\r'
elseif c == string.byte('t') then
result[result_pos] = '\t'
else
result[result_pos] = data:sub(pos, pos)
end
result_pos = result_pos + 1
pos = pos + 1
start = pos
elseif b == b_quote then
-- Add the final chunk
if pos > start then
result[result_pos] = data:sub(start, pos - 1)
result_pos = result_pos + 1
end
pos = pos + 1
return table.concat(result)
else
pos = pos + 1
end
end
error("Unterminated string")
end
-- Parse a number more efficiently
parse_number = function()
local start = pos
local b = data:byte(pos)
-- Skip any sign
if b == b_minus then
pos = pos + 1
if pos > len then
error("Malformed number")
end
b = data:byte(pos)
end
-- Integer part
if b < b_0 or b > b_9 then
error("Malformed number")
end
repeat
pos = pos + 1
if pos > len then break end
b = data:byte(pos)
until b < b_0 or b > b_9
-- Fractional part
if pos <= len and b == b_dot then
pos = pos + 1
if pos > len or data:byte(pos) < b_0 or data:byte(pos) > b_9 then
error("Malformed number")
end
repeat
pos = pos + 1
if pos > len then break end
b = data:byte(pos)
until b < b_0 or b > b_9
end
-- Exponent
if pos <= len and (b == b_e or b == b_E) then
pos = pos + 1
if pos > len then
error("Malformed number")
end
b = data:byte(pos)
if b == b_plus or b == b_minus then
pos = pos + 1
if pos > len then
error("Malformed number")
end
b = data:byte(pos)
end
if b < b_0 or b > b_9 then
error("Malformed number")
end
repeat
pos = pos + 1
if pos > len then break end
b = data:byte(pos)
until b < b_0 or b > b_9
end
return tonumber(data:sub(start, pos - 1))
end
-- Parse an object more efficiently
parse_object = function()
pos = pos + 1 -- Skip opening brace
local obj = {}
skip()
if pos <= len and data:byte(pos) == b_rcurly then
pos = pos + 1
return obj
end
while pos <= len do
skip()
if data:byte(pos) ~= b_quote then
error("Expected string key")
end
local key = parse_string()
skip()
if data:byte(pos) ~= b_colon then
error("Expected colon")
end
pos = pos + 1
obj[key] = parse_value()
skip()
local b = data:byte(pos)
if b == b_rcurly then
pos = pos + 1
return obj
end
if b ~= b_comma then
error("Expected comma or closing brace")
end
pos = pos + 1
end
error("Unterminated object")
end
-- Parse an array more efficiently
parse_array = function()
pos = pos + 1 -- Skip opening bracket
local arr = {}
local index = 1
skip()
if pos <= len and data:byte(pos) == b_rbracket then
pos = pos + 1
return arr
end
while pos <= len do
arr[index] = parse_value()
index = index + 1
skip()
local b = data:byte(pos)
if b == b_rbracket then
pos = pos + 1
return arr
end
if b ~= b_comma then
error("Expected comma or closing bracket")
end
pos = pos + 1
end
error("Unterminated array")
end
-- Parse a value more efficiently
parse_value = function()
skip()
if pos > len then
error("Unexpected end of input")
end
local b = data:byte(pos)
if b == b_quote then
return parse_string()
elseif b == b_lcurly then
return parse_object()
elseif b == b_lbracket then
return parse_array()
elseif b == string.byte('n') and pos + 3 <= len and data:sub(pos, pos + 3) == "null" then
pos = pos + 4
return nil
elseif b == string.byte('t') and pos + 3 <= len and data:sub(pos, pos + 3) == "true" then
pos = pos + 4
return true
elseif b == string.byte('f') and pos + 4 <= len and data:sub(pos, pos + 4) == "false" then
pos = pos + 5
return false
elseif b == b_minus or (b >= b_0 and b <= b_9) then
return parse_number()
else
error("Unexpected character: " .. string.char(b))
end
end
skip()
local result = parse_value()
skip()
if pos <= len then
error("Unexpected trailing characters")
end
return result
end
function json_is_valid(str)
if type(str) ~= "string" then return false end
local status, _ = pcall(json_decode, str)
return status
end
function json_pretty_print(value)
if type(value) == "string" then
value = json_decode(value)
end
local function stringify(val, indent, visited)
visited = visited or {}
indent = indent or 0
local spaces = string.rep(" ", indent)
if type(val) == "table" then
if visited[val] then return "{...}" end
visited[val] = true
local isArray = true
local i = 1
for k in pairs(val) do
if type(k) ~= "number" or k ~= i then
isArray = false
break
end
i = i + 1
end
local result = isArray and "[\n" or "{\n"
local first = true
if isArray then
for i, v in ipairs(val) do
if not first then result = result .. ",\n" end
first = false
result = result .. spaces .. " " .. stringify(v, indent + 1, visited)
end
else
for k, v in pairs(val) do
if not first then result = result .. ",\n" end
first = false
result = result .. spaces .. " \"" .. tostring(k) .. "\": " .. stringify(v, indent + 1, visited)
end
end
return result .. "\n" .. spaces .. (isArray and "]" or "}")
elseif type(val) == "string" then
return "\"" .. val:gsub('\\', '\\\\'):gsub('"', '\\"'):gsub('\n', '\\n') .. "\""
else
return tostring(val)
end
end
return stringify(value)
end

View File

@ -1,802 +0,0 @@
--[[
math.lua - High-performance math library
]]--
local math_ext = {}
-- Import standard math functions
for name, func in pairs(_G.math) do
math_ext[name] = func
end
-- ======================================================================
-- CONSTANTS (higher precision)
-- ======================================================================
math_ext.pi = 3.14159265358979323846
math_ext.tau = 6.28318530717958647693 -- 2*pi
math_ext.e = 2.71828182845904523536
math_ext.phi = 1.61803398874989484820 -- Golden ratio
math_ext.sqrt2 = 1.41421356237309504880
math_ext.sqrt3 = 1.73205080756887729353
math_ext.ln2 = 0.69314718055994530942
math_ext.ln10 = 2.30258509299404568402
math_ext.infinity = 1/0
math_ext.nan = 0/0
-- ======================================================================
-- EXTENDED FUNCTIONS
-- ======================================================================
-- Cube root (handles negative numbers correctly)
function math_ext.cbrt(x)
return x < 0 and -(-x)^(1/3) or x^(1/3)
end
-- Hypotenuse of right-angled triangle
function math_ext.hypot(x, y)
return math.sqrt(x * x + y * y)
end
-- Check if value is NaN
function math_ext.isnan(x)
return x ~= x
end
-- Check if value is finite
function math_ext.isfinite(x)
return x > -math_ext.infinity and x < math_ext.infinity
end
-- Sign function (-1, 0, 1)
function math_ext.sign(x)
return x > 0 and 1 or (x < 0 and -1 or 0)
end
-- Clamp value between min and max
function math_ext.clamp(x, min, max)
return x < min and min or (x > max and max or x)
end
-- Linear interpolation
function math_ext.lerp(a, b, t)
return a + (b - a) * t
end
-- Smooth step interpolation
function math_ext.smoothstep(a, b, t)
t = math_ext.clamp((t - a) / (b - a), 0, 1)
return t * t * (3 - 2 * t)
end
-- Map value from one range to another
function math_ext.map(x, in_min, in_max, out_min, out_max)
return (x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min
end
-- Round to nearest integer
function math_ext.round(x)
return x >= 0 and math.floor(x + 0.5) or math.ceil(x - 0.5)
end
-- Round to specified decimal places
function math_ext.roundto(x, decimals)
local mult = 10 ^ (decimals or 0)
return math.floor(x * mult + 0.5) / mult
end
-- Normalize angle to [-π, π]
function math_ext.normalize_angle(angle)
return angle - 2 * math_ext.pi * math.floor((angle + math_ext.pi) / (2 * math_ext.pi))
end
-- Distance between points
function math_ext.distance(x1, y1, x2, y2)
local dx, dy = x2 - x1, y2 - y1
return math.sqrt(dx * dx + dy * dy)
end
-- ======================================================================
-- RANDOM NUMBER FUNCTIONS
-- ======================================================================
-- Random float in range [min, max)
function math_ext.randomf(min, max)
if not min and not max then
return math.random()
elseif not max then
max = min
min = 0
end
return min + math.random() * (max - min)
end
-- Random integer in range [min, max]
function math_ext.randint(min, max)
if not max then
max = min
min = 1
end
return math.floor(math.random() * (max - min + 1) + min)
end
-- Random boolean with probability p (default 0.5)
function math_ext.randboolean(p)
p = p or 0.5
return math.random() < p
end
-- ======================================================================
-- STATISTICS FUNCTIONS
-- ======================================================================
-- Sum of values
function math_ext.sum(t)
if type(t) ~= "table" then return 0 end
local sum = 0
for i=1, #t do
if type(t[i]) == "number" then
sum = sum + t[i]
end
end
return sum
end
-- Mean (average) of values
function math_ext.mean(t)
if type(t) ~= "table" or #t == 0 then return 0 end
local sum = 0
local count = 0
for i=1, #t do
if type(t[i]) == "number" then
sum = sum + t[i]
count = count + 1
end
end
return count > 0 and sum / count or 0
end
-- Median of values
function math_ext.median(t)
if type(t) ~= "table" or #t == 0 then return 0 end
local nums = {}
local count = 0
for i=1, #t do
if type(t[i]) == "number" then
count = count + 1
nums[count] = t[i]
end
end
if count == 0 then return 0 end
table.sort(nums)
if count % 2 == 0 then
return (nums[count/2] + nums[count/2 + 1]) / 2
else
return nums[math.ceil(count/2)]
end
end
-- Variance of values
function math_ext.variance(t)
if type(t) ~= "table" then return 0 end
local count = 0
local m = math_ext.mean(t)
local sum = 0
for i=1, #t do
if type(t[i]) == "number" then
local dev = t[i] - m
sum = sum + dev * dev
count = count + 1
end
end
return count > 1 and sum / count or 0
end
-- Standard deviation
function math_ext.stdev(t)
return math.sqrt(math_ext.variance(t))
end
-- Population variance
function math_ext.pvariance(t)
if type(t) ~= "table" then return 0 end
local count = 0
local m = math_ext.mean(t)
local sum = 0
for i=1, #t do
if type(t[i]) == "number" then
local dev = t[i] - m
sum = sum + dev * dev
count = count + 1
end
end
return count > 0 and sum / count or 0
end
-- Population standard deviation
function math_ext.pstdev(t)
return math.sqrt(math_ext.pvariance(t))
end
-- Mode (most common value)
function math_ext.mode(t)
if type(t) ~= "table" or #t == 0 then return nil end
local counts = {}
local most_frequent = nil
local max_count = 0
for i=1, #t do
local v = t[i]
counts[v] = (counts[v] or 0) + 1
if counts[v] > max_count then
max_count = counts[v]
most_frequent = v
end
end
return most_frequent
end
-- Min and max simultaneously (faster than calling both separately)
function math_ext.minmax(t)
if type(t) ~= "table" or #t == 0 then return nil, nil end
local min, max
for i=1, #t do
if type(t[i]) == "number" then
min = t[i]
max = t[i]
break
end
end
if min == nil then return nil, nil end
for i=1, #t do
if type(t[i]) == "number" then
if t[i] < min then min = t[i] end
if t[i] > max then max = t[i] end
end
end
return min, max
end
-- ======================================================================
-- VECTOR OPERATIONS (2D/3D vectors)
-- ======================================================================
-- 2D Vector operations
math_ext.vec2 = {
new = function(x, y)
return {x = x or 0, y = y or 0}
end,
copy = function(v)
return {x = v.x, y = v.y}
end,
add = function(a, b)
return {x = a.x + b.x, y = a.y + b.y}
end,
sub = function(a, b)
return {x = a.x - b.x, y = a.y - b.y}
end,
mul = function(a, b)
if type(b) == "number" then
return {x = a.x * b, y = a.y * b}
end
return {x = a.x * b.x, y = a.y * b.y}
end,
div = function(a, b)
if type(b) == "number" then
local inv = 1 / b
return {x = a.x * inv, y = a.y * inv}
end
return {x = a.x / b.x, y = a.y / b.y}
end,
dot = function(a, b)
return a.x * b.x + a.y * b.y
end,
length = function(v)
return math.sqrt(v.x * v.x + v.y * v.y)
end,
length_squared = function(v)
return v.x * v.x + v.y * v.y
end,
distance = function(a, b)
local dx, dy = b.x - a.x, b.y - a.y
return math.sqrt(dx * dx + dy * dy)
end,
distance_squared = function(a, b)
local dx, dy = b.x - a.x, b.y - a.y
return dx * dx + dy * dy
end,
normalize = function(v)
local len = math.sqrt(v.x * v.x + v.y * v.y)
if len > 1e-10 then
local inv_len = 1 / len
return {x = v.x * inv_len, y = v.y * inv_len}
end
return {x = 0, y = 0}
end,
rotate = function(v, angle)
local c, s = math.cos(angle), math.sin(angle)
return {
x = v.x * c - v.y * s,
y = v.x * s + v.y * c
}
end,
angle = function(v)
return math.atan2(v.y, v.x)
end,
lerp = function(a, b, t)
t = math_ext.clamp(t, 0, 1)
return {
x = a.x + (b.x - a.x) * t,
y = a.y + (b.y - a.y) * t
}
end,
reflect = function(v, normal)
local dot = v.x * normal.x + v.y * normal.y
return {
x = v.x - 2 * dot * normal.x,
y = v.y - 2 * dot * normal.y
}
end
}
-- 3D Vector operations
math_ext.vec3 = {
new = function(x, y, z)
return {x = x or 0, y = y or 0, z = z or 0}
end,
copy = function(v)
return {x = v.x, y = v.y, z = v.z}
end,
add = function(a, b)
return {x = a.x + b.x, y = a.y + b.y, z = a.z + b.z}
end,
sub = function(a, b)
return {x = a.x - b.x, y = a.y - b.y, z = a.z - b.z}
end,
mul = function(a, b)
if type(b) == "number" then
return {x = a.x * b, y = a.y * b, z = a.z * b}
end
return {x = a.x * b.x, y = a.y * b.y, z = a.z * b.z}
end,
div = function(a, b)
if type(b) == "number" then
local inv = 1 / b
return {x = a.x * inv, y = a.y * inv, z = a.z * inv}
end
return {x = a.x / b.x, y = a.y / b.y, z = a.z / b.z}
end,
dot = function(a, b)
return a.x * b.x + a.y * b.y + a.z * b.z
end,
cross = function(a, b)
return {
x = a.y * b.z - a.z * b.y,
y = a.z * b.x - a.x * b.z,
z = a.x * b.y - a.y * b.x
}
end,
length = function(v)
return math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
end,
length_squared = function(v)
return v.x * v.x + v.y * v.y + v.z * v.z
end,
distance = function(a, b)
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
return math.sqrt(dx * dx + dy * dy + dz * dz)
end,
distance_squared = function(a, b)
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
return dx * dx + dy * dy + dz * dz
end,
normalize = function(v)
local len = math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
if len > 1e-10 then
local inv_len = 1 / len
return {x = v.x * inv_len, y = v.y * inv_len, z = v.z * inv_len}
end
return {x = 0, y = 0, z = 0}
end,
lerp = function(a, b, t)
t = math_ext.clamp(t, 0, 1)
return {
x = a.x + (b.x - a.x) * t,
y = a.y + (b.y - a.y) * t,
z = a.z + (b.z - a.z) * t
}
end,
reflect = function(v, normal)
local dot = v.x * normal.x + v.y * normal.y + v.z * normal.z
return {
x = v.x - 2 * dot * normal.x,
y = v.y - 2 * dot * normal.y,
z = v.z - 2 * dot * normal.z
}
end
}
-- ======================================================================
-- MATRIX OPERATIONS (2x2 and 3x3 matrices)
-- ======================================================================
math_ext.mat2 = {
-- Create a new 2x2 matrix
new = function(a, b, c, d)
return {
{a or 1, b or 0},
{c or 0, d or 1}
}
end,
-- Create identity matrix
identity = function()
return {{1, 0}, {0, 1}}
end,
-- Matrix multiplication
mul = function(a, b)
return {
{
a[1][1] * b[1][1] + a[1][2] * b[2][1],
a[1][1] * b[1][2] + a[1][2] * b[2][2]
},
{
a[2][1] * b[1][1] + a[2][2] * b[2][1],
a[2][1] * b[1][2] + a[2][2] * b[2][2]
}
}
end,
-- Determinant
det = function(m)
return m[1][1] * m[2][2] - m[1][2] * m[2][1]
end,
-- Inverse matrix
inverse = function(m)
local det = m[1][1] * m[2][2] - m[1][2] * m[2][1]
if math.abs(det) < 1e-10 then
return nil -- Matrix is not invertible
end
local inv_det = 1 / det
return {
{m[2][2] * inv_det, -m[1][2] * inv_det},
{-m[2][1] * inv_det, m[1][1] * inv_det}
}
end,
-- Rotation matrix
rotation = function(angle)
local cos, sin = math.cos(angle), math.sin(angle)
return {
{cos, -sin},
{sin, cos}
}
end,
-- Apply matrix to vector
transform = function(m, v)
return {
x = m[1][1] * v.x + m[1][2] * v.y,
y = m[2][1] * v.x + m[2][2] * v.y
}
end,
-- Scale matrix
scale = function(sx, sy)
sy = sy or sx
return {
{sx, 0},
{0, sy}
}
end
}
math_ext.mat3 = {
-- Create identity matrix 3x3
identity = function()
return {
{1, 0, 0},
{0, 1, 0},
{0, 0, 1}
}
end,
-- Create a 2D transformation matrix (translation, rotation, scale)
transform = function(x, y, angle, sx, sy)
sx = sx or 1
sy = sy or sx
local cos, sin = math.cos(angle), math.sin(angle)
return {
{cos * sx, -sin * sy, x},
{sin * sx, cos * sy, y},
{0, 0, 1}
}
end,
-- Matrix multiplication
mul = function(a, b)
local result = {
{0, 0, 0},
{0, 0, 0},
{0, 0, 0}
}
for i = 1, 3 do
for j = 1, 3 do
for k = 1, 3 do
result[i][j] = result[i][j] + a[i][k] * b[k][j]
end
end
end
return result
end,
-- Apply matrix to point (homogeneous coordinates)
transform_point = function(m, v)
local x = m[1][1] * v.x + m[1][2] * v.y + m[1][3]
local y = m[2][1] * v.x + m[2][2] * v.y + m[2][3]
local w = m[3][1] * v.x + m[3][2] * v.y + m[3][3]
if math.abs(w) < 1e-10 then
return {x = 0, y = 0}
end
return {x = x / w, y = y / w}
end,
-- Translation matrix
translation = function(x, y)
return {
{1, 0, x},
{0, 1, y},
{0, 0, 1}
}
end,
-- Rotation matrix
rotation = function(angle)
local cos, sin = math.cos(angle), math.sin(angle)
return {
{cos, -sin, 0},
{sin, cos, 0},
{0, 0, 1}
}
end,
-- Scale matrix
scale = function(sx, sy)
sy = sy or sx
return {
{sx, 0, 0},
{0, sy, 0},
{0, 0, 1}
}
end,
-- Determinant
det = function(m)
return m[1][1] * (m[2][2] * m[3][3] - m[2][3] * m[3][2]) -
m[1][2] * (m[2][1] * m[3][3] - m[2][3] * m[3][1]) +
m[1][3] * (m[2][1] * m[3][2] - m[2][2] * m[3][1])
end
}
-- ======================================================================
-- GEOMETRY FUNCTIONS
-- ======================================================================
math_ext.geometry = {
-- Distance from point to line
point_line_distance = function(px, py, x1, y1, x2, y2)
local dx, dy = x2 - x1, y2 - y1
local len_sq = dx * dx + dy * dy
if len_sq < 1e-10 then
return math_ext.distance(px, py, x1, y1)
end
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
t = math_ext.clamp(t, 0, 1)
local nearestX = x1 + t * dx
local nearestY = y1 + t * dy
return math_ext.distance(px, py, nearestX, nearestY)
end,
-- Check if point is inside polygon
point_in_polygon = function(px, py, vertices)
local inside = false
local n = #vertices / 2
for i = 1, n do
local x1, y1 = vertices[i*2-1], vertices[i*2]
local x2, y2
if i == n then
x2, y2 = vertices[1], vertices[2]
else
x2, y2 = vertices[i*2+1], vertices[i*2+2]
end
if ((y1 > py) ~= (y2 > py)) and
(px < (x2 - x1) * (py - y1) / (y2 - y1) + x1) then
inside = not inside
end
end
return inside
end,
-- Area of a triangle
triangle_area = function(x1, y1, x2, y2, x3, y3)
return math.abs((x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2)) / 2)
end,
-- Check if point is inside triangle
point_in_triangle = function(px, py, x1, y1, x2, y2, x3, y3)
local area = math_ext.geometry.triangle_area(x1, y1, x2, y2, x3, y3)
local area1 = math_ext.geometry.triangle_area(px, py, x2, y2, x3, y3)
local area2 = math_ext.geometry.triangle_area(x1, y1, px, py, x3, y3)
local area3 = math_ext.geometry.triangle_area(x1, y1, x2, y2, px, py)
return math.abs(area - (area1 + area2 + area3)) < 1e-10
end,
-- Check if two line segments intersect
line_intersect = function(x1, y1, x2, y2, x3, y3, x4, y4)
local d = (y4 - y3) * (x2 - x1) - (x4 - x3) * (y2 - y1)
if math.abs(d) < 1e-10 then
return false, nil, nil -- Lines are parallel
end
local ua = ((x4 - x3) * (y1 - y3) - (y4 - y3) * (x1 - x3)) / d
local ub = ((x2 - x1) * (y1 - y3) - (y2 - y1) * (x1 - x3)) / d
if ua >= 0 and ua <= 1 and ub >= 0 and ub <= 1 then
local x = x1 + ua * (x2 - x1)
local y = y1 + ua * (y2 - y1)
return true, x, y
end
return false, nil, nil
end,
-- Closest point on line segment to point
closest_point_on_segment = function(px, py, x1, y1, x2, y2)
local dx, dy = x2 - x1, y2 - y1
local len_sq = dx * dx + dy * dy
if len_sq < 1e-10 then
return x1, y1
end
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
t = math_ext.clamp(t, 0, 1)
return x1 + t * dx, y1 + t * dy
end
}
-- ======================================================================
-- INTERPOLATION FUNCTIONS
-- ======================================================================
math_ext.interpolation = {
-- Cubic Bezier interpolation
bezier = function(t, p0, p1, p2, p3)
t = math_ext.clamp(t, 0, 1)
local t2 = t * t
local t3 = t2 * t
local mt = 1 - t
local mt2 = mt * mt
local mt3 = mt2 * mt
return p0 * mt3 + 3 * p1 * mt2 * t + 3 * p2 * mt * t2 + p3 * t3
end,
-- Catmull-Rom spline interpolation
catmull_rom = function(t, p0, p1, p2, p3)
t = math_ext.clamp(t, 0, 1)
local t2 = t * t
local t3 = t2 * t
return 0.5 * (
(2 * p1) +
(-p0 + p2) * t +
(2 * p0 - 5 * p1 + 4 * p2 - p3) * t2 +
(-p0 + 3 * p1 - 3 * p2 + p3) * t3
)
end,
-- Hermite interpolation
hermite = function(t, p0, p1, m0, m1)
t = math_ext.clamp(t, 0, 1)
local t2 = t * t
local t3 = t2 * t
local h00 = 2 * t3 - 3 * t2 + 1
local h10 = t3 - 2 * t2 + t
local h01 = -2 * t3 + 3 * t2
local h11 = t3 - t2
return h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
end,
-- Quadratic Bezier interpolation
quadratic_bezier = function(t, p0, p1, p2)
t = math_ext.clamp(t, 0, 1)
local mt = 1 - t
return mt * mt * p0 + 2 * mt * t * p1 + t * t * p2
end,
-- Step interpolation
step = function(t, edge, x)
return t < edge and 0 or x
end,
-- Smoothstep interpolation
smoothstep = function(edge0, edge1, x)
local t = math_ext.clamp((x - edge0) / (edge1 - edge0), 0, 1)
return t * t * (3 - 2 * t)
end,
-- Smootherstep interpolation (Ken Perlin)
smootherstep = function(edge0, edge1, x)
local t = math_ext.clamp((x - edge0) / (edge1 - edge0), 0, 1)
return t * t * t * (t * (t * 6 - 15) + 10)
end
}
return math_ext

View File

@ -1,667 +0,0 @@
--[[
sandbox.lua
]]--
__http_response = {}
__module_paths = {}
__module_bytecode = {}
__ready_modules = {}
__EXIT_SENTINEL = {} -- Unique object for exit identification
-- ======================================================================
-- CORE SANDBOX FUNCTIONALITY
-- ======================================================================
function exit()
error(__EXIT_SENTINEL)
end
-- Create environment inheriting from _G
function __create_env(ctx)
local env = setmetatable({}, {__index = _G})
if ctx then
env.ctx = ctx
if ctx._env then
env._env = ctx._env
end
end
if __setup_require then
__setup_require(env)
end
return env
end
-- Execute script with clean environment
function __execute_script(fn, ctx)
__http_response = nil
local env = __create_env(ctx)
env.exit = exit
setfenv(fn, env)
local ok, result = pcall(fn)
if not ok then
if result == __EXIT_SENTINEL then
return
end
error(result, 0)
end
return result
end
-- Ensure __http_response exists, then return it
function __ensure_response()
if not __http_response then
__http_response = {}
end
return __http_response
end
-- ======================================================================
-- HTTP FUNCTIONS
-- ======================================================================
-- Set HTTP status code
function http_set_status(code)
if type(code) ~= "number" then
error("http_set_status: status code must be a number", 2)
end
local resp = __ensure_response()
resp.status = code
end
-- Set HTTP header
function http_set_header(name, value)
if type(name) ~= "string" or type(value) ~= "string" then
error("http_set_header: name and value must be strings", 2)
end
local resp = __ensure_response()
resp.headers = resp.headers or {}
resp.headers[name] = value
end
-- Set content type; http_set_header helper
function http_set_content_type(content_type)
http_set_header("Content-Type", content_type)
end
-- Set metadata (arbitrary data to be returned with response)
function http_set_metadata(key, value)
if type(key) ~= "string" then
error("http_set_metadata: key must be a string", 2)
end
local resp = __ensure_response()
resp.metadata = resp.metadata or {}
resp.metadata[key] = value
end
-- Generic HTTP request function
function http_request(method, url, body, options)
if type(method) ~= "string" then
error("http_request: method must be a string", 2)
end
if type(url) ~= "string" then
error("http_request: url must be a string", 2)
end
-- Call native implementation
local result = __http_request(method, url, body, options)
return result
end
-- Shorthand function to directly get JSON
function http_get_json(url, options)
options = options or {}
local response = http_get(url, options)
if response.ok and response.json then
return response.json
end
return nil, response
end
-- Utility to build a URL with query parameters
function http_build_url(base_url, params)
if not params or type(params) ~= "table" then
return base_url
end
local query = {}
for k, v in pairs(params) do
if type(v) == "table" then
for _, item in ipairs(v) do
table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(item)))
end
else
table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(v)))
end
end
if #query > 0 then
if string.contains(base_url, "?") then
return base_url .. "&" .. table.concat(query, "&")
else
return base_url .. "?" .. table.concat(query, "&")
end
end
return base_url
end
local function make_method(method, needs_body)
return function(url, body_or_options, options)
if needs_body then
options = options or {}
return http_request(method, url, body_or_options, options)
else
body_or_options = body_or_options or {}
return http_request(method, url, nil, body_or_options)
end
end
end
http_get = make_method("GET", false)
http_delete = make_method("DELETE", false)
http_head = make_method("HEAD", false)
http_options = make_method("OPTIONS", false)
http_post = make_method("POST", true)
http_put = make_method("PUT", true)
http_patch = make_method("PATCH", true)
function http_redirect(url, status)
if type(url) ~= "string" then
error("http_redirect: url must be a string", 2)
end
status = status or 302 -- Default to temporary redirect
local resp = __ensure_response()
resp.status = status
resp.headers = resp.headers or {}
resp.headers["Location"] = url
exit()
end
-- ======================================================================
-- COOKIE FUNCTIONS
-- ======================================================================
-- Set a cookie
function cookie_set(name, value, options)
if type(name) ~= "string" then
error("cookie_set: name must be a string", 2)
end
local resp = __ensure_response()
resp.cookies = resp.cookies or {}
local opts = options or {}
local cookie = {
name = name,
value = value or "",
path = opts.path or "/",
domain = opts.domain
}
if opts.expires then
if type(opts.expires) == "number" then
if opts.expires > 0 then
cookie.max_age = opts.expires
local now = os.time()
cookie.expires = now + opts.expires
elseif opts.expires < 0 then
cookie.expires = 1
cookie.max_age = 0
end
-- opts.expires == 0: Session cookie (omitting both expires and max-age)
end
end
cookie.secure = (opts.secure ~= false)
cookie.http_only = (opts.http_only ~= false)
if opts.same_site then
local same_site = string.trim(opts.same_site):lower()
local valid_values = {none = true, lax = true, strict = true}
if not valid_values[same_site] then
error("cookie_set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
end
-- If SameSite=None, the cookie must be secure
if same_site == "none" and not cookie.secure then
cookie.secure = true
end
cookie.same_site = opts.same_site
else
cookie.same_site = "Lax"
end
table.insert(resp.cookies, cookie)
return true
end
-- Get a cookie value
function cookie_get(name)
if type(name) ~= "string" then
error("cookie_get: name must be a string", 2)
end
local env = getfenv(2)
if env.ctx and env.ctx.cookies then
return env.ctx.cookies[name]
end
if env.ctx and env.ctx._request_cookies then
return env.ctx._request_cookies[name]
end
return nil
end
-- Remove a cookie
function cookie_remove(name, path, domain)
if type(name) ~= "string" then
error("cookie_remove: name must be a string", 2)
end
return cookie_set(name, "", {expires = 0, path = path or "/", domain = domain})
end
-- ======================================================================
-- SESSION FUNCTIONS
-- ======================================================================
function session_get(key)
if type(key) ~= "string" then
error("session_get: key must be a string", 2)
end
local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then
return env.ctx.session.data[key]
end
return nil
end
function session_set(key, value)
if type(key) ~= "string" then
error("session_set: key must be a string", 2)
end
if type(value) == nil then
error("session_set: value cannot be nil", 2)
end
local resp = __ensure_response()
resp.session = resp.session or {}
resp.session[key] = value
local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then
env.ctx.session.data[key] = value
end
end
function session_id()
local env = getfenv(2)
if env.ctx and env.ctx.session then
return env.ctx.session.id
end
return nil
end
function session_get_all()
local env = getfenv(2)
if env.ctx and env.ctx.session then
return env.ctx.session.data
end
return nil
end
function session_delete(key)
if type(key) ~= "string" then
error("session_delete: key must be a string", 2)
end
local resp = __ensure_response()
resp.session = resp.session or {}
resp.session[key] = "__SESSION_DELETE_MARKER__"
local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then
env.ctx.session.data[key] = nil
end
end
function session_clear()
local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then
for k, _ in pairs(env.ctx.session.data) do
env.ctx.session.data[k] = nil
end
end
local resp = __ensure_response()
resp.session = {}
resp.session["__clear_all"] = true
end
-- ======================================================================
-- CSRF FUNCTIONS
-- ======================================================================
function csrf_generate()
local token = generate_token(32)
session_set("_csrf_token", token)
return token
end
function csrf_field()
local token = session_get("_csrf_token")
if not token then
token = csrf_generate()
end
return string.format('<input type="hidden" name="_csrf_token" value="%s" />',
html_special_chars(token))
end
function csrf_validate()
local env = getfenv(2)
local token = false
if env.ctx and env.ctx.session and env.ctx.session.data then
token = env.ctx.session.data["_csrf_token"]
end
if not token then
http_set_status(403)
__http_response.body = "CSRF validation failed"
exit()
end
local request_token = nil
if env.ctx and env.ctx.form then
request_token = env.ctx.form._csrf_token
end
if not request_token and env.ctx and env.ctx._request_headers then
request_token = env.ctx._request_headers["x-csrf-token"] or
env.ctx._request_headers["csrf-token"]
end
if not request_token or request_token ~= token then
http_set_status(403)
__http_response.body = "CSRF validation failed"
exit()
end
return true
end
-- ======================================================================
-- TEMPLATE RENDER FUNCTIONS
-- ======================================================================
-- Template processing with code execution
_G.render = function(template_str, env)
local function get_line(s, ln)
for line in s:gmatch("([^\n]*)\n?") do
if ln == 1 then return line end
ln = ln - 1
end
end
local function pos_to_line(s, pos)
local line = 1
for _ in s:sub(1, pos):gmatch("\n") do line = line + 1 end
return line
end
local pos, chunks = 1, {}
while pos <= #template_str do
local unescaped_start = template_str:find("{{{", pos, true)
local escaped_start = template_str:find("{{", pos, true)
local start, tag_type, open_len
if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then
start, tag_type, open_len = unescaped_start, "-", 3
elseif escaped_start then
start, tag_type, open_len = escaped_start, "=", 2
else
table.insert(chunks, template_str:sub(pos))
break
end
if start > pos then
table.insert(chunks, template_str:sub(pos, start-1))
end
pos = start + open_len
local close_tag = tag_type == "-" and "}}}" or "}}"
local close_start, close_stop = template_str:find(close_tag, pos, true)
if not close_start then
error("Failed to find closing tag at position " .. pos)
end
local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$")
-- Check if it's a simple variable name for escaped output
local is_simple_var = tag_type == "=" and code:match("^[%w_]+$")
table.insert(chunks, {tag_type, code, pos, is_simple_var})
pos = close_stop + 1
end
local buffer = {"local _tostring, _escape, _b, _b_i = ...\n"}
for _, chunk in ipairs(chunks) do
local t = type(chunk)
if t == "string" then
table.insert(buffer, "_b_i = _b_i + 1\n")
table.insert(buffer, "_b[_b_i] = " .. string.format("%q", chunk) .. "\n")
else
t = chunk[1]
if t == "=" then
if chunk[4] then -- is_simple_var
table.insert(buffer, "_b_i = _b_i + 1\n")
table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _escape(_tostring(" .. chunk[2] .. "))\n")
else
table.insert(buffer, "--[[" .. chunk[3] .. "]] " .. chunk[2] .. "\n")
end
elseif t == "-" then
table.insert(buffer, "_b_i = _b_i + 1\n")
table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _tostring(" .. chunk[2] .. ")\n")
end
end
end
table.insert(buffer, "return _b")
local fn, err = loadstring(table.concat(buffer))
if not fn then error(err) end
env = env or {}
local runtime_env = setmetatable({}, {__index = function(_, k) return env[k] or _G[k] end})
setfenv(fn, runtime_env)
local output_buffer = {}
fn(tostring, html_special_chars, output_buffer, 0)
return table.concat(output_buffer)
end
-- Named placeholder processing
_G.parse = function(template_str, env)
local pos, output = 1, {}
env = env or {}
while pos <= #template_str do
local unescaped_start, unescaped_end, unescaped_name = template_str:find("{{{%s*([%w_]+)%s*}}}", pos)
local escaped_start, escaped_end, escaped_name = template_str:find("{{%s*([%w_]+)%s*}}", pos)
local next_pos, placeholder_end, name, escaped
if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then
next_pos, placeholder_end, name, escaped = unescaped_start, unescaped_end, unescaped_name, false
elseif escaped_start then
next_pos, placeholder_end, name, escaped = escaped_start, escaped_end, escaped_name, true
else
local text = template_str:sub(pos)
if text and #text > 0 then
table.insert(output, text)
end
break
end
local text = template_str:sub(pos, next_pos - 1)
if text and #text > 0 then
table.insert(output, text)
end
local value = env[name]
local str = tostring(value or "")
if escaped then
str = html_special_chars(str)
end
table.insert(output, str)
pos = placeholder_end + 1
end
return table.concat(output)
end
-- Indexed placeholder processing
_G.iparse = function(template_str, values)
local pos, output, value_index = 1, {}, 1
values = values or {}
while pos <= #template_str do
local unescaped_start, unescaped_end = template_str:find("{{{}}}", pos, true)
local escaped_start, escaped_end = template_str:find("{{}}", pos, true)
local next_pos, placeholder_end, escaped
if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then
next_pos, placeholder_end, escaped = unescaped_start, unescaped_end, false
elseif escaped_start then
next_pos, placeholder_end, escaped = escaped_start, escaped_end, true
else
local text = template_str:sub(pos)
if text and #text > 0 then
table.insert(output, text)
end
break
end
local text = template_str:sub(pos, next_pos - 1)
if text and #text > 0 then
table.insert(output, text)
end
local value = values[value_index]
local str = tostring(value or "")
if escaped then
str = html_special_chars(str)
end
table.insert(output, str)
pos = placeholder_end + 1
value_index = value_index + 1
end
return table.concat(output)
end
-- ======================================================================
-- PASSWORD FUNCTIONS
-- ======================================================================
-- Hash a password using Argon2id
-- Options:
-- memory: Amount of memory to use in KB (default: 128MB)
-- iterations: Number of iterations (default: 4)
-- parallelism: Number of threads (default: 4)
-- salt_length: Length of salt in bytes (default: 16)
-- key_length: Length of the derived key in bytes (default: 32)
function password_hash(plain_password, options)
if type(plain_password) ~= "string" then
error("password_hash: expected string password", 2)
end
return __password_hash(plain_password, options)
end
-- Verify a password against a hash
function password_verify(plain_password, hash_string)
if type(plain_password) ~= "string" then
error("password_verify: expected string password", 2)
end
if type(hash_string) ~= "string" then
error("password_verify: expected string hash", 2)
end
return __password_verify(plain_password, hash_string)
end
-- ======================================================================
-- SEND FUNCTIONS
-- ======================================================================
function send_html(content)
http_set_content_type("text/html")
return content
end
function send_json(content)
http_set_content_type("application/json")
return content
end
function send_text(content)
http_set_content_type("text/plain")
return content
end
function send_xml(content)
http_set_content_type("application/xml")
return content
end
function send_javascript(content)
http_set_content_type("application/javascript")
return content
end
function send_css(content)
http_set_content_type("text/css")
return content
end
function send_svg(content)
http_set_content_type("image/svg+xml")
return content
end
function send_csv(content)
http_set_content_type("text/csv")
return content
end
function send_binary(content, mime_type)
http_set_content_type(mime_type or "application/octet-stream")
return content
end

View File

@ -1,297 +0,0 @@
local function normalize_params(params, ...)
if type(params) == "table" then return params end
local args = {...}
if #args > 0 or params ~= nil then
table.insert(args, 1, params)
return args
end
return nil
end
local connection_mt = {
__index = {
query = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:query: query must be a string", 2)
end
local normalized_params = normalize_params(params, ...)
return __sqlite_query(self.db_name, query, normalized_params)
end,
exec = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:exec: query must be a string", 2)
end
local normalized_params = normalize_params(params, ...)
return __sqlite_exec(self.db_name, query, normalized_params)
end,
get_one = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2)
end
local normalized_params = normalize_params(params, ...)
return __sqlite_get_one(self.db_name, query, normalized_params)
end,
insert = function(self, table_name, data, columns)
if type(data) ~= "table" then
error("connection:insert: data must be a table", 2)
end
-- Single object: {col1=val1, col2=val2}
if data[1] == nil and next(data) ~= nil then
local cols = table.keys(data)
local placeholders = table.map(cols, function(_, i) return ":p" .. i end)
local params = {}
for i, col in ipairs(cols) do
params["p" .. i] = data[col]
end
local query = string.format(
"INSERT INTO %s (%s) VALUES (%s)",
table_name,
table.concat(cols, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
-- Array data with columns
if columns and type(columns) == "table" then
if #data > 0 and type(data[1]) == "table" then
-- Multiple rows
local value_groups = {}
local params = {}
local param_idx = 1
for _, row in ipairs(data) do
local row_placeholders = {}
for j = 1, #columns do
local param_name = "p" .. param_idx
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[j]
param_idx = param_idx + 1
end
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(columns, ", "),
table.concat(value_groups, ", ")
)
return self:exec(query, params)
else
-- Single row array
local placeholders = table.map(columns, function(_, i) return ":p" .. i end)
local params = {}
for i = 1, #columns do
params["p" .. i] = data[i]
end
local query = string.format(
"INSERT INTO %s (%s) VALUES (%s)",
table_name,
table.concat(columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
end
-- Array of objects
if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then
local cols = table.keys(data[1])
local value_groups = {}
local params = {}
local param_idx = 1
for _, row in ipairs(data) do
local row_placeholders = {}
for _, col in ipairs(cols) do
local param_name = "p" .. param_idx
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[col]
param_idx = param_idx + 1
end
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(cols, ", "),
table.concat(value_groups, ", ")
)
return self:exec(query, params)
end
error("connection:insert: invalid data format", 2)
end,
update = function(self, table_name, data, where, where_params, ...)
if type(data) ~= "table" or next(data) == nil then
return 0
end
local sets = {}
local params = {}
local param_idx = 1
for col, val in pairs(data) do
local param_name = "p" .. param_idx
table.insert(sets, col .. " = :" .. param_name)
params[param_name] = val
param_idx = param_idx + 1
end
local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", "))
if where then
query = query .. " WHERE " .. where
if where_params then
local normalized = normalize_params(where_params, ...)
if type(normalized) == "table" then
for k, v in pairs(normalized) do
if type(k) == "string" then
params[k] = v
else
params["w" .. param_idx] = v
param_idx = param_idx + 1
end
end
end
end
end
return self:exec(query, params)
end,
create_table = function(self, table_name, ...)
local column_definitions = {}
local index_definitions = {}
for _, def_string in ipairs({...}) do
if type(def_string) == "string" then
local is_unique = false
local index_def = def_string
if string.starts_with(def_string, "UNIQUE INDEX:") then
is_unique = true
index_def = string.trim(def_string:sub(14))
elseif string.starts_with(def_string, "INDEX:") then
index_def = string.trim(def_string:sub(7))
else
table.insert(column_definitions, def_string)
goto continue
end
local paren_pos = index_def:find("%(")
if not paren_pos then goto continue end
local index_name = string.trim(index_def:sub(1, paren_pos - 1))
local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$")
if not columns_part then goto continue end
local columns = table.map(string.split(columns_part, ","), string.trim)
if #columns > 0 then
table.insert(index_definitions, {
name = index_name,
columns = columns,
unique = is_unique
})
end
end
::continue::
end
if #column_definitions == 0 then
error("connection:create_table: no column definitions specified for table " .. table_name, 2)
end
local statements = {}
table.insert(statements, string.format(
"CREATE TABLE IF NOT EXISTS %s (%s)",
table_name,
table.concat(column_definitions, ", ")
))
for _, idx in ipairs(index_definitions) do
local unique_prefix = idx.unique and "UNIQUE " or ""
table.insert(statements, string.format(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique_prefix,
idx.name,
table_name,
table.concat(idx.columns, ", ")
))
end
return self:exec(table.concat(statements, ";\n"))
end,
delete = function(self, table_name, where, params, ...)
local query = "DELETE FROM " .. table_name
if where then
query = query .. " WHERE " .. where
end
return self:exec(query, normalize_params(params, ...))
end,
exists = function(self, table_name, where, params, ...)
if type(table_name) ~= "string" then
error("connection:exists: table_name must be a string", 2)
end
local query = "SELECT 1 FROM " .. table_name
if where then
query = query .. " WHERE " .. where
end
query = query .. " LIMIT 1"
local results = self:query(query, normalize_params(params, ...))
return #results > 0
end,
begin = function(self)
return self:exec("BEGIN TRANSACTION")
end,
commit = function(self)
return self:exec("COMMIT")
end,
rollback = function(self)
return self:exec("ROLLBACK")
end,
transaction = function(self, callback)
self:begin()
local success, result = pcall(callback, self)
if success then
self:commit()
return result
else
self:rollback()
error(result, 2)
end
end
}
}
return function(db_name)
if type(db_name) ~= "string" then
error("sqlite: database name must be a string", 2)
end
return setmetatable({
db_name = db_name
}, connection_mt)
end

View File

@ -1,197 +0,0 @@
--[[
string.lua - Extended string library functions
]]--
local string_ext = {}
-- ======================================================================
-- STRING UTILITY FUNCTIONS
-- ======================================================================
-- Trim whitespace from both ends
function string_ext.trim(s)
if type(s) ~= "string" then return s end
return s:match("^%s*(.-)%s*$")
end
-- Split string by delimiter
function string_ext.split(s, delimiter)
if type(s) ~= "string" then return {} end
delimiter = delimiter or ","
local result = {}
for match in (s..delimiter):gmatch("(.-)"..delimiter) do
table.insert(result, match)
end
return result
end
-- Check if string starts with prefix
function string_ext.starts_with(s, prefix)
if type(s) ~= "string" or type(prefix) ~= "string" then return false end
return s:sub(1, #prefix) == prefix
end
-- Check if string ends with suffix
function string_ext.ends_with(s, suffix)
if type(s) ~= "string" or type(suffix) ~= "string" then return false end
return suffix == "" or s:sub(-#suffix) == suffix
end
-- Left pad a string
function string_ext.pad_left(s, len, char)
if type(s) ~= "string" or type(len) ~= "number" then return s end
char = char or " "
if #s >= len then return s end
return string.rep(char:sub(1,1), len - #s) .. s
end
-- Right pad a string
function string_ext.pad_right(s, len, char)
if type(s) ~= "string" or type(len) ~= "number" then return s end
char = char or " "
if #s >= len then return s end
return s .. string.rep(char:sub(1,1), len - #s)
end
-- Center a string
function string_ext.center(s, width, char)
if type(s) ~= "string" or width <= #s then return s end
char = char or " "
local pad_len = width - #s
local left_pad = math.floor(pad_len / 2)
local right_pad = pad_len - left_pad
return string.rep(char:sub(1,1), left_pad) .. s .. string.rep(char:sub(1,1), right_pad)
end
-- Count occurrences of substring
function string_ext.count(s, substr)
if type(s) ~= "string" or type(substr) ~= "string" or #substr == 0 then return 0 end
local count, pos = 0, 1
while true do
pos = s:find(substr, pos, true)
if not pos then break end
count = count + 1
pos = pos + 1
end
return count
end
-- Capitalize first letter
function string_ext.capitalize(s)
if type(s) ~= "string" or #s == 0 then return s end
return s:sub(1,1):upper() .. s:sub(2)
end
-- Capitalize all words
function string_ext.title(s)
if type(s) ~= "string" then return s end
return s:gsub("(%w)([%w]*)", function(first, rest)
return first:upper() .. rest:lower()
end)
end
-- Insert string at position
function string_ext.insert(s, pos, insert_str)
if type(s) ~= "string" or type(insert_str) ~= "string" then return s end
pos = math.max(1, math.min(pos, #s + 1))
return s:sub(1, pos - 1) .. insert_str .. s:sub(pos)
end
-- Remove substring
function string_ext.remove(s, start, length)
if type(s) ~= "string" then return s end
length = length or 1
if start < 1 or start > #s then return s end
return s:sub(1, start - 1) .. s:sub(start + length)
end
-- Replace substring once
function string_ext.replace(s, old, new, n)
if type(s) ~= "string" or type(old) ~= "string" or #old == 0 then return s end
new = new or ""
n = n or 1
return s:gsub(old:gsub("[%-%^%$%(%)%%%.%[%]%*%+%-%?]", "%%%1"), new, n)
end
-- Check if string contains substring
function string_ext.contains(s, substr)
if type(s) ~= "string" or type(substr) ~= "string" then return false end
return s:find(substr, 1, true) ~= nil
end
-- Escape pattern magic characters
function string_ext.escape_pattern(s)
if type(s) ~= "string" then return s end
return s:gsub("[%-%^%$%(%)%%%.%[%]%*%+%-%?]", "%%%1")
end
-- Wrap text at specified width
function string_ext.wrap(s, width, indent_first, indent_rest)
if type(s) ~= "string" or type(width) ~= "number" then return s end
width = math.max(1, width)
indent_first = indent_first or ""
indent_rest = indent_rest or indent_first
local result = {}
local line_prefix = indent_first
local pos = 1
while pos <= #s do
local line_width = width - #line_prefix
local end_pos = math.min(pos + line_width - 1, #s)
if end_pos < #s then
local last_space = s:sub(pos, end_pos):match(".*%s()")
if last_space then
end_pos = pos + last_space - 2
end
end
table.insert(result, line_prefix .. s:sub(pos, end_pos))
pos = end_pos + 1
-- Skip leading spaces on next line
while s:sub(pos, pos) == " " do
pos = pos + 1
end
line_prefix = indent_rest
end
return table.concat(result, "\n")
end
-- Limit string length with ellipsis
function string_ext.truncate(s, length, ellipsis)
if type(s) ~= "string" then return s end
ellipsis = ellipsis or "..."
if #s <= length then return s end
return s:sub(1, length - #ellipsis) .. ellipsis
end
-- ======================================================================
-- INSTALL EXTENSIONS INTO STRING LIBRARY
-- ======================================================================
for name, func in pairs(string) do
string_ext[name] = func
end
return string_ext

File diff suppressed because it is too large Load Diff

View File

@ -1,130 +0,0 @@
--[[
time.lua - High performance timing functions
]]--
local ffi = require('ffi')
local is_windows = (ffi.os == "Windows")
-- Define C structures and functions based on platform
if is_windows then
ffi.cdef[[
typedef struct {
int64_t QuadPart;
} LARGE_INTEGER;
int QueryPerformanceCounter(LARGE_INTEGER* lpPerformanceCount);
int QueryPerformanceFrequency(LARGE_INTEGER* lpFrequency);
]]
else
ffi.cdef[[
typedef long time_t;
typedef struct timeval {
long tv_sec;
long tv_usec;
} timeval;
int gettimeofday(struct timeval* tv, void* tz);
time_t time(time_t* t);
]]
end
local time = {}
local has_initialized = false
local start_time, timer_freq
-- Initialize timing system based on platform
local function init()
if has_initialized then return end
if ffi.os == "Windows" then
local frequency = ffi.new("LARGE_INTEGER")
ffi.C.QueryPerformanceFrequency(frequency)
timer_freq = tonumber(frequency.QuadPart)
local counter = ffi.new("LARGE_INTEGER")
ffi.C.QueryPerformanceCounter(counter)
start_time = tonumber(counter.QuadPart)
else
-- Nothing special needed for Unix platform init
start_time = ffi.C.time(nil)
end
has_initialized = true
end
-- PHP-compatible microtime implementation
function time.microtime(get_as_float)
init()
if ffi.os == "Windows" then
local counter = ffi.new("LARGE_INTEGER")
ffi.C.QueryPerformanceCounter(counter)
local now = tonumber(counter.QuadPart)
local seconds = math.floor((now - start_time) / timer_freq)
local microseconds = ((now - start_time) % timer_freq) * 1000000 / timer_freq
if get_as_float then
return seconds + microseconds / 1000000
else
return string.format("0.%06d %d", microseconds, seconds)
end
else
local tv = ffi.new("struct timeval")
ffi.C.gettimeofday(tv, nil)
if get_as_float then
return tonumber(tv.tv_sec) + tonumber(tv.tv_usec) / 1000000
else
return string.format("0.%06d %d", tv.tv_usec, tv.tv_sec)
end
end
end
-- High-precision monotonic timer (returns seconds with microsecond precision)
function time.monotonic()
init()
if ffi.os == "Windows" then
local counter = ffi.new("LARGE_INTEGER")
ffi.C.QueryPerformanceCounter(counter)
local now = tonumber(counter.QuadPart)
return (now - start_time) / timer_freq
else
local tv = ffi.new("struct timeval")
ffi.C.gettimeofday(tv, nil)
return tonumber(tv.tv_sec) - start_time + tonumber(tv.tv_usec) / 1000000
end
end
-- Benchmark function that measures execution time
function time.benchmark(func, iterations, warmup)
iterations = iterations or 1000
warmup = warmup or 10
-- Warmup
for i=1, warmup do func() end
local start = time.microtime(true)
for i=1, iterations do
func()
end
local finish = time.microtime(true)
local elapsed = (finish - start) * 1000000 -- Convert to microseconds
return elapsed / iterations
end
-- Simple sleep function using coroutine yielding
function time.sleep(seconds)
if type(seconds) ~= "number" or seconds <= 0 then
return
end
local start = time.monotonic()
while time.monotonic() - start < seconds do
-- Use coroutine.yield to avoid consuming CPU
coroutine.yield()
end
end
_G.microtime = time.microtime
return time

View File

@ -1,293 +0,0 @@
--[[
util.lua - Utility functions for the Lua sandbox
]]--
-- ======================================================================
-- CORE UTILITY FUNCTIONS
-- ======================================================================
-- Generate a random token
function generate_token(length)
return __generate_token(length or 32)
end
-- ======================================================================
-- HTML ENTITY FUNCTIONS
-- ======================================================================
-- Convert special characters to HTML entities (like htmlspecialchars)
function html_special_chars(str)
if type(str) ~= "string" then
return str
end
return __html_special_chars(str)
end
-- Convert all applicable characters to HTML entities (like htmlentities)
function html_entities(str)
if type(str) ~= "string" then
return str
end
return __html_entities(str)
end
-- Convert HTML entities back to characters (simple version)
function html_entity_decode(str)
if type(str) ~= "string" then
return str
end
str = str:gsub("&lt;", "<")
str = str:gsub("&gt;", ">")
str = str:gsub("&quot;", '"')
str = str:gsub("&#39;", "'")
str = str:gsub("&amp;", "&")
return str
end
-- Convert newlines to <br> tags
function nl2br(str)
if type(str) ~= "string" then
return str
end
return str:gsub("\r\n", "<br>"):gsub("\n", "<br>"):gsub("\r", "<br>")
end
-- ======================================================================
-- URL FUNCTIONS
-- ======================================================================
-- URL encode a string
function url_encode(str)
if type(str) ~= "string" then
return str
end
str = str:gsub("\n", "\r\n")
str = str:gsub("([^%w %-%_%.%~])", function(c)
return string.format("%%%02X", string.byte(c))
end)
str = str:gsub(" ", "+")
return str
end
-- URL decode a string
function url_decode(str)
if type(str) ~= "string" then
return str
end
str = str:gsub("+", " ")
str = str:gsub("%%(%x%x)", function(h)
return string.char(tonumber(h, 16))
end)
return str
end
-- ======================================================================
-- VALIDATION FUNCTIONS
-- ======================================================================
-- Email validation
function is_email(str)
if type(str) ~= "string" then
return false
end
-- Simple email validation pattern
local pattern = "^[%w%.%%%+%-]+@[%w%.%%%+%-]+%.%w%w%w?%w?$"
return str:match(pattern) ~= nil
end
-- URL validation
function is_url(str)
if type(str) ~= "string" then
return false
end
-- Simple URL validation
local pattern = "^https?://[%w-_%.%?%.:/%+=&%%]+$"
return str:match(pattern) ~= nil
end
-- IP address validation (IPv4)
function is_ipv4(str)
if type(str) ~= "string" then
return false
end
local pattern = "^(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)$"
local a, b, c, d = str:match(pattern)
if not (a and b and c and d) then
return false
end
a, b, c, d = tonumber(a), tonumber(b), tonumber(c), tonumber(d)
return a <= 255 and b <= 255 and c <= 255 and d <= 255
end
-- Integer validation
function is_int(str)
if type(str) == "number" then
return math.floor(str) == str
elseif type(str) ~= "string" then
return false
end
return str:match("^-?%d+$") ~= nil
end
-- Float validation
function is_float(str)
if type(str) == "number" then
return true
elseif type(str) ~= "string" then
return false
end
return str:match("^-?%d+%.?%d*$") ~= nil
end
-- Boolean validation
function is_bool(value)
if type(value) == "boolean" then
return true
elseif type(value) ~= "string" and type(value) ~= "number" then
return false
end
local v = type(value) == "string" and value:lower() or value
return v == "1" or v == "true" or v == "on" or v == "yes" or
v == "0" or v == "false" or v == "off" or v == "no" or
v == 1 or v == 0
end
-- Convert to boolean
function to_bool(value)
if type(value) == "boolean" then
return value
elseif type(value) ~= "string" and type(value) ~= "number" then
return false
end
local v = type(value) == "string" and value:lower() or value
return v == "1" or v == "true" or v == "on" or v == "yes" or v == 1
end
-- Sanitize string (simple version)
function sanitize_string(str)
if type(str) ~= "string" then
return ""
end
return html_special_chars(str)
end
-- Sanitize to integer
function sanitize_int(value)
if type(value) ~= "string" and type(value) ~= "number" then
return 0
end
value = tostring(value)
local result = value:match("^-?%d+")
return result and tonumber(result) or 0
end
-- Sanitize to float
function sanitize_float(value)
if type(value) ~= "string" and type(value) ~= "number" then
return 0
end
value = tostring(value)
local result = value:match("^-?%d+%.?%d*")
return result and tonumber(result) or 0
end
-- Sanitize URL
function sanitize_url(str)
if type(str) ~= "string" then
return ""
end
-- Basic sanitization by removing control characters
str = str:gsub("[\000-\031]", "")
-- Make sure it's a valid URL
if is_url(str) then
return str
end
-- Try to prepend http:// if it's missing
if not str:match("^https?://") and is_url("http://" .. str) then
return "http://" .. str
end
return ""
end
-- Sanitize email
function sanitize_email(str)
if type(str) ~= "string" then
return ""
end
-- Remove all characters except common email characters
str = str:gsub("[^%a%d%!%#%$%%%&%'%*%+%-%/%=%?%^%_%`%{%|%}%~%@%.%[%]]", "")
-- Return only if it's a valid email
if is_email(str) then
return str
end
return ""
end
-- ======================================================================
-- SECURITY FUNCTIONS
-- ======================================================================
-- Basic XSS prevention
function xss_clean(str)
if type(str) ~= "string" then
return str
end
-- Convert problematic characters to entities
local result = html_special_chars(str)
-- Remove JavaScript event handlers
result = result:gsub("on%w+%s*=", "")
-- Remove JavaScript protocol
result = result:gsub("javascript:", "")
-- Remove CSS expression
result = result:gsub("expression%s*%(", "")
return result
end
-- Base64 encode
function base64_encode(str)
if type(str) ~= "string" then
return str
end
return __base64_encode(str)
end
-- Base64 decode
function base64_decode(str)
if type(str) ~= "string" then
return str
end
return __base64_decode(str)
end

View File

@ -1,430 +0,0 @@
package runner
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"Moonshark/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// ModuleConfig holds configuration for Lua's module loading system
type ModuleConfig struct {
ScriptDir string // Base directory for script being executed
LibDirs []string // Additional library directories
}
// ModuleLoader manages module loading and caching
type ModuleLoader struct {
config *ModuleConfig
pathCache map[string]string // Cache module paths for fast lookups
bytecodeCache map[string][]byte // Cache of compiled bytecode
debug bool
mu sync.RWMutex
}
// NewModuleLoader creates a new module loader
func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
if config == nil {
config = &ModuleConfig{
ScriptDir: "",
LibDirs: []string{},
}
}
return &ModuleLoader{
config: config,
pathCache: make(map[string]string),
bytecodeCache: make(map[string][]byte),
debug: false,
}
}
// EnableDebug turns on debug logging
func (l *ModuleLoader) EnableDebug() {
l.debug = true
}
// SetScriptDir sets the script directory
func (l *ModuleLoader) SetScriptDir(dir string) {
l.mu.Lock()
defer l.mu.Unlock()
l.config.ScriptDir = dir
}
// debugLog logs a message if debug mode is enabled
func (l *ModuleLoader) debugLog(format string, args ...interface{}) {
if l.debug {
logger.Debug("ModuleLoader "+format, args...)
}
}
// SetupRequire configures the require system in a Lua state
func (l *ModuleLoader) SetupRequire(state *luajit.State) error {
l.mu.RLock()
defer l.mu.RUnlock()
// Initialize our module registry in Lua
err := state.DoString(`
-- Initialize global module registry
__module_paths = {}
__module_bytecode = {}
__ready_modules = {}
-- Create module preload table
package.preload = package.preload or {}
`)
if err != nil {
return err
}
// Set up package.path based on search paths
paths := l.getSearchPaths()
pathStr := strings.Join(paths, ";")
escapedPathStr := escapeLuaString(pathStr)
return state.DoString(`package.path = "` + escapedPathStr + `"`)
}
// getSearchPaths returns a list of Lua search paths
func (l *ModuleLoader) getSearchPaths() []string {
absPaths := []string{}
seen := map[string]bool{}
// Add script directory (highest priority)
if l.config.ScriptDir != "" {
absPath, err := filepath.Abs(l.config.ScriptDir)
if err == nil && !seen[absPath] {
absPaths = append(absPaths, filepath.Join(absPath, "?.lua"))
seen[absPath] = true
}
}
// Add lib directories
for _, dir := range l.config.LibDirs {
if dir == "" {
continue
}
absPath, err := filepath.Abs(dir)
if err == nil && !seen[absPath] {
absPaths = append(absPaths, filepath.Join(absPath, "?.lua"))
seen[absPath] = true
}
}
return absPaths
}
// PreloadModules preloads modules from library directories
func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
l.mu.Lock()
defer l.mu.Unlock()
// Reset caches
l.pathCache = make(map[string]string)
l.bytecodeCache = make(map[string][]byte)
// Reset module registry in Lua
if err := state.DoString(`
-- Reset module registry
__module_paths = {}
__module_bytecode = {}
__ready_modules = {}
-- Clear non-core modules from package.loaded
local core_modules = {
string = true, table = true, math = true, os = true,
package = true, io = true, coroutine = true, debug = true, _G = true
}
for name in pairs(package.loaded) do
if not core_modules[name] then
package.loaded[name] = nil
end
end
-- Reset preload table
package.preload = {}
`); err != nil {
return err
}
// Scan and preload modules from all library directories
for _, dir := range l.config.LibDirs {
if dir == "" {
continue
}
absDir, err := filepath.Abs(dir)
if err != nil {
continue
}
l.debugLog("Scanning directory: %s", absDir)
// Find all Lua files
err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
return nil
}
// Get module name from path
relPath, err := filepath.Rel(absDir, path)
if err != nil || strings.HasPrefix(relPath, "..") {
return nil
}
// Convert path to module name
modName := strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
l.debugLog("Found module: %s at %s", modName, path)
// Register in our caches
l.pathCache[modName] = path
// Load file content
content, err := os.ReadFile(path)
if err != nil {
l.debugLog("Failed to read module file: %v", err)
return nil
}
// Compile to bytecode
bytecode, err := state.CompileBytecode(string(content), path)
if err != nil {
l.debugLog("Failed to compile module: %v", err)
return nil
}
// Cache bytecode
l.bytecodeCache[modName] = bytecode
// Register in Lua - store path info
escapedPath := escapeLuaString(path)
escapedName := escapeLuaString(modName)
if err := state.DoString(`__module_paths["` + escapedName + `"] = "` + escapedPath + `"`); err != nil {
return nil
}
// Load bytecode and register in package.preload properly
if err := state.LoadBytecode(bytecode, path); err != nil {
return nil
}
// Store the function in package.preload - the function is on the stack
state.GetGlobal("package")
state.GetField(-1, "preload")
state.PushString(modName)
state.PushCopy(-4) // Copy the compiled function
state.SetTable(-3) // preload[modName] = function
state.Pop(2) // Pop package and preload tables
// Mark as ready
if err := state.DoString(`__ready_modules["` + escapedName + `"] = true`); err != nil {
state.Pop(1) // Remove the function from stack
return nil
}
state.Pop(1) // Remove the function from stack
return nil
})
if err != nil {
return err
}
}
// Install optimized require implementation
return state.DoString(`
-- Setup environment-aware require function
function __setup_require(env)
-- Create require function specific to this environment
env.require = function(modname)
-- Check if already loaded
if package.loaded[modname] then
return package.loaded[modname]
end
-- Check preloaded modules
if __ready_modules[modname] then
local loader = package.preload[modname]
if loader then
-- Set environment for loader
setfenv(loader, env)
-- Execute and store result
local result = loader()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
end
-- Direct file load as fallback
if __module_paths[modname] then
local path = __module_paths[modname]
local chunk, err = loadfile(path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
end
-- Full path search as last resort
local errors = {}
for path in package.path:gmatch("[^;]+") do
local file_path = path:gsub("?", modname:gsub("%.", "/"))
local chunk, err = loadfile(file_path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
table.insert(errors, "\tno file '" .. file_path .. "'")
end
error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2)
end
return env
end
`)
}
// GetModuleByPath finds the module name for a file path
func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
l.mu.RLock()
defer l.mu.RUnlock()
// Convert to absolute path for consistent comparison
absPath, err := filepath.Abs(path)
if err != nil {
absPath = filepath.Clean(path)
}
// Try direct lookup from cache with absolute path
for modName, modPath := range l.pathCache {
if modPath == absPath {
return modName, true
}
}
// Try to construct module name from lib dirs
for _, dir := range l.config.LibDirs {
absDir, err := filepath.Abs(dir)
if err != nil {
continue
}
// Check if the file is under this lib directory
relPath, err := filepath.Rel(absDir, absPath)
if err != nil || strings.HasPrefix(relPath, "..") {
continue
}
if strings.HasSuffix(relPath, ".lua") {
modName := strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
l.debugLog("Found module %s for path %s", modName, path)
return modName, true
}
}
l.debugLog("No module found for path %s", path)
return "", false
}
// RefreshModule recompiles and updates a specific module
func (l *ModuleLoader) RefreshModule(state *luajit.State, moduleName string) error {
l.mu.Lock()
defer l.mu.Unlock()
// Get module path
path, exists := l.pathCache[moduleName]
if !exists {
l.debugLog("Module not found in cache: %s", moduleName)
return fmt.Errorf("module %s not found", moduleName)
}
l.debugLog("Refreshing module: %s at %s", moduleName, path)
// Read updated file content
content, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read module file: %w", err)
}
// Recompile to bytecode
bytecode, err := state.CompileBytecode(string(content), path)
if err != nil {
return fmt.Errorf("failed to compile module: %w", err)
}
// Update bytecode cache
l.bytecodeCache[moduleName] = bytecode
// Load new bytecode
if err := state.LoadBytecode(bytecode, path); err != nil {
return fmt.Errorf("failed to load bytecode: %w", err)
}
// Update package.preload with new function (function is on stack)
state.GetGlobal("package")
state.GetField(-1, "preload")
state.PushString(moduleName)
state.PushCopy(-4) // Copy the new compiled function
state.SetTable(-3) // preload[moduleName] = new_function
state.Pop(2) // Pop package and preload tables
state.Pop(1) // Pop the function
// Clear from package.loaded so it gets reloaded
escapedName := escapeLuaString(moduleName)
if err := state.DoString(`package.loaded["` + escapedName + `"] = nil`); err != nil {
return fmt.Errorf("failed to clear loaded module: %w", err)
}
l.debugLog("Successfully refreshed module: %s", moduleName)
return nil
}
// RefreshModuleByPath refreshes a module by its file path
func (l *ModuleLoader) RefreshModuleByPath(state *luajit.State, filePath string) error {
moduleName, exists := l.GetModuleByPath(filePath)
if !exists {
return fmt.Errorf("no module found for path: %s", filePath)
}
return l.RefreshModule(state, moduleName)
}
// escapeLuaString escapes special characters in a string for Lua
func escapeLuaString(s string) string {
replacer := strings.NewReplacer(
"\\", "\\\\",
"\"", "\\\"",
"\n", "\\n",
"\r", "\\r",
"\t", "\\t",
)
return replacer.Replace(s)
}

View File

@ -1,98 +0,0 @@
package runner
import (
"fmt"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/alexedwards/argon2id"
)
// RegisterPasswordFunctions registers password-related functions in the Lua state
func RegisterPasswordFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__password_hash", passwordHash); err != nil {
return err
}
if err := state.RegisterGoFunction("__password_verify", passwordVerify); err != nil {
return err
}
return nil
}
// passwordHash implements the Argon2id password hashing using alexedwards/argon2id
func passwordHash(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("password_hash error: expected string password")
return 1
}
password := state.ToString(1)
params := &argon2id.Params{
Memory: 128 * 1024,
Iterations: 4,
Parallelism: 4,
SaltLength: 16,
KeyLength: 32,
}
if state.IsTable(2) {
state.GetField(2, "memory")
if state.IsNumber(-1) {
params.Memory = max(uint32(state.ToNumber(-1)), 8*1024)
}
state.Pop(1)
state.GetField(2, "iterations")
if state.IsNumber(-1) {
params.Iterations = max(uint32(state.ToNumber(-1)), 1)
}
state.Pop(1)
state.GetField(2, "parallelism")
if state.IsNumber(-1) {
params.Parallelism = max(uint8(state.ToNumber(-1)), 1)
}
state.Pop(1)
state.GetField(2, "salt_length")
if state.IsNumber(-1) {
params.SaltLength = max(uint32(state.ToNumber(-1)), 8)
}
state.Pop(1)
state.GetField(2, "key_length")
if state.IsNumber(-1) {
params.KeyLength = max(uint32(state.ToNumber(-1)), 16)
}
state.Pop(1)
}
hash, err := argon2id.CreateHash(password, params)
if err != nil {
state.PushString(fmt.Sprintf("password_hash error: %v", err))
return 1
}
state.PushString(hash)
return 1
}
// passwordVerify verifies a password against a hash
func passwordVerify(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) {
state.PushBoolean(false)
return 1
}
password := state.ToString(1)
hash := state.ToString(2)
match, err := argon2id.ComparePasswordAndHash(password, hash)
if err != nil {
state.PushBoolean(false)
return 1
}
state.PushBoolean(match)
return 1
}

View File

@ -1,56 +0,0 @@
package runner
import (
"sync"
"github.com/valyala/fasthttp"
)
// Response represents a unified response from script execution
type Response struct {
// Basic properties
Body any // Body content (any type)
Metadata map[string]any // Additional metadata
// HTTP specific properties
Status int // HTTP status code
Headers map[string]string // HTTP headers
Cookies []*fasthttp.Cookie // HTTP cookies
// Session information
SessionData map[string]any
}
// Response pool to reduce allocations
var responsePool = sync.Pool{
New: func() any {
return &Response{
Status: 200,
Headers: make(map[string]string, 8),
Metadata: make(map[string]any, 8),
Cookies: make([]*fasthttp.Cookie, 0, 4),
SessionData: make(map[string]any, 8),
}
},
}
// NewResponse creates a new response object from the pool
func NewResponse() *Response {
return responsePool.Get().(*Response)
}
// Release returns a response to the pool after cleaning it
func ReleaseResponse(resp *Response) {
if resp == nil {
return
}
resp.Body = nil
resp.Status = 200
resp.Headers = make(map[string]string, 8)
resp.Metadata = make(map[string]any, 8)
resp.Cookies = resp.Cookies[:0]
resp.SessionData = make(map[string]any, 8)
responsePool.Put(resp)
}

View File

@ -1,566 +0,0 @@
package runner
import (
"Moonshark/utils/color"
"Moonshark/utils/logger"
"context"
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Common errors
var (
ErrRunnerClosed = errors.New("lua runner is closed")
ErrInitFailed = errors.New("initialization failed")
ErrStateNotReady = errors.New("lua state not ready")
ErrTimeout = errors.New("operation timed out")
)
// RunnerOption defines a functional option for configuring the Runner
type RunnerOption func(*Runner)
// State wraps a Lua state with its sandbox
type State struct {
L *luajit.State // The Lua state
sandbox *Sandbox // Associated sandbox
index int // Index for debugging
inUse atomic.Bool // Whether the state is currently in use
}
// Runner runs Lua scripts using a pool of Lua states
type Runner struct {
states []*State // All states managed by this runner
statePool chan int // Pool of available state indexes
poolSize int // Size of the state pool
moduleLoader *ModuleLoader // Module loader
dataDir string // Data directory for SQLite databases
fsDir string // Virtual filesystem directory
isRunning atomic.Bool // Whether the runner is active
mu sync.RWMutex // Mutex for thread safety
scriptDir string // Current script directory
}
// WithPoolSize sets the state pool size
func WithPoolSize(size int) RunnerOption {
return func(r *Runner) {
if size > 0 {
r.poolSize = size
}
}
}
// WithLibDirs sets additional library directories
func WithLibDirs(dirs ...string) RunnerOption {
return func(r *Runner) {
if r.moduleLoader == nil {
r.moduleLoader = NewModuleLoader(&ModuleConfig{
LibDirs: dirs,
})
} else {
r.moduleLoader.config.LibDirs = dirs
}
}
}
// WithDataDir sets the data directory for SQLite databases
func WithDataDir(dataDir string) RunnerOption {
return func(r *Runner) {
if dataDir != "" {
r.dataDir = dataDir
}
}
}
// WithFsDir sets the virtual filesystem directory
func WithFsDir(fsDir string) RunnerOption {
return func(r *Runner) {
if fsDir != "" {
r.fsDir = fsDir
}
}
}
// NewRunner creates a new Runner with a pool of states
func NewRunner(options ...RunnerOption) (*Runner, error) {
// Default configuration
runner := &Runner{
poolSize: runtime.GOMAXPROCS(0),
dataDir: "data",
fsDir: "fs",
}
// Apply options
for _, opt := range options {
opt(runner)
}
// Set up module loader if not already initialized
if runner.moduleLoader == nil {
config := &ModuleConfig{
ScriptDir: "",
LibDirs: []string{},
}
runner.moduleLoader = NewModuleLoader(config)
}
InitSQLite(runner.dataDir)
InitFS(runner.fsDir)
SetSQLitePoolSize(runner.poolSize)
// Initialize states and pool
runner.states = make([]*State, runner.poolSize)
runner.statePool = make(chan int, runner.poolSize)
// Create and initialize all states
if err := runner.initializeStates(); err != nil {
CleanupSQLite()
runner.Close()
return nil, err
}
runner.isRunning.Store(true)
return runner, nil
}
// initializeStates creates and initializes all states in the pool
func (r *Runner) initializeStates() error {
logger.Info("[LuaRunner] Creating %s states...", color.Apply(strconv.Itoa(r.poolSize), color.Yellow))
for i := range r.poolSize {
state, err := r.createState(i)
if err != nil {
return err
}
r.states[i] = state
r.statePool <- i // Add index to the pool
}
return nil
}
// createState initializes a new Lua state
func (r *Runner) createState(index int) (*State, error) {
verbose := index == 0
if verbose {
logger.Debug("Creating Lua state %d", index)
}
L := luajit.New()
if L == nil {
return nil, errors.New("failed to create Lua state")
}
sb := NewSandbox()
// Set up sandbox
if err := sb.Setup(L, verbose); err != nil {
L.Cleanup()
L.Close()
return nil, ErrInitFailed
}
// Set up module loader
if err := r.moduleLoader.SetupRequire(L); err != nil {
L.Cleanup()
L.Close()
return nil, ErrInitFailed
}
// Preload modules
if err := r.moduleLoader.PreloadModules(L); err != nil {
L.Cleanup()
L.Close()
return nil, errors.New("failed to preload modules")
}
if verbose {
logger.Debug("Lua state %d initialized successfully", index)
}
return &State{
L: L,
sandbox: sb,
index: index,
}, nil
}
// Execute runs a script in a sandbox with context
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
if !r.isRunning.Load() {
return nil, ErrRunnerClosed
}
// Set script directory if provided
if scriptPath != "" {
r.mu.Lock()
r.scriptDir = filepath.Dir(scriptPath)
r.moduleLoader.SetScriptDir(r.scriptDir)
r.mu.Unlock()
}
// Get a state from the pool
var stateIndex int
select {
case stateIndex = <-r.statePool:
// Got a state
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(1 * time.Second):
return nil, ErrTimeout
}
state := r.states[stateIndex]
if state == nil {
r.statePool <- stateIndex
return nil, ErrStateNotReady
}
// Use atomic operations
state.inUse.Store(true)
defer func() {
state.inUse.Store(false)
if r.isRunning.Load() {
select {
case r.statePool <- stateIndex:
default:
// Pool is full or closed, state will be cleaned up by Close()
}
}
}()
// Execute in sandbox
response, err := state.sandbox.Execute(state.L, bytecode, execCtx)
if err != nil {
return nil, err
}
return response, nil
}
// Run executes a Lua script with immediate context
func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
return r.Execute(context.Background(), bytecode, execCtx, scriptPath)
}
// Close gracefully shuts down the Runner
func (r *Runner) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.isRunning.Load() {
return ErrRunnerClosed
}
r.isRunning.Store(false)
// Drain all states from the pool
for {
select {
case <-r.statePool:
default:
goto waitForInUse
}
}
waitForInUse:
// Wait for in-use states to finish (with timeout)
timeout := time.Now().Add(10 * time.Second)
for {
allIdle := true
for _, state := range r.states {
if state != nil && state.inUse.Load() {
allIdle = false
break
}
}
if allIdle {
break
}
if time.Now().After(timeout) {
logger.Warning("Timeout waiting for states to finish during shutdown, forcing close")
break
}
time.Sleep(10 * time.Millisecond)
}
// Now safely close all states
for i, state := range r.states {
if state != nil {
if state.inUse.Load() {
logger.Warning("Force closing state %d that is still in use", i)
}
state.L.Cleanup()
state.L.Close()
r.states[i] = nil
}
}
CleanupFS()
CleanupSQLite()
logger.Debug("Runner closed")
return nil
}
// RefreshStates rebuilds all states in the pool
func (r *Runner) RefreshStates() error {
r.mu.Lock()
defer r.mu.Unlock()
if !r.isRunning.Load() {
return ErrRunnerClosed
}
logger.Info("Runner is refreshing all states...")
// Drain all states from the pool
for {
select {
case <-r.statePool:
default:
goto waitForInUse
}
}
waitForInUse:
// Wait for in-use states to finish (with timeout)
timeout := time.Now().Add(10 * time.Second)
for {
allIdle := true
for _, state := range r.states {
if state != nil && state.inUse.Load() {
allIdle = false
break
}
}
if allIdle {
break
}
if time.Now().After(timeout) {
logger.Warning("Timeout waiting for states to finish, forcing refresh")
break
}
time.Sleep(10 * time.Millisecond)
}
// Now safely destroy all states
for i, state := range r.states {
if state != nil {
if state.inUse.Load() {
logger.Warning("Force closing state %d that is still in use", i)
}
state.L.Cleanup()
state.L.Close()
r.states[i] = nil
}
}
// Reinitialize all states
if err := r.initializeStates(); err != nil {
return err
}
logger.Debug("All states refreshed successfully")
return nil
}
// NotifyFileChanged alerts the runner about file changes
func (r *Runner) NotifyFileChanged(filePath string) bool {
logger.Debug("Runner notified of file change: %s", filePath)
module, isModule := r.moduleLoader.GetModuleByPath(filePath)
if isModule {
logger.Debug("Refreshing module: %s", module)
return r.RefreshModule(module)
}
logger.Debug("File change noted but no refresh needed: %s", filePath)
return true
}
// RefreshModule refreshes a specific module across all states
func (r *Runner) RefreshModule(moduleName string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
if !r.isRunning.Load() {
return false
}
logger.Debug("Refreshing module: %s", moduleName)
success := true
for _, state := range r.states {
if state == nil || state.inUse.Load() {
continue
}
// Use the enhanced module loader refresh
if err := r.moduleLoader.RefreshModule(state.L, moduleName); err != nil {
success = false
logger.Debug("Failed to refresh module %s in state %d: %v", moduleName, state.index, err)
}
}
if success {
logger.Debug("Successfully refreshed module: %s", moduleName)
}
return success
}
// RefreshModuleByPath refreshes a module by its file path
func (r *Runner) RefreshModuleByPath(filePath string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
if !r.isRunning.Load() {
return false
}
logger.Debug("Refreshing module by path: %s", filePath)
success := true
for _, state := range r.states {
if state == nil || state.inUse.Load() {
continue
}
// Use the enhanced module loader refresh by path
if err := r.moduleLoader.RefreshModuleByPath(state.L, filePath); err != nil {
success = false
logger.Debug("Failed to refresh module at %s in state %d: %v", filePath, state.index, err)
}
}
return success
}
// GetStateCount returns the number of initialized states
func (r *Runner) GetStateCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, state := range r.states {
if state != nil {
count++
}
}
return count
}
// GetActiveStateCount returns the number of states currently in use
func (r *Runner) GetActiveStateCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, state := range r.states {
if state != nil && state.inUse.Load() {
count++
}
}
return count
}
// RunScriptFile loads, compiles and executes a Lua script file
func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
if !r.isRunning.Load() {
return nil, ErrRunnerClosed
}
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return nil, fmt.Errorf("script file not found: %s", filePath)
}
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
absPath, err := filepath.Abs(filePath)
if err != nil {
return nil, fmt.Errorf("failed to get absolute path: %w", err)
}
scriptDir := filepath.Dir(absPath)
r.mu.Lock()
prevScriptDir := r.scriptDir
r.scriptDir = scriptDir
r.moduleLoader.SetScriptDir(scriptDir)
r.mu.Unlock()
defer func() {
r.mu.Lock()
r.scriptDir = prevScriptDir
r.moduleLoader.SetScriptDir(prevScriptDir)
r.mu.Unlock()
}()
var stateIndex int
select {
case stateIndex = <-r.statePool:
// Got a state
case <-time.After(5 * time.Second):
return nil, ErrTimeout
}
state := r.states[stateIndex]
if state == nil {
r.statePool <- stateIndex
return nil, ErrStateNotReady
}
state.inUse.Store(true)
defer func() {
state.inUse.Store(false)
if r.isRunning.Load() {
select {
case r.statePool <- stateIndex:
// State returned to pool
default:
// Pool is full or closed
}
}
}()
bytecode, err := state.L.CompileBytecode(string(content), filepath.Base(absPath))
if err != nil {
return nil, fmt.Errorf("compilation error: %w", err)
}
ctx := NewContext()
defer ctx.Release()
ctx.Set("_script_path", absPath)
ctx.Set("_script_dir", scriptDir)
response, err := state.sandbox.Execute(state.L, bytecode, ctx)
if err != nil {
return nil, fmt.Errorf("execution error: %w", err)
}
return response, nil
}

View File

@ -1,334 +0,0 @@
package runner
import (
"fmt"
"sync"
"github.com/goccy/go-json"
"github.com/valyala/fasthttp"
"Moonshark/utils/logger"
"maps"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Error represents a simple error string
type Error string
func (e Error) Error() string {
return string(e)
}
// Error types
var (
ErrSandboxNotInitialized = Error("sandbox not initialized")
)
// Sandbox provides a secure execution environment for Lua scripts
type Sandbox struct {
modules map[string]any
debug bool
mu sync.RWMutex
}
// NewSandbox creates a new sandbox environment
func NewSandbox() *Sandbox {
return &Sandbox{
modules: make(map[string]any, 8),
debug: false,
}
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) {
s.mu.Lock()
defer s.mu.Unlock()
s.modules[name] = module
logger.Debug("Added module: %s", name)
}
// Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State, verbose bool) error {
if verbose {
logger.Debug("Setting up sandbox...")
}
if err := loadSandboxIntoState(state, verbose); err != nil {
logger.Error("Failed to load sandbox: %v", err)
return err
}
if err := s.registerCoreFunctions(state); err != nil {
logger.Error("Failed to register core functions: %v", err)
return err
}
s.mu.RLock()
for name, module := range s.modules {
logger.Debug("Registering module: %s", name)
if err := state.PushValue(module); err != nil {
s.mu.RUnlock()
logger.Error("Failed to register module %s: %v", name, err)
return err
}
state.SetGlobal(name)
}
s.mu.RUnlock()
if verbose {
logger.Debug("Sandbox setup complete")
}
return nil
}
// registerCoreFunctions registers all built-in functions in the Lua state
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
return err
}
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
return err
}
if err := state.RegisterGoFunction("__json_marshal", jsonMarshal); err != nil {
return err
}
if err := state.RegisterGoFunction("__json_unmarshal", jsonUnmarshal); err != nil {
return err
}
if err := RegisterSQLiteFunctions(state); err != nil {
return err
}
if err := RegisterFSFunctions(state); err != nil {
return err
}
if err := RegisterPasswordFunctions(state); err != nil {
return err
}
if err := RegisterUtilFunctions(state); err != nil {
return err
}
if err := RegisterCryptoFunctions(state); err != nil {
return err
}
if err := RegisterEnvFunctions(state); err != nil {
return err
}
return nil
}
// Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
state.GetGlobal("__execute_script")
if !state.IsFunction(-1) {
state.Pop(1)
return nil, ErrSandboxNotInitialized
}
if err := state.LoadBytecode(bytecode, "script"); err != nil {
state.Pop(1) // Pop the __execute_script function
return nil, fmt.Errorf("failed to load script: %w", err)
}
// Push context values
if err := state.PushTable(ctx.Values); err != nil {
state.Pop(2) // Pop bytecode and __execute_script
return nil, err
}
// Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil {
return nil, fmt.Errorf("script execution failed: %w", err)
}
body, err := state.ToValue(-1)
state.Pop(1)
response := NewResponse()
if err == nil {
response.Body = body
}
extractHTTPResponseData(state, response)
return response, nil
}
// extractResponseData pulls response info from the Lua state
func extractHTTPResponseData(state *luajit.State, response *Response) {
state.GetGlobal("__http_response")
if !state.IsTable(-1) {
state.Pop(1)
return
}
// Extract status
state.GetField(-1, "status")
if state.IsNumber(-1) {
response.Status = int(state.ToNumber(-1))
}
state.Pop(1)
// Extract headers
state.GetField(-1, "headers")
if state.IsTable(-1) {
state.PushNil() // Start iteration
for state.Next(-2) {
if state.IsString(-2) && state.IsString(-1) {
key := state.ToString(-2)
value := state.ToString(-1)
response.Headers[key] = value
}
state.Pop(1)
}
}
state.Pop(1)
// Extract cookies
state.GetField(-1, "cookies")
if state.IsTable(-1) {
length := state.GetTableLength(-1)
for i := 1; i <= length; i++ {
state.PushNumber(float64(i))
state.GetTable(-2)
if state.IsTable(-1) {
extractCookie(state, response)
}
state.Pop(1)
}
}
state.Pop(1)
// Extract metadata
state.GetField(-1, "metadata")
if state.IsTable(-1) {
table, err := state.ToTable(-1)
if err == nil {
maps.Copy(response.Metadata, table)
}
}
state.Pop(1)
// Extract session data
state.GetField(-1, "session")
if state.IsTable(-1) {
table, err := state.ToTable(-1)
if err == nil {
maps.Copy(response.SessionData, table)
}
}
state.Pop(1)
state.Pop(1) // Pop __http_response
}
// extractCookie pulls cookie data from the current table on the stack
func extractCookie(state *luajit.State, response *Response) {
cookie := fasthttp.AcquireCookie()
// Get name (required)
state.GetField(-1, "name")
if !state.IsString(-1) {
state.Pop(1)
fasthttp.ReleaseCookie(cookie)
return
}
cookie.SetKey(state.ToString(-1))
state.Pop(1)
// Get value
state.GetField(-1, "value")
if state.IsString(-1) {
cookie.SetValue(state.ToString(-1))
}
state.Pop(1)
// Get path
state.GetField(-1, "path")
if state.IsString(-1) {
cookie.SetPath(state.ToString(-1))
} else {
cookie.SetPath("/") // Default
}
state.Pop(1)
// Get domain
state.GetField(-1, "domain")
if state.IsString(-1) {
cookie.SetDomain(state.ToString(-1))
}
state.Pop(1)
// Get other parameters
state.GetField(-1, "http_only")
if state.IsBoolean(-1) {
cookie.SetHTTPOnly(state.ToBoolean(-1))
}
state.Pop(1)
state.GetField(-1, "secure")
if state.IsBoolean(-1) {
cookie.SetSecure(state.ToBoolean(-1))
}
state.Pop(1)
state.GetField(-1, "max_age")
if state.IsNumber(-1) {
cookie.SetMaxAge(int(state.ToNumber(-1)))
}
state.Pop(1)
response.Cookies = append(response.Cookies, cookie)
}
// jsonMarshal converts a Lua value to a JSON string
func jsonMarshal(state *luajit.State) int {
value, err := state.ToValue(1)
if err != nil {
state.PushString(fmt.Sprintf("json marshal error: %v", err))
return 1
}
bytes, err := json.Marshal(value)
if err != nil {
state.PushString(fmt.Sprintf("json marshal error: %v", err))
return 1
}
state.PushString(string(bytes))
return 1
}
// jsonUnmarshal converts a JSON string to a Lua value
func jsonUnmarshal(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("json unmarshal error: expected string")
return 1
}
jsonStr := state.ToString(1)
var value any
err := json.Unmarshal([]byte(jsonStr), &value)
if err != nil {
state.PushString(fmt.Sprintf("json unmarshal error: %v", err))
return 1
}
if err := state.PushValue(value); err != nil {
state.PushString(fmt.Sprintf("json unmarshal error: %v", err))
return 1
}
return 1
}

View File

@ -1,427 +0,0 @@
package runner
import (
"context"
"fmt"
"path/filepath"
"strings"
"sync"
"time"
sqlite "zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
"Moonshark/utils/color"
"Moonshark/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
var (
dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex
dataDir string
poolSize = 8 // Default, will be set to match runner pool size
connTimeout = 5 * time.Second
)
// InitSQLite initializes the SQLite subsystem
func InitSQLite(dir string) {
dataDir = dir
logger.Info("SQLite is g2g! %s", color.Apply(dir, color.Yellow))
}
// SetSQLitePoolSize sets the pool size to match the runner pool size
func SetSQLitePoolSize(size int) {
if size > 0 {
poolSize = size
}
}
// CleanupSQLite closes all database connections
func CleanupSQLite() {
poolsMu.Lock()
defer poolsMu.Unlock()
for name, pool := range dbPools {
if err := pool.Close(); err != nil {
logger.Error("Failed to close database %s: %v", name, err)
}
}
dbPools = make(map[string]*sqlitex.Pool)
logger.Debug("SQLite connections closed")
}
// getPool returns a connection pool for the database
func getPool(dbName string) (*sqlitex.Pool, error) {
// Validate database name
dbName = filepath.Base(dbName)
if dbName == "" || dbName[0] == '.' {
return nil, fmt.Errorf("invalid database name")
}
// Check for existing pool
poolsMu.RLock()
pool, exists := dbPools[dbName]
if exists {
poolsMu.RUnlock()
return pool, nil
}
poolsMu.RUnlock()
// Create new pool under write lock
poolsMu.Lock()
defer poolsMu.Unlock()
// Double-check if a pool was created while waiting for lock
if pool, exists = dbPools[dbName]; exists {
return pool, nil
}
// Create new pool with proper size
dbPath := filepath.Join(dataDir, dbName+".db")
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
PoolSize: poolSize,
PrepareConn: func(conn *sqlite.Conn) error {
// Execute PRAGMA statements individually
pragmas := []string{
"PRAGMA journal_mode = WAL",
"PRAGMA synchronous = NORMAL",
"PRAGMA cache_size = 1000",
"PRAGMA foreign_keys = ON",
"PRAGMA temp_store = MEMORY",
}
for _, pragma := range pragmas {
if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil {
return err
}
}
return nil
},
})
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
dbPools[dbName] = pool
logger.Debug("Created SQLite pool for %s (size: %d)", dbName, poolSize)
return pool, nil
}
// sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.query: requires database name and query")
return -1
}
dbName := state.ToString(1)
query := state.ToString(2)
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions
rows := make([]map[string]any, 0, 16)
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
}
// Set up result function
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
row := make(map[string]any)
colCount := stmt.ColumnCount()
for i := range colCount {
colName := stmt.ColumnName(i)
switch stmt.ColumnType(i) {
case sqlite.TypeInteger:
row[colName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat:
row[colName] = stmt.ColumnFloat(i)
case sqlite.TypeText:
row[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
if blobSize > 0 {
buf := make([]byte, blobSize)
row[colName] = stmt.ColumnBytes(i, buf)
} else {
row[colName] = []byte{}
}
case sqlite.TypeNull:
row[colName] = nil
}
}
rows = append(rows, row)
return nil
}
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Create result table
state.NewTable()
for i, row := range rows {
state.PushNumber(float64(i + 1))
if err := state.PushTable(row); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
state.SetTable(-3)
}
return 1
}
// sqlExec executes a SQL statement without returning results
func sqlExec(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.exec: requires database name and query")
return -1
}
dbName := state.ToString(1)
query := state.ToString(2)
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
// Fast path for multi-statement scripts
if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
}
// Fast path for simple queries with no parameters
if !hasParams {
if err := sqlitex.Execute(conn, query, nil); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
}
// Create execution options for parameterized query
var execOpts sqlitex.ExecOptions
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Return affected rows
state.PushNumber(float64(conn.Changes()))
return 1
}
// setupParams configures execution options with parameters from Lua
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
if state.IsTable(paramIndex) {
params, err := state.ToTable(paramIndex)
if err != nil {
return fmt.Errorf("invalid parameters: %w", err)
}
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters from stack
count := state.GetTop() - 2
args := make([]any, count)
for i := range count {
idx := i + 3
val, err := state.ToValue(idx)
if err != nil {
return fmt.Errorf("invalid parameter %d: %w", i+1, err)
}
args[i] = val
}
execOpts.Args = args
}
return nil
}
// sqlGetOne executes a query and returns only the first row
func sqlGetOne(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.get_one: requires database name and query")
return -1
}
dbName := state.ToString(1)
query := state.ToString(2)
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions
var result map[string]any
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
}
// Set up result function to get only first row
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
if result != nil {
return nil // Already got first row
}
result = make(map[string]any)
colCount := stmt.ColumnCount()
for i := range colCount {
colName := stmt.ColumnName(i)
switch stmt.ColumnType(i) {
case sqlite.TypeInteger:
result[colName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat:
result[colName] = stmt.ColumnFloat(i)
case sqlite.TypeText:
result[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
if blobSize > 0 {
buf := make([]byte, blobSize)
result[colName] = stmt.ColumnBytes(i, buf)
} else {
result[colName] = []byte{}
}
case sqlite.TypeNull:
result[colName] = nil
}
}
return nil
}
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
// Return result or nil if no rows
if result == nil {
state.PushNil()
} else {
if err := state.PushTable(result); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
}
return 1
}
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err
}
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
return err
}
if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil {
return err
}
return nil
}

View File

@ -1,116 +0,0 @@
package runner
import (
"encoding/base64"
"html"
"strings"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// RegisterUtilFunctions registers utility functions with the Lua state
func RegisterUtilFunctions(state *luajit.State) error {
// HTML special chars
if err := state.RegisterGoFunction("__html_special_chars", htmlSpecialChars); err != nil {
return err
}
// HTML entities
if err := state.RegisterGoFunction("__html_entities", htmlEntities); err != nil {
return err
}
// Base64 encode
if err := state.RegisterGoFunction("__base64_encode", base64Encode); err != nil {
return err
}
// Base64 decode
if err := state.RegisterGoFunction("__base64_decode", base64Decode); err != nil {
return err
}
return nil
}
// htmlSpecialChars converts special characters to HTML entities
func htmlSpecialChars(state *luajit.State) int {
if !state.IsString(1) {
state.PushNil()
return 1
}
input := state.ToString(1)
result := html.EscapeString(input)
state.PushString(result)
return 1
}
// htmlEntities is a more comprehensive version of htmlSpecialChars
func htmlEntities(state *luajit.State) int {
if !state.IsString(1) {
state.PushNil()
return 1
}
input := state.ToString(1)
// First use HTML escape for standard entities
result := html.EscapeString(input)
// Additional entities beyond what html.EscapeString handles
replacements := map[string]string{
"©": "&copy;",
"®": "&reg;",
"™": "&trade;",
"€": "&euro;",
"£": "&pound;",
"¥": "&yen;",
"—": "&mdash;",
"": "&ndash;",
"…": "&hellip;",
"•": "&bull;",
"°": "&deg;",
"±": "&plusmn;",
"¼": "&frac14;",
"½": "&frac12;",
"¾": "&frac34;",
}
for char, entity := range replacements {
result = strings.ReplaceAll(result, char, entity)
}
state.PushString(result)
return 1
}
// base64Encode encodes a string to base64
func base64Encode(state *luajit.State) int {
if !state.IsString(1) {
state.PushNil()
return 1
}
input := state.ToString(1)
result := base64.StdEncoding.EncodeToString([]byte(input))
state.PushString(result)
return 1
}
// base64Decode decodes a base64 string
func base64Decode(state *luajit.State) int {
if !state.IsString(1) {
state.PushNil()
return 1
}
input := state.ToString(1)
result, err := base64.StdEncoding.DecodeString(input)
if err != nil {
state.PushNil()
return 1
}
state.PushString(string(result))
return 1
}

View File

@ -1,212 +0,0 @@
package sessions
import (
"sync"
"time"
"github.com/VictoriaMetrics/fastcache"
gonanoid "github.com/matoous/go-nanoid/v2"
"github.com/valyala/fasthttp"
)
const (
DefaultMaxSessions = 10000
DefaultCookieName = "MoonsharkSID"
DefaultCookiePath = "/"
DefaultMaxAge = 86400 // 1 day in seconds
CleanupInterval = 5 * time.Minute
)
// SessionManager handles multiple sessions
type SessionManager struct {
cache *fastcache.Cache
cookieName string
cookiePath string
cookieDomain string
cookieSecure bool
cookieHTTPOnly bool
cookieMaxAge int
cookieMu sync.RWMutex
cleanupTicker *time.Ticker
cleanupDone chan struct{}
}
// NewSessionManager creates a new session manager
func NewSessionManager(maxSessions int) *SessionManager {
if maxSessions <= 0 {
maxSessions = DefaultMaxSessions
}
sm := &SessionManager{
cache: fastcache.New(maxSessions * 4096),
cookieName: DefaultCookieName,
cookiePath: DefaultCookiePath,
cookieHTTPOnly: true,
cookieMaxAge: DefaultMaxAge,
cleanupDone: make(chan struct{}),
}
// Pre-populate session pool
for i := 0; i < 100; i++ {
s := NewSession("", 0)
s.Release()
}
sm.cleanupTicker = time.NewTicker(CleanupInterval)
go sm.cleanupRoutine()
return sm
}
// Stop shuts down the session manager's cleanup routine
func (sm *SessionManager) Stop() {
close(sm.cleanupDone)
}
func (sm *SessionManager) cleanupRoutine() {
for {
select {
case <-sm.cleanupTicker.C:
sm.CleanupExpired()
case <-sm.cleanupDone:
sm.cleanupTicker.Stop()
return
}
}
}
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist
func (sm *SessionManager) GetSession(id string) *Session {
if id != "" {
if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 {
if s, err := Unmarshal(data); err == nil && !s.IsExpired() {
s.UpdateLastUsed()
s.ResetDirty()
return s
}
sm.cache.Del([]byte(id))
}
}
return sm.CreateSession()
}
// CreateSession generates a new session with a unique ID
func (sm *SessionManager) CreateSession() *Session {
id, _ := gonanoid.New()
// Ensure uniqueness (max 3 attempts)
for i := 0; i < 3 && sm.cache.Has([]byte(id)); i++ {
id, _ = gonanoid.New()
}
s := NewSession(id, sm.cookieMaxAge)
if data, err := s.Marshal(); err == nil {
sm.cache.Set([]byte(id), data)
}
s.ResetDirty()
return s
}
// DestroySession removes a session
func (sm *SessionManager) DestroySession(id string) {
if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 {
if s, err := Unmarshal(data); err == nil {
s.Release()
}
}
sm.cache.Del([]byte(id))
}
// CleanupExpired removes all expired sessions
func (sm *SessionManager) CleanupExpired() int {
// fastcache doesn't support iteration
return 0
}
// SetCookieOptions configures cookie parameters
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
sm.cookieMu.Lock()
sm.cookieName = name
sm.cookiePath = path
sm.cookieDomain = domain
sm.cookieSecure = secure
sm.cookieHTTPOnly = httpOnly
sm.cookieMaxAge = maxAge
sm.cookieMu.Unlock()
}
// GetSessionFromRequest extracts the session from a request
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
sm.cookieMu.RLock()
name := sm.cookieName
sm.cookieMu.RUnlock()
if cookie := ctx.Request.Header.Cookie(name); len(cookie) > 0 {
return sm.GetSession(string(cookie))
}
return sm.CreateSession()
}
// ApplySessionCookie adds the session cookie to the response
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
if session.IsDirty() {
if data, err := session.Marshal(); err == nil {
sm.cache.Set([]byte(session.ID), data)
}
session.ResetDirty()
}
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
sm.cookieMu.RLock()
cookie.SetKey(sm.cookieName)
cookie.SetPath(sm.cookiePath)
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
cookie.SetMaxAge(sm.cookieMaxAge)
if sm.cookieDomain != "" {
cookie.SetDomain(sm.cookieDomain)
}
cookie.SetSecure(sm.cookieSecure)
sm.cookieMu.RUnlock()
cookie.SetValue(session.ID)
ctx.Response.Header.SetCookie(cookie)
}
// CookieOptions returns the cookie options for this session manager
func (sm *SessionManager) CookieOptions() map[string]any {
sm.cookieMu.RLock()
defer sm.cookieMu.RUnlock()
return map[string]any{
"name": sm.cookieName,
"path": sm.cookiePath,
"domain": sm.cookieDomain,
"secure": sm.cookieSecure,
"http_only": sm.cookieHTTPOnly,
"max_age": sm.cookieMaxAge,
}
}
// GetCacheStats returns statistics about the session cache
func (sm *SessionManager) GetCacheStats() map[string]uint64 {
if sm == nil || sm.cache == nil {
return map[string]uint64{}
}
var stats fastcache.Stats
sm.cache.UpdateStats(&stats)
return map[string]uint64{
"entries": stats.EntriesCount,
"bytes": stats.BytesSize,
"max_bytes": stats.MaxBytesSize,
"gets": stats.GetCalls,
"sets": stats.SetCalls,
"misses": stats.Misses,
}
}
// GlobalSessionManager is the default session manager instance
var GlobalSessionManager = NewSessionManager(DefaultMaxSessions)

View File

@ -1,434 +0,0 @@
package sessions
import (
"fmt"
"sync"
"time"
"github.com/deneonet/benc"
bstd "github.com/deneonet/benc/std"
)
// Session stores data for a single user session
type Session struct {
ID string
Data map[string]any
CreatedAt time.Time
UpdatedAt time.Time
LastUsed time.Time
Expiry time.Time
dirty bool
}
var (
sessionPool = sync.Pool{
New: func() any {
return &Session{Data: make(map[string]any, 8)}
},
}
bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
)
// NewSession creates a new session with the given ID
func NewSession(id string, maxAge int) *Session {
s := sessionPool.Get().(*Session)
now := time.Now()
*s = Session{
ID: id,
Data: s.Data, // Reuse map
CreatedAt: now,
UpdatedAt: now,
LastUsed: now,
Expiry: now.Add(time.Duration(maxAge) * time.Second),
}
return s
}
// Release returns the session to the pool
func (s *Session) Release() {
for k := range s.Data {
delete(s.Data, k)
}
sessionPool.Put(s)
}
// Get returns a deep copy of a value
func (s *Session) Get(key string) any {
if v, ok := s.Data[key]; ok {
return deepCopy(v)
}
return nil
}
// GetTable returns a value as a table
func (s *Session) GetTable(key string) map[string]any {
if v := s.Get(key); v != nil {
if t, ok := v.(map[string]any); ok {
return t
}
}
return nil
}
// GetAll returns a deep copy of all session data
func (s *Session) GetAll() map[string]any {
copy := make(map[string]any, len(s.Data))
for k, v := range s.Data {
copy[k] = deepCopy(v)
}
return copy
}
// Set stores a value in the session
func (s *Session) Set(key string, value any) {
if existing, ok := s.Data[key]; ok && deepEqual(existing, value) {
return // No change
}
s.Data[key] = value
s.UpdatedAt = time.Now()
s.dirty = true
}
// SetSafe stores a value with validation
func (s *Session) SetSafe(key string, value any) error {
if err := validate(value); err != nil {
return fmt.Errorf("session.SetSafe: %w", err)
}
s.Set(key, value)
return nil
}
// SetTable is a convenience method for setting table data
func (s *Session) SetTable(key string, table map[string]any) error {
return s.SetSafe(key, table)
}
// Delete removes a value from the session
func (s *Session) Delete(key string) {
delete(s.Data, key)
s.UpdatedAt = time.Now()
s.dirty = true
}
// Clear removes all data from the session
func (s *Session) Clear() {
s.Data = make(map[string]any, 8)
s.UpdatedAt = time.Now()
s.dirty = true
}
// IsExpired checks if the session has expired
func (s *Session) IsExpired() bool {
return time.Now().After(s.Expiry)
}
// UpdateLastUsed updates the last used time
func (s *Session) UpdateLastUsed() {
now := time.Now()
if now.Sub(s.LastUsed) > 5*time.Second {
s.LastUsed = now
}
}
// IsDirty returns if the session has unsaved changes
func (s *Session) IsDirty() bool {
return s.dirty
}
// ResetDirty marks the session as clean after saving
func (s *Session) ResetDirty() {
s.dirty = false
}
// SizePlain calculates the size needed to marshal the session
func (s *Session) SizePlain() int {
return bstd.SizeString(s.ID) +
bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) +
bstd.SizeInt64()*4
}
// MarshalPlain serializes the session to binary
func (s *Session) MarshalPlain(n int, b []byte) int {
n = bstd.MarshalString(n, b, s.ID)
n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny)
n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.LastUsed.Unix())
return bstd.MarshalInt64(n, b, s.Expiry.Unix())
}
// UnmarshalPlain deserializes the session from binary
func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) {
var err error
n, s.ID, err = bstd.UnmarshalString(n, b)
if err != nil {
return n, err
}
n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny)
if err != nil {
return n, err
}
var ts int64
for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} {
n, ts, err = bstd.UnmarshalInt64(n, b)
if err != nil {
return n, err
}
*t = time.Unix(ts, 0)
}
return n, nil
}
// Marshal serializes the session using benc
func (s *Session) Marshal() ([]byte, error) {
return bufPool.Marshal(s.SizePlain(), func(b []byte) int {
return s.MarshalPlain(0, b)
})
}
// Unmarshal deserializes a session using benc
func Unmarshal(data []byte) (*Session, error) {
s := sessionPool.Get().(*Session)
if _, err := s.UnmarshalPlain(0, data); err != nil {
s.Release()
return nil, err
}
return s, nil
}
// Type identifiers
const (
typeNull byte = 0
typeString byte = 1
typeInt byte = 2
typeFloat byte = 3
typeBool byte = 4
typeBytes byte = 5
typeTable byte = 6
typeArray byte = 7
)
// sizeAny calculates the size needed for any value
func sizeAny(v any) int {
if v == nil {
return 1
}
size := 1 // type byte
switch v := v.(type) {
case string:
size += bstd.SizeString(v)
case int:
size += bstd.SizeInt64()
case int64:
size += bstd.SizeInt64()
case float64:
size += bstd.SizeFloat64()
case bool:
size += bstd.SizeBool()
case []byte:
size += bstd.SizeBytes(v)
case map[string]any:
size += bstd.SizeMap(v, bstd.SizeString, sizeAny)
case []any:
size += bstd.SizeSlice(v, sizeAny)
default:
size += bstd.SizeString("unknown")
}
return size
}
// marshalAny serializes any value
func marshalAny(n int, b []byte, v any) int {
if v == nil {
b[n] = typeNull
return n + 1
}
switch v := v.(type) {
case string:
b[n] = typeString
return bstd.MarshalString(n+1, b, v)
case int:
b[n] = typeInt
return bstd.MarshalInt64(n+1, b, int64(v))
case int64:
b[n] = typeInt
return bstd.MarshalInt64(n+1, b, v)
case float64:
b[n] = typeFloat
return bstd.MarshalFloat64(n+1, b, v)
case bool:
b[n] = typeBool
return bstd.MarshalBool(n+1, b, v)
case []byte:
b[n] = typeBytes
return bstd.MarshalBytes(n+1, b, v)
case map[string]any:
b[n] = typeTable
return bstd.MarshalMap(n+1, b, v, bstd.MarshalString, marshalAny)
case []any:
b[n] = typeArray
return bstd.MarshalSlice(n+1, b, v, marshalAny)
default:
b[n] = typeString
return bstd.MarshalString(n+1, b, "unknown")
}
}
// unmarshalAny deserializes any value
func unmarshalAny(n int, b []byte) (int, any, error) {
if len(b) <= n {
return n, nil, benc.ErrBufTooSmall
}
switch b[n] {
case typeNull:
return n + 1, nil, nil
case typeString:
return bstd.UnmarshalString(n+1, b)
case typeInt:
n, v, err := bstd.UnmarshalInt64(n+1, b)
return n, v, err
case typeFloat:
return bstd.UnmarshalFloat64(n+1, b)
case typeBool:
return bstd.UnmarshalBool(n+1, b)
case typeBytes:
return bstd.UnmarshalBytesCopied(n+1, b)
case typeTable:
return bstd.UnmarshalMap[string, any](n+1, b, bstd.UnmarshalString, unmarshalAny)
case typeArray:
return bstd.UnmarshalSlice[any](n+1, b, unmarshalAny)
default:
return n + 1, nil, nil
}
}
// deepCopy creates a deep copy of any value
func deepCopy(v any) any {
switch v := v.(type) {
case map[string]any:
cp := make(map[string]any, len(v))
for k, val := range v {
cp[k] = deepCopy(val)
}
return cp
case []any:
cp := make([]any, len(v))
for i, val := range v {
cp[i] = deepCopy(val)
}
return cp
default:
return v
}
}
// validate ensures a value can be safely serialized
func validate(v any) error {
switch v := v.(type) {
case nil, string, int, int64, float64, bool, []byte:
return nil
case map[string]any:
for k, val := range v {
if err := validate(val); err != nil {
return fmt.Errorf("invalid value for key %q: %w", k, err)
}
}
case []any:
for i, val := range v {
if err := validate(val); err != nil {
return fmt.Errorf("invalid value at index %d: %w", i, err)
}
}
default:
return fmt.Errorf("unsupported type: %T", v)
}
return nil
}
// deepEqual efficiently compares two values for deep equality
func deepEqual(a, b any) bool {
if a == b {
return true
}
if a == nil || b == nil {
return false
}
switch va := a.(type) {
case string:
if vb, ok := b.(string); ok {
return va == vb
}
case int:
if vb, ok := b.(int); ok {
return va == vb
}
if vb, ok := b.(int64); ok {
return int64(va) == vb
}
case int64:
if vb, ok := b.(int64); ok {
return va == vb
}
if vb, ok := b.(int); ok {
return va == int64(vb)
}
case float64:
if vb, ok := b.(float64); ok {
return va == vb
}
case bool:
if vb, ok := b.(bool); ok {
return va == vb
}
case []byte:
if vb, ok := b.([]byte); ok {
if len(va) != len(vb) {
return false
}
for i, v := range va {
if v != vb[i] {
return false
}
}
return true
}
case map[string]any:
if vb, ok := b.(map[string]any); ok {
if len(va) != len(vb) {
return false
}
for k, v := range va {
if bv, exists := vb[k]; !exists || !deepEqual(v, bv) {
return false
}
}
return true
}
case []any:
if vb, ok := b.([]any); ok {
if len(va) != len(vb) {
return false
}
for i, v := range va {
if !deepEqual(v, vb[i]) {
return false
}
}
return true
}
}
return false
}
// IsEmpty returns true if the session has no data
func (s *Session) IsEmpty() bool {
return len(s.Data) == 0
}

93
state/bytecode.go Normal file
View File

@ -0,0 +1,93 @@
package state
import (
"crypto/sha256"
"fmt"
"sync"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// BytecodeEntry holds compiled bytecode with metadata
type BytecodeEntry struct {
Bytecode []byte
Name string
Hash [32]byte
}
// Global bytecode cache
var (
bytecodeCache = struct {
sync.RWMutex
entries map[string]*BytecodeEntry
}{
entries: make(map[string]*BytecodeEntry),
}
)
// CompileAndCache compiles code to bytecode and stores it globally
func CompileAndCache(state *luajit.State, code, name string) (*BytecodeEntry, error) {
hash := sha256.Sum256([]byte(code))
cacheKey := fmt.Sprintf("%x", hash)
// Check cache first
bytecodeCache.RLock()
if entry, exists := bytecodeCache.entries[cacheKey]; exists {
bytecodeCache.RUnlock()
return entry, nil
}
bytecodeCache.RUnlock()
// Compile bytecode
bytecode, err := state.CompileBytecode(code, name)
if err != nil {
return nil, err
}
// Store in cache
entry := &BytecodeEntry{
Bytecode: bytecode,
Name: name,
Hash: hash,
}
bytecodeCache.Lock()
bytecodeCache.entries[cacheKey] = entry
bytecodeCache.Unlock()
return entry, nil
}
// GetCached retrieves bytecode from cache by code hash
func GetCached(code string) (*BytecodeEntry, bool) {
hash := sha256.Sum256([]byte(code))
cacheKey := fmt.Sprintf("%x", hash)
bytecodeCache.RLock()
defer bytecodeCache.RUnlock()
entry, exists := bytecodeCache.entries[cacheKey]
return entry, exists
}
// ClearCache removes all cached bytecode entries
func ClearCache() {
bytecodeCache.Lock()
defer bytecodeCache.Unlock()
bytecodeCache.entries = make(map[string]*BytecodeEntry)
}
// CacheStats returns cache statistics
func CacheStats() (int, int64) {
bytecodeCache.RLock()
defer bytecodeCache.RUnlock()
count := len(bytecodeCache.entries)
var totalSize int64
for _, entry := range bytecodeCache.entries {
totalSize += int64(len(entry.Bytecode))
}
return count, totalSize
}

327
state/state.go Normal file
View File

@ -0,0 +1,327 @@
package state
import (
"fmt"
"os"
"path/filepath"
"Moonshark/modules"
"Moonshark/modules/http"
"Moonshark/modules/kv"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// State wraps luajit.State with enhanced functionality
type State struct {
*luajit.State
initialized bool
isWorker bool
scriptDir string
}
// Config holds state initialization options
type Config struct {
OpenLibs bool
InstallModules bool
ScriptDir string
IsWorker bool
}
// DefaultConfig returns default configuration
func DefaultConfig() Config {
return Config{
OpenLibs: true,
InstallModules: true,
IsWorker: false,
}
}
// New creates a new enhanced Lua state
func New(config ...Config) (*State, error) {
cfg := DefaultConfig()
if len(config) > 0 {
cfg = config[0]
}
// Create base state
baseState := luajit.New(cfg.OpenLibs)
if baseState == nil {
return nil, fmt.Errorf("failed to create Lua state")
}
state := &State{
State: baseState,
isWorker: cfg.IsWorker,
scriptDir: cfg.ScriptDir,
}
// Set worker global flag if this is a worker state
if cfg.IsWorker {
state.PushBoolean(true)
state.SetGlobal("__IS_WORKER")
}
// Install module system if requested
if cfg.InstallModules {
if err := state.initializeModules(); err != nil {
state.Close()
return nil, fmt.Errorf("failed to install modules: %w", err)
}
}
// Set script directory if provided
if cfg.ScriptDir != "" {
if err := state.SetScriptDirectory(cfg.ScriptDir); err != nil {
state.Close()
return nil, fmt.Errorf("failed to set script directory: %w", err)
}
}
state.initialized = true
return state, nil
}
// NewFromScript creates a state configured for a specific script file
func NewFromScript(scriptPath string, config ...Config) (*State, error) {
cfg := DefaultConfig()
if len(config) > 0 {
cfg = config[0]
}
// Set script directory from file path
scriptDir := filepath.Dir(scriptPath)
absScriptDir, err := filepath.Abs(scriptDir)
if err != nil {
return nil, fmt.Errorf("failed to get absolute path: %w", err)
}
cfg.ScriptDir = absScriptDir
return New(cfg)
}
// NewWorker creates a new worker state with __IS_WORKER flag set
func NewWorker(config ...Config) (*State, error) {
cfg := DefaultConfig()
if len(config) > 0 {
cfg = config[0]
}
cfg.IsWorker = true
return New(cfg)
}
// NewWorkerFromScript creates a worker state configured for a specific script
func NewWorkerFromScript(scriptPath string, config ...Config) (*State, error) {
cfg := DefaultConfig()
if len(config) > 0 {
cfg = config[0]
}
cfg.IsWorker = true
scriptDir := filepath.Dir(scriptPath)
absScriptDir, err := filepath.Abs(scriptDir)
if err != nil {
return nil, fmt.Errorf("failed to get absolute path: %w", err)
}
cfg.ScriptDir = absScriptDir
return New(cfg)
}
// Store main state initialization for worker replication
var (
mainStateScriptDir string
mainStateScript string
mainStateScriptName string
)
// initializeModules sets up the module system
func (s *State) initializeModules() error {
// Initialize global registry if needed
if modules.Global == nil {
if err := modules.Initialize(); err != nil {
return err
}
}
// Install modules first
if err := modules.Global.InstallInState(s.State); err != nil {
return err
}
return nil
}
// SetupStateCreator sets up the state creator after script is loaded
func (s *State) SetupStateCreator() {
if s.isWorker {
return
}
http.SetStateCreator(func() (*luajit.State, error) {
cfg := DefaultConfig()
cfg.IsWorker = true
cfg.ScriptDir = mainStateScriptDir
workerState, err := New(cfg)
if err != nil {
return nil, err
}
// Execute the same script as main state to get identical environment
if mainStateScript != "" {
if err := workerState.ExecuteString(mainStateScript, mainStateScriptName); err != nil {
workerState.Close()
return nil, fmt.Errorf("failed to execute script in worker: %w", err)
}
}
return workerState.State, nil
})
}
// SetScriptDirectory adds a directory to Lua's package.path
func (s *State) SetScriptDirectory(dir string) error {
packagePath := filepath.Join(dir, "?.lua")
return s.AddPackagePath(packagePath)
}
// ExecuteFile compiles and runs a Lua script file
func (s *State) ExecuteFile(scriptPath string) error {
// Check if file exists
if _, err := os.Stat(scriptPath); os.IsNotExist(err) {
return fmt.Errorf("script file '%s' not found", scriptPath)
}
// Read script content
scriptContent, err := os.ReadFile(scriptPath)
if err != nil {
return fmt.Errorf("failed to read script file '%s': %w", scriptPath, err)
}
// Store for worker replication if this is main state
if !s.isWorker {
mainStateScript = string(scriptContent)
mainStateScriptName = scriptPath
mainStateScriptDir = s.scriptDir
// Set up state creator now that we have the script
s.SetupStateCreator()
}
return s.ExecuteString(string(scriptContent), scriptPath)
}
// ExecuteString compiles and runs Lua code with bytecode caching
func (s *State) ExecuteString(code, name string) error {
entry, err := CompileAndCache(s.State, code, name)
if err != nil {
return fmt.Errorf("compilation error in '%s': %w", name, err)
}
if err := s.LoadAndRunBytecode(entry.Bytecode, name); err != nil {
return fmt.Errorf("execution error in '%s': %w", name, err)
}
return nil
}
// ExecuteFileWithResults executes a script and returns results
func (s *State) ExecuteFileWithResults(scriptPath string) ([]any, error) {
if _, err := os.Stat(scriptPath); os.IsNotExist(err) {
return nil, fmt.Errorf("script file '%s' not found", scriptPath)
}
scriptContent, err := os.ReadFile(scriptPath)
if err != nil {
return nil, fmt.Errorf("failed to read script file '%s': %w", scriptPath, err)
}
return s.ExecuteStringWithResults(string(scriptContent), scriptPath)
}
// ExecuteStringWithResults executes code and returns all results with bytecode caching
func (s *State) ExecuteStringWithResults(code, name string) ([]any, error) {
baseTop := s.GetTop()
defer s.SetTop(baseTop)
entry, err := CompileAndCache(s.State, code, name)
if err != nil {
return nil, fmt.Errorf("compilation error in '%s': %w", name, err)
}
if err := s.LoadAndRunBytecodeWithResults(entry.Bytecode, name, -1); err != nil {
return nil, fmt.Errorf("execution error in '%s': %w", name, err)
}
nresults := s.GetTop() - baseTop
if nresults == 0 {
return nil, nil
}
results := make([]any, nresults)
for i := range nresults {
val, err := s.ToValue(baseTop + i + 1)
if err != nil {
results[i] = nil
} else {
results[i] = val
}
}
return results, nil
}
// IsInitialized returns whether the state has been properly initialized
func (s *State) IsInitialized() bool {
return s.initialized
}
// IsWorker returns whether this state is a worker state
func (s *State) IsWorker() bool {
return s.isWorker
}
// Close cleans up the state and releases resources
func (s *State) Close() {
if s.State != nil {
if !s.isWorker {
kv.CloseAllStores()
}
s.Cleanup()
s.State.Close()
s.State = nil
}
s.initialized = false
}
// Quick creates a minimal state for one-off script execution
func Quick() (*State, error) {
return New(Config{
OpenLibs: true,
InstallModules: true,
})
}
// QuickExecute creates a state, executes code, and cleans up
func QuickExecute(code, name string) error {
state, err := Quick()
if err != nil {
return err
}
defer state.Close()
return state.ExecuteString(code, name)
}
// QuickExecuteWithResults executes code and returns results
func QuickExecuteWithResults(code, name string) ([]any, error) {
state, err := Quick()
if err != nil {
return nil, err
}
defer state.Close()
return state.ExecuteStringWithResults(code, name)
}

507
tests/crypto.lua Normal file
View File

@ -0,0 +1,507 @@
require("tests")
local crypto = require("crypto")
-- Test data
local test_data = "Hello, World!"
local test_key = "secret-key-123"
-- ======================================================================
-- ENCODING/DECODING TESTS
-- ======================================================================
test("Base64 Encoding/Decoding", function()
local encoded = crypto.base64_encode(test_data)
assert_equal(type(encoded), "string")
assert(#encoded > 0, "encoded string should not be empty")
local decoded = crypto.base64_decode(encoded)
assert_equal(decoded, test_data)
end)
test("Base64 URL Encoding/Decoding", function()
local encoded = crypto.base64_url_encode(test_data)
assert_equal(type(encoded), "string")
local decoded = crypto.base64_url_decode(encoded)
assert_equal(decoded, test_data)
end)
test("Hex Encoding/Decoding", function()
local encoded = crypto.hex_encode(test_data)
assert_equal(type(encoded), "string")
assert(encoded:match("^[0-9a-f]+$"), "hex should only contain hex characters")
local decoded = crypto.hex_decode(encoded)
assert_equal(decoded, test_data)
end)
test("Encoding Chain", function()
local chain_encoded = crypto.encode_chain(test_data, {"hex", "base64"})
local chain_decoded = crypto.decode_chain(chain_encoded, {"hex", "base64"})
assert_equal(chain_decoded, test_data)
end)
-- ======================================================================
-- HASHING TESTS
-- ======================================================================
test("MD5 Hash", function()
local hash = crypto.md5(test_data)
assert_equal(type(hash), "string")
assert_equal(#hash, 32) -- MD5 is 32 hex characters
assert(hash:match("^[0-9a-f]+$"), "hash should be hex")
-- Same input should produce same hash
assert_equal(crypto.md5(test_data), hash)
end)
test("SHA1 Hash", function()
local hash = crypto.sha1(test_data)
assert_equal(type(hash), "string")
assert_equal(#hash, 40) -- SHA1 is 40 hex characters
assert(hash:match("^[0-9a-f]+$"), "hash should be hex")
end)
test("SHA256 Hash", function()
local hash = crypto.sha256(test_data)
assert_equal(type(hash), "string")
assert_equal(#hash, 64) -- SHA256 is 64 hex characters
assert(hash:match("^[0-9a-f]+$"), "hash should be hex")
end)
test("SHA512 Hash", function()
local hash = crypto.sha512(test_data)
assert_equal(type(hash), "string")
assert_equal(#hash, 128) -- SHA512 is 128 hex characters
assert(hash:match("^[0-9a-f]+$"), "hash should be hex")
end)
test("Hash Multiple Inputs", function()
local hash1 = crypto.hash_multiple({"hello", "world"})
local hash2 = crypto.hash_multiple({"hello", "world"})
local hash3 = crypto.hash_multiple({"world", "hello"})
assert_equal(hash1, hash2)
assert(hash1 ~= hash3, "different order should produce different hash")
end)
-- ======================================================================
-- HMAC TESTS
-- ======================================================================
test("HMAC SHA1", function()
local hmac = crypto.hmac_sha1(test_data, test_key)
assert_equal(type(hmac), "string")
assert_equal(#hmac, 40) -- SHA1 HMAC is 40 hex characters
assert(hmac:match("^[0-9a-f]+$"), "hmac should be hex")
end)
test("HMAC SHA256", function()
local hmac = crypto.hmac_sha256(test_data, test_key)
assert_equal(type(hmac), "string")
assert_equal(#hmac, 64) -- SHA256 HMAC is 64 hex characters
assert(hmac:match("^[0-9a-f]+$"), "hmac should be hex")
end)
test("MAC Functions", function()
local mac = crypto.mac(test_data, test_key)
assert_equal(type(mac), "string")
assert(#mac > 0, "mac should not be empty")
local valid = crypto.verify_mac(test_data, test_key, mac)
assert_equal(valid, true)
local invalid = crypto.verify_mac("different data", test_key, mac)
assert_equal(invalid, false)
end)
-- ======================================================================
-- UUID TESTS
-- ======================================================================
test("UUID Generation", function()
local uuid1 = crypto.uuid()
local uuid2 = crypto.uuid_v4()
assert_equal(type(uuid1), "string")
assert_equal(type(uuid2), "string")
assert_equal(#uuid1, 36) -- Standard UUID length
assert_equal(#uuid2, 36)
assert(uuid1 ~= uuid2, "UUIDs should be unique")
assert_equal(crypto.is_uuid(uuid1), true)
assert_equal(crypto.is_uuid(uuid2), true)
assert_equal(crypto.is_uuid("not-a-uuid"), false)
end)
-- ======================================================================
-- RANDOM GENERATION TESTS
-- ======================================================================
test("Random Bytes", function()
local bytes1 = crypto.random_bytes(16)
local bytes2 = crypto.random_bytes(16)
assert_equal(type(bytes1), "string")
assert_equal(#bytes1, 16)
assert(bytes1 ~= bytes2, "random bytes should be different")
end)
test("Random Hex", function()
local hex1 = crypto.random_hex(8)
local hex2 = crypto.random_hex(8)
assert_equal(type(hex1), "string")
assert_equal(#hex1, 16) -- 8 bytes = 16 hex characters
assert(hex1:match("^[0-9a-f]+$"), "should be hex")
assert(hex1 ~= hex2, "random hex should be different")
end)
test("Random String", function()
local str1 = crypto.random_string(10)
local str2 = crypto.random_string(10)
assert_equal(type(str1), "string")
assert_equal(#str1, 10)
assert(str1 ~= str2, "random strings should be different")
local custom = crypto.random_string(5, "abc")
assert_equal(#custom, 5)
assert(custom:match("^[abc]+$"), "should only contain specified characters")
end)
test("Random Alphanumeric", function()
local str = crypto.random_alphanumeric(20)
assert_equal(#str, 20)
assert(str:match("^[a-zA-Z0-9]+$"), "should be alphanumeric")
end)
test("Random Password", function()
local pass1 = crypto.random_password(12)
local pass2 = crypto.random_password(12, true) -- with symbols
assert_equal(#pass1, 12)
assert_equal(#pass2, 12)
assert(pass1 ~= pass2, "passwords should be different")
end)
test("Token Generation", function()
local token1 = crypto.token(32)
local token2 = crypto.token(32)
assert_equal(#token1, 64) -- 32 bytes = 64 hex characters
assert(token1:match("^[0-9a-f]+$"), "token should be hex")
assert(token1 ~= token2, "tokens should be unique")
end)
test("Nonce Generation", function()
local nonce1 = crypto.nonce()
local nonce2 = crypto.nonce(32)
assert_equal(#nonce1, 32) -- default 16 bytes = 32 hex
assert_equal(#nonce2, 64) -- 32 bytes = 64 hex
assert(nonce1 ~= nonce2, "nonces should be unique")
end)
-- ======================================================================
-- UTILITY TESTS
-- ======================================================================
test("Secure Compare", function()
local str1 = "hello"
local str2 = "hello"
local str3 = "world"
assert_equal(crypto.secure_compare(str1, str2), true)
assert_equal(crypto.secure_compare(str1, str3), false)
end)
test("Checksum Functions", function()
local checksum = crypto.checksum(test_data)
assert_equal(type(checksum), "string")
local valid = crypto.verify_checksum(test_data, checksum)
assert_equal(valid, true)
local invalid = crypto.verify_checksum("different", checksum)
assert_equal(invalid, false)
end)
test("XOR Encryption", function()
local key = "mykey"
local encrypted = crypto.xor_encrypt(test_data, key)
local decrypted = crypto.xor_decrypt(encrypted, key)
assert_equal(decrypted, test_data)
assert(encrypted ~= test_data, "encrypted should be different")
end)
test("Hash Chain", function()
local chain1 = crypto.hash_chain(test_data, 100)
local chain2 = crypto.hash_chain(test_data, 100)
local chain3 = crypto.hash_chain(test_data, 101)
assert_equal(chain1, chain2)
assert(chain1 ~= chain3, "different iterations should produce different results")
end)
test("Key Derivation", function()
local derived1, salt1 = crypto.derive_key("password", "salt", 1000)
local derived2, salt2 = crypto.derive_key("password", salt1, 1000)
assert_equal(derived1, derived2)
assert_equal(salt1, salt2)
assert_equal(type(derived1), "string")
assert(#derived1 > 0, "derived key should not be empty")
end)
test("Fingerprint", function()
local data = {name = "test", value = 42}
local fp1 = crypto.fingerprint(data)
local fp2 = crypto.fingerprint(data)
assert_equal(fp1, fp2)
assert_equal(type(fp1), "string")
assert(#fp1 > 0, "fingerprint should not be empty")
end)
test("Integrity Check", function()
local check = crypto.integrity_check(test_data)
assert_equal(check.data, test_data)
assert_equal(type(check.hash), "string")
assert_equal(type(check.timestamp), "number")
assert_equal(type(check.uuid), "string")
local valid = crypto.verify_integrity(check)
assert_equal(valid, true)
-- Tamper with data
check.data = "tampered"
local invalid = crypto.verify_integrity(check)
assert_equal(invalid, false)
end)
-- ======================================================================
-- ERROR HANDLING TESTS
-- ======================================================================
test("Error Handling", function()
-- Invalid base64
local success, err = pcall(crypto.base64_decode, "invalid===base64")
assert_equal(success, false)
-- Invalid hex
local success2, err2 = pcall(crypto.hex_decode, "invalid_hex")
assert_equal(success2, false)
-- Invalid UUID validation (returns boolean, doesn't throw)
assert_equal(crypto.is_uuid("not-a-uuid"), false)
assert_equal(crypto.is_uuid(""), false)
assert_equal(crypto.is_uuid("12345"), false)
end)
-- ======================================================================
-- PASSWORD TESTS
-- ======================================================================
test("Password Hash and Verification", function()
local password = "hubba-ba-loo117!@#"
local hash = crypto.hash_password(password)
local hash_fast = crypto.hash_password_fast(password)
local hash_strong = crypto.hash_password_strong(password)
assert(crypto.verify_password(password, hash))
assert(crypto.verify_password(password, hash_fast))
assert(crypto.verify_password(password, hash_strong))
assert(not crypto.verify_password("failure", hash))
assert(not crypto.verify_password("failure", hash_fast))
assert(not crypto.verify_password("failure", hash_strong))
end)
test("Algorithm-Specific Password Hashing", function()
local password = "test123!@#"
-- Test each algorithm individually
local argon2_hash = crypto.hash_password(password, "argon2id")
local bcrypt_hash = crypto.hash_password(password, "bcrypt")
local scrypt_hash = crypto.hash_password(password, "scrypt")
local pbkdf2_hash = crypto.hash_password(password, "pbkdf2")
assert(crypto.verify_password(password, argon2_hash))
assert(crypto.verify_password(password, bcrypt_hash))
assert(crypto.verify_password(password, scrypt_hash))
assert(crypto.verify_password(password, pbkdf2_hash))
-- Verify wrong passwords fail
assert(not crypto.verify_password("wrong", argon2_hash))
assert(not crypto.verify_password("wrong", bcrypt_hash))
assert(not crypto.verify_password("wrong", scrypt_hash))
assert(not crypto.verify_password("wrong", pbkdf2_hash))
end)
test("Algorithm Detection", function()
local password = "detectme123"
local argon2_hash = crypto.hash_password(password, "argon2id")
local bcrypt_hash = crypto.hash_password(password, "bcrypt")
local scrypt_hash = crypto.hash_password(password, "scrypt")
local pbkdf2_hash = crypto.hash_password(password, "pbkdf2")
assert(crypto.detect_algorithm(argon2_hash) == "argon2id")
assert(crypto.detect_algorithm(bcrypt_hash) == "bcrypt")
assert(crypto.detect_algorithm(scrypt_hash) == "scrypt")
assert(crypto.detect_algorithm(pbkdf2_hash) == "pbkdf2")
assert(crypto.detect_algorithm("invalid$format") == "unknown")
end)
test("Custom Algorithm Options", function()
local password = "custom123"
-- Test custom argon2id options
local custom_argon2 = crypto.hash_password(password, "argon2id", {
time = 2,
memory = 32768,
threads = 2
})
assert(crypto.verify_password(password, custom_argon2))
-- Test custom bcrypt cost
local custom_bcrypt = crypto.hash_password(password, "bcrypt", {cost = 10})
assert(crypto.verify_password(password, custom_bcrypt))
-- Test custom scrypt parameters
local custom_scrypt = crypto.hash_password(password, "scrypt", {
N = 16384,
r = 4,
p = 2
})
assert(crypto.verify_password(password, custom_scrypt))
-- Test custom pbkdf2 iterations
local custom_pbkdf2 = crypto.hash_password(password, "pbkdf2", {
iterations = 50000
})
assert(crypto.verify_password(password, custom_pbkdf2))
end)
test("Direct Algorithm Functions", function()
local password = "direct123"
-- Test direct algorithm calls
local argon2_direct = crypto.argon2_hash(password)
local bcrypt_direct = crypto.bcrypt_hash(password)
local scrypt_direct = crypto.scrypt_hash(password)
local pbkdf2_direct = crypto.pbkdf2_hash(password)
assert(crypto.argon2_verify(password, argon2_direct))
assert(crypto.bcrypt_verify(password, bcrypt_direct))
assert(crypto.scrypt_verify(password, scrypt_direct))
assert(crypto.pbkdf2_verify(password, pbkdf2_direct))
-- Test with custom options
local argon2_custom = crypto.argon2_hash(password, {time = 1, memory = 16384})
local scrypt_custom = crypto.scrypt_hash(password, {N = 8192})
assert(crypto.argon2_verify(password, argon2_custom))
assert(crypto.scrypt_verify(password, scrypt_custom))
end)
test("Security Level Presets", function()
local password = "preset123"
local algorithms = {"argon2id", "bcrypt", "scrypt", "pbkdf2"}
for _, algo in ipairs(algorithms) do
local fast_hash = crypto.hash_password_fast(password, algo)
local strong_hash = crypto.hash_password_strong(password, algo)
assert(crypto.verify_password(password, fast_hash))
assert(crypto.verify_password(password, strong_hash))
-- Verify algorithm detection
assert(crypto.detect_algorithm(fast_hash) == algo)
assert(crypto.detect_algorithm(strong_hash) == algo)
end
end)
test("Edge Cases and Error Handling", function()
-- Test empty password
local empty_hash = crypto.hash_password("")
assert(crypto.verify_password("", empty_hash))
assert(not crypto.verify_password("not-empty", empty_hash))
-- Test long password
local long_password = string.rep("a", 1000)
local long_hash = crypto.hash_password(long_password)
assert(crypto.verify_password(long_password, long_hash))
-- Test unicode password
local unicode_password = "🔐password123🔑"
local unicode_hash = crypto.hash_password(unicode_password)
assert(crypto.verify_password(unicode_password, unicode_hash))
-- Test invalid hash formats
assert(not crypto.verify_password("test", "invalid-hash"))
assert(not crypto.verify_password("test", "$invalid$format$"))
-- Test unsupported algorithm error
local success, err = pcall(crypto.hash_password, "test", "invalid-algo")
assert(not success)
assert(string.find(err, "unsupported algorithm"))
end)
test("Cross-Algorithm Verification", function()
local password = "cross123"
-- Create hashes with different algorithms
local hashes = {
crypto.hash_password(password, "argon2id"),
crypto.hash_password(password, "bcrypt"),
crypto.hash_password(password, "scrypt"),
crypto.hash_password(password, "pbkdf2")
}
-- Each hash should only verify with correct password
for _, hash in ipairs(hashes) do
assert(crypto.verify_password(password, hash))
assert(not crypto.verify_password("wrong", hash))
end
-- Hashes should be different from each other
for i = 1, #hashes do
for j = i + 1, #hashes do
assert(hashes[i] ~= hashes[j])
end
end
end)
-- ======================================================================
-- PERFORMANCE TESTS
-- ======================================================================
test("Performance Test", function()
local large_data = string.rep("test data for performance ", 1000)
local start = os.clock()
local hash = crypto.sha256(large_data)
local hash_time = os.clock() - start
start = os.clock()
local encoded = crypto.base64_encode(large_data)
local encode_time = os.clock() - start
start = os.clock()
local decoded = crypto.base64_decode(encoded)
local decode_time = os.clock() - start
print(string.format(" SHA256 of %d bytes: %.3fs", #large_data, hash_time))
print(string.format(" Base64 encode: %.3fs", encode_time))
print(string.format(" Base64 decode: %.3fs", decode_time))
assert_equal(decoded, large_data)
assert_equal(type(hash), "string")
end)
summary()
test_exit()

456
tests/fs.lua Normal file
View File

@ -0,0 +1,456 @@
require("tests")
local fs = require("fs")
-- Test data
local test_content = "Hello, filesystem!\nThis is a test file.\n"
local test_dir = "test_fs_dir"
local test_file = fs.join(test_dir, "test.txt")
-- Clean up function
local function cleanup()
if fs.exists(test_file) then fs.remove(test_file) end
if fs.exists(test_dir) then fs.rmdir(test_dir) end
end
-- ======================================================================
-- SETUP AND CLEANUP
-- ======================================================================
-- Clean up before tests
cleanup()
-- ======================================================================
-- BASIC FILE OPERATIONS
-- ======================================================================
test("File Write and Read", function()
fs.mkdir(test_dir)
fs.write(test_file, test_content)
assert_equal(fs.exists(test_file), true)
assert_equal(fs.is_file(test_file), true)
assert_equal(fs.is_dir(test_file), false)
local content = fs.read(test_file)
assert_equal(content, test_content)
end)
test("File Size", function()
local size = fs.size(test_file)
assert_equal(size, #test_content)
assert_equal(fs.size("nonexistent.txt"), nil)
end)
test("File Append", function()
local additional = "Appended content.\n"
fs.append(test_file, additional)
local content = fs.read(test_file)
assert_equal(content, test_content .. additional)
end)
test("File Copy", function()
local copy_file = fs.join(test_dir, "copy.txt")
fs.copy(test_file, copy_file)
assert_equal(fs.exists(copy_file), true)
assert_equal(fs.read(copy_file), fs.read(test_file))
fs.remove(copy_file)
end)
test("File Move", function()
local move_file = fs.join(test_dir, "moved.txt")
local original_content = fs.read(test_file)
fs.move(test_file, move_file)
assert_equal(fs.exists(test_file), false)
assert_equal(fs.exists(move_file), true)
assert_equal(fs.read(move_file), original_content)
-- Move back for other tests
fs.move(move_file, test_file)
end)
test("File Lines", function()
local lines = fs.lines(test_file)
assert_equal(type(lines), "table")
assert(#lines >= 2, "should have multiple lines")
assert(string.find(lines[1], "Hello"), "first line should contain 'Hello'")
end)
test("File Touch", function()
local touch_file = fs.join(test_dir, "touched.txt")
fs.touch(touch_file)
assert_equal(fs.exists(touch_file), true)
assert_equal(fs.size(touch_file), 0)
fs.remove(touch_file)
end)
test("File Modification Time", function()
local mtime = fs.mtime(test_file)
assert_equal(type(mtime), "number")
assert(mtime > 0, "mtime should be positive")
-- Touch should update mtime
local old_mtime = mtime
fs.touch(test_file)
local new_mtime = fs.mtime(test_file)
assert(new_mtime >= old_mtime, "mtime should be updated")
end)
-- ======================================================================
-- DIRECTORY OPERATIONS
-- ======================================================================
test("Directory Creation and Removal", function()
local nested_dir = fs.join(test_dir, "nested", "deep")
fs.mkdir(nested_dir)
assert_equal(fs.exists(nested_dir), true)
assert_equal(fs.is_dir(nested_dir), true)
-- Clean up nested directories
fs.rmdir(fs.join(test_dir, "nested"))
end)
test("Directory Listing", function()
-- Create some test files
fs.write(fs.join(test_dir, "file1.txt"), "content1")
fs.write(fs.join(test_dir, "file2.log"), "content2")
fs.mkdir(fs.join(test_dir, "subdir"))
local entries = fs.list(test_dir)
assert_equal(type(entries), "table")
assert(#entries >= 3, "should have at least 3 entries")
-- Check entry structure
local found_file = false
for _, entry in ipairs(entries) do
assert_equal(type(entry.name), "string")
assert_equal(type(entry.is_dir), "boolean")
if entry.name == "file1.txt" then
found_file = true
assert_equal(entry.is_dir, false)
assert_equal(type(entry.size), "number")
end
end
assert_equal(found_file, true)
-- Test filtered listings
local files = fs.list_files(test_dir)
local dirs = fs.list_dirs(test_dir)
local names = fs.list_names(test_dir)
assert_equal(type(files), "table")
assert_equal(type(dirs), "table")
assert_equal(type(names), "table")
assert(#files > 0, "should have files")
assert(#dirs > 0, "should have directories")
-- Clean up
fs.remove(fs.join(test_dir, "file1.txt"))
fs.remove(fs.join(test_dir, "file2.log"))
fs.rmdir(fs.join(test_dir, "subdir"))
end)
-- ======================================================================
-- PATH OPERATIONS
-- ======================================================================
test("Path Join", function()
local path = fs.join("a", "b", "c", "file.txt")
assert(string.find(path, "file.txt"), "should contain filename")
local empty_path = fs.join()
assert_equal(type(empty_path), "string")
end)
test("Path Components", function()
local test_path = fs.join("home", "user", "documents", "file.txt")
local dir = fs.dirname(test_path)
local base = fs.basename(test_path)
local ext = fs.ext(test_path)
assert_equal(base, "file.txt")
assert_equal(ext, ".txt")
assert(string.find(dir, "documents"), "dirname should contain 'documents'")
end)
test("Path Split Extension", function()
local test_path = fs.join("home", "user", "file.tar.gz")
local dir, name, ext = fs.splitext(test_path)
assert_equal(name, "file.tar")
assert_equal(ext, ".gz")
assert(string.find(dir, "user"), "dir should contain 'user'")
end)
test("Path Absolute", function()
local abs_path = fs.abs(".")
assert_equal(type(abs_path), "string")
assert(#abs_path > 1, "absolute path should not be empty")
end)
test("Path Clean", function()
local messy_path = "./test/../test/./file.txt"
local clean_path = fs.clean(messy_path)
assert_equal(type(clean_path), "string")
assert(not string.find(clean_path, "%.%."), "should not contain '..'")
end)
test("Path Split", function()
local test_path = fs.join("home", "user", "file.txt")
local dir, file = fs.split(test_path)
assert_equal(file, "file.txt")
assert(string.find(dir, "user"), "dir should contain 'user'")
end)
-- ======================================================================
-- WORKING DIRECTORY
-- ======================================================================
test("Working Directory", function()
local original_cwd = fs.getcwd()
assert_equal(type(original_cwd), "string")
assert(#original_cwd > 0, "cwd should not be empty")
-- Test directory change
fs.chdir(test_dir)
local new_cwd = fs.getcwd()
assert(string.find(new_cwd, test_dir), "cwd should contain test_dir")
-- Change back
fs.chdir(original_cwd)
assert_equal(fs.getcwd(), original_cwd)
end)
-- ======================================================================
-- TEMPORARY FILES
-- ======================================================================
test("Temporary Files", function()
local temp_file = fs.tempfile("test_")
local temp_dir = fs.tempdir("test_")
assert_equal(type(temp_file), "string")
assert_equal(type(temp_dir), "string")
assert_equal(fs.exists(temp_file), true)
assert_equal(fs.exists(temp_dir), true)
assert_equal(fs.is_dir(temp_dir), true)
-- Clean up
fs.remove(temp_file)
fs.rmdir(temp_dir)
end)
-- ======================================================================
-- PATTERN MATCHING
-- ======================================================================
test("Glob Patterns", function()
-- Create test files for globbing
fs.write(fs.join(test_dir, "test1.txt"), "content")
fs.write(fs.join(test_dir, "test2.txt"), "content")
fs.write(fs.join(test_dir, "other.log"), "content")
local pattern = fs.join(test_dir, "*.txt")
local matches = fs.glob(pattern)
assert_equal(type(matches), "table")
assert(#matches >= 2, "should match txt files")
-- Clean up
fs.remove(fs.join(test_dir, "test1.txt"))
fs.remove(fs.join(test_dir, "test2.txt"))
fs.remove(fs.join(test_dir, "other.log"))
end)
test("Walk Directory", function()
-- Create nested structure
fs.mkdir(fs.join(test_dir, "sub1"))
fs.mkdir(fs.join(test_dir, "sub2"))
fs.write(fs.join(test_dir, "root.txt"), "content")
fs.write(fs.join(test_dir, "sub1", "nested.txt"), "content")
local files = fs.walk(test_dir)
assert_equal(type(files), "table")
assert(#files > 3, "should find multiple files and directories")
-- Clean up
fs.remove(fs.join(test_dir, "root.txt"))
fs.remove(fs.join(test_dir, "sub1", "nested.txt"))
fs.rmdir(fs.join(test_dir, "sub1"))
fs.rmdir(fs.join(test_dir, "sub2"))
end)
-- ======================================================================
-- UTILITY FUNCTIONS
-- ======================================================================
test("File Extension Functions", function()
local path = "document.pdf"
assert_equal(fs.extension(path), "pdf")
local new_path = fs.change_ext(path, "txt")
assert_equal(new_path, "document.txt")
local new_path2 = fs.change_ext(path, ".docx")
assert_equal(new_path2, "document.docx")
end)
test("Ensure Directory", function()
local ensure_dir = fs.join(test_dir, "ensure_test")
fs.ensure_dir(ensure_dir)
assert_equal(fs.exists(ensure_dir), true)
assert_equal(fs.is_dir(ensure_dir), true)
-- Should not error if already exists
fs.ensure_dir(ensure_dir)
fs.rmdir(ensure_dir)
end)
test("Human Readable Size", function()
local small_file = fs.join(test_dir, "small.txt")
fs.write(small_file, "test")
local size_str = fs.size_human(small_file)
assert_equal(type(size_str), "string")
assert(string.find(size_str, "B"), "should contain byte unit")
fs.remove(small_file)
end)
test("Safe Path Check", function()
assert_equal(fs.is_safe_path("safe/path.txt"), true)
assert_equal(fs.is_safe_path("../dangerous"), false)
assert_equal(fs.is_safe_path("/absolute/path"), false)
assert_equal(fs.is_safe_path("~/home/path"), false)
end)
test("Copy Tree", function()
-- Create source structure
local src_dir = fs.join(test_dir, "src")
local dst_dir = fs.join(test_dir, "dst")
fs.mkdir(src_dir)
fs.mkdir(fs.join(src_dir, "subdir"))
fs.write(fs.join(src_dir, "file1.txt"), "content1")
fs.write(fs.join(src_dir, "subdir", "file2.txt"), "content2")
fs.copytree(src_dir, dst_dir)
assert_equal(fs.exists(dst_dir), true)
assert_equal(fs.exists(fs.join(dst_dir, "file1.txt")), true)
assert_equal(fs.exists(fs.join(dst_dir, "subdir", "file2.txt")), true)
assert_equal(fs.read(fs.join(dst_dir, "file1.txt")), "content1")
-- Clean up
fs.rmdir(src_dir)
fs.rmdir(dst_dir)
end)
test("Find Files", function()
-- Create test files
fs.write(fs.join(test_dir, "find1.txt"), "content")
fs.write(fs.join(test_dir, "find2.txt"), "content")
fs.write(fs.join(test_dir, "other.log"), "content")
fs.mkdir(fs.join(test_dir, "subdir"))
fs.write(fs.join(test_dir, "subdir", "find3.txt"), "content")
local txt_files = fs.find(test_dir, "%.txt$", true)
assert_equal(type(txt_files), "table")
assert(#txt_files >= 3, "should find txt files recursively")
local txt_files_flat = fs.find(test_dir, "%.txt$", false)
assert(#txt_files_flat < #txt_files, "non-recursive should find fewer files")
-- Clean up
fs.remove(fs.join(test_dir, "find1.txt"))
fs.remove(fs.join(test_dir, "find2.txt"))
fs.remove(fs.join(test_dir, "other.log"))
fs.remove(fs.join(test_dir, "subdir", "find3.txt"))
fs.rmdir(fs.join(test_dir, "subdir"))
end)
test("Directory Tree", function()
-- Create test structure
fs.mkdir(fs.join(test_dir, "tree_test"))
fs.write(fs.join(test_dir, "tree_test", "file.txt"), "content")
fs.mkdir(fs.join(test_dir, "tree_test", "subdir"))
local tree = fs.tree(test_dir)
assert_equal(type(tree), "table")
assert_equal(tree.is_dir, true)
assert_equal(type(tree.children), "table")
assert(#tree.children > 0, "should have children")
-- Clean up
fs.remove(fs.join(test_dir, "tree_test", "file.txt"))
fs.rmdir(fs.join(test_dir, "tree_test", "subdir"))
fs.rmdir(fs.join(test_dir, "tree_test"))
end)
-- ======================================================================
-- ERROR HANDLING
-- ======================================================================
test("Error Handling", function()
-- Reading non-existent file
local success, err = pcall(fs.read, "nonexistent.txt")
assert_equal(false, success)
-- Writing to invalid path
local success2, err2 = pcall(fs.write, "/invalid/path/file.txt", "content")
assert_equal(false, success2)
-- Listing non-existent directory
local success3, err3 = pcall(fs.list, "nonexistent_dir")
assert_equal(false, success3)
end)
-- ======================================================================
-- PERFORMANCE TESTS
-- ======================================================================
test("Performance Test", function()
local large_content = string.rep("performance test data\n", 1000)
local perf_file = fs.join(test_dir, "performance.txt")
local start = os.clock()
fs.write(perf_file, large_content)
local write_time = os.clock() - start
start = os.clock()
local read_content = fs.read(perf_file)
local read_time = os.clock() - start
start = os.clock()
local lines = fs.lines(perf_file)
local lines_time = os.clock() - start
print(string.format(" Write %d bytes: %.3fs", #large_content, write_time))
print(string.format(" Read %d bytes: %.3fs", #read_content, read_time))
print(string.format(" Parse %d lines: %.3fs", #lines, lines_time))
assert_equal(read_content, large_content)
assert_equal(#lines, 1000)
fs.remove(perf_file)
end)
-- ======================================================================
-- CLEANUP
-- ======================================================================
cleanup()
summary()
test_exit()

188
tests/json.lua Normal file
View File

@ -0,0 +1,188 @@
require("tests")
--local json = require("json")
-- Test data
local test_data = {
name = "John Doe",
age = 30,
active = true,
scores = {85, 92, 78, 90},
address = {
street = "123 Main St",
city = "Springfield",
zip = "12345"
},
tags = {"developer", "golang", "lua"}
}
-- Test 1: Basic encoding
test("Basic JSON Encoding", function()
local encoded = json.encode(test_data)
assert_equal(type(encoded), "string")
assert(string.find(encoded, "John Doe"), "should contain name")
assert(string.find(encoded, "30"), "should contain age")
end)
-- Test 2: Basic decoding
test("Basic JSON Decoding", function()
local encoded = json.encode(test_data)
local decoded = json.decode(encoded)
assert_equal(decoded.name, "John Doe")
assert_equal(decoded.age, 30)
assert_equal(decoded.active, true)
assert_equal(#decoded.scores, 4)
end)
-- Test 3: Round-trip encoding/decoding
test("Round-trip Encoding/Decoding", function()
local encoded = json.encode(test_data)
local decoded = json.decode(encoded)
local re_encoded = json.encode(decoded)
local re_decoded = json.decode(re_encoded)
assert_equal(re_decoded.name, test_data.name)
assert_equal(re_decoded.address.city, test_data.address.city)
end)
-- Test 4: Pretty printing
test("Pretty Printing", function()
local pretty = json.pretty(test_data)
assert_equal(type(pretty), "string")
assert(string.find(pretty, "\n"), "pretty should contain newlines")
assert(string.find(pretty, " "), "pretty should contain indentation")
-- Should still be valid JSON
local decoded = json.decode(pretty)
assert_equal(decoded.name, test_data.name)
end)
-- Test 5: Object merging
test("Object Merging", function()
local obj1 = {a = 1, b = 2}
local obj2 = {b = 3, c = 4}
local obj3 = {d = 5}
local merged = json.merge(obj1, obj2, obj3)
assert_equal(merged.a, 1)
assert_equal(merged.b, 3) -- later wins
assert_equal(merged.c, 4)
assert_equal(merged.d, 5)
end)
-- Test 6: Data extraction
test("Data Extraction", function()
local name = json.extract(test_data, "name")
assert_equal(name, "John Doe")
local city = json.extract(test_data, "address.city")
assert_equal(city, "Springfield")
local first_score = json.extract(test_data, "scores.[0]")
assert_equal(first_score, 85)
local missing = json.extract(test_data, "nonexistent.field")
assert_equal(missing, nil)
end)
-- Test 7: Schema validation
test("Schema Validation", function()
local schema = {
type = "table",
properties = {
name = {type = "string"},
age = {type = "number"},
active = {type = "boolean"}
},
required = {name = true, age = true}
}
local valid, err = json.validate(test_data, schema)
assert_equal(valid, true)
local invalid_data = {name = "John", age = "not_a_number"}
local invalid, err2 = json.validate(invalid_data, schema)
assert_equal(invalid, false)
assert_equal(type(err2), "string")
end)
-- Test 8: File operations
test("File Save/Load", function()
local filename = "test_output.json"
-- Save to file
json.save_file(filename, test_data, true) -- pretty format
-- Check file exists
assert(file_exists(filename), "file should exist after save")
-- Load from file
local loaded = json.load_file(filename)
assert_equal(loaded.name, test_data.name)
assert_equal(loaded.address.zip, test_data.address.zip)
-- Clean up
os.remove(filename)
end)
-- Test 9: Error handling
test("Error Handling", function()
-- Invalid JSON should throw error
local success, err = pcall(json.decode, '{"invalid": json}')
assert_equal(success, false)
assert_equal(type(err), "string")
-- Missing file should throw error
local success2, err2 = pcall(json.load_file, "nonexistent_file.json")
assert_equal(success2, false)
assert_equal(type(err2), "string")
end)
-- Test 10: Edge cases
test("Edge Cases", function()
-- Empty objects
local empty_obj = {}
local encoded_empty = json.encode(empty_obj)
local decoded_empty = json.decode(encoded_empty)
assert_equal(type(decoded_empty), "table")
-- Special numbers
local special = {
zero = 0,
negative = -42,
decimal = 3.14159
}
local encoded_special = json.encode(special)
local decoded_special = json.decode(encoded_special)
assert_equal(decoded_special.zero, 0)
assert_equal(decoded_special.negative, -42)
assert_close(decoded_special.decimal, 3.14159, 0.00001)
end)
-- Performance test
test("Performance Test", function()
local large_data = {}
for i = 1, 1000 do
large_data[i] = {
id = i,
name = "User " .. i,
data = {x = i * 2, y = i * 3, z = i * 4}
}
end
local start = os.clock()
local encoded = json.encode(large_data)
local encode_time = os.clock() - start
start = os.clock()
local decoded = json.decode(encoded)
local decode_time = os.clock() - start
print(string.format(" Encoded 1000 objects in %.3f seconds", encode_time))
print(string.format(" Decoded 1000 objects in %.3f seconds", decode_time))
assert_equal(#decoded, 1000)
assert_equal(decoded[500].name, "User 500")
end)
summary()
test_exit()

362
tests/kv.lua Normal file
View File

@ -0,0 +1,362 @@
require("tests")
local kv = require("kv")
-- Clean up any existing test files
os.remove("test_store.json")
os.remove("test_store.txt")
os.remove("test_oop.json")
os.remove("test_temp.json")
-- ======================================================================
-- BASIC OPERATIONS
-- ======================================================================
test("Store creation and opening", function()
assert(kv.open("test", "test_store.json"))
assert(kv.open("memory_only", ""))
assert(not kv.open("", "test.json")) -- Empty name should fail
end)
test("Set and get operations", function()
assert(kv.set("test", "key1", "value1"))
assert_equal("value1", kv.get("test", "key1"))
assert_equal("default", kv.get("test", "nonexistent", "default"))
assert_equal(nil, kv.get("test", "nonexistent"))
-- Test with special characters
assert(kv.set("test", "special:key", "value with spaces & symbols!"))
assert_equal("value with spaces & symbols!", kv.get("test", "special:key"))
end)
test("Key existence and deletion", function()
kv.set("test", "temp_key", "temp_value")
assert(kv.has("test", "temp_key"))
assert(not kv.has("test", "missing_key"))
assert(kv.delete("test", "temp_key"))
assert(not kv.has("test", "temp_key"))
assert(not kv.delete("test", "missing_key"))
end)
test("Store size tracking", function()
kv.clear("test")
assert_equal(0, kv.size("test"))
kv.set("test", "k1", "v1")
kv.set("test", "k2", "v2")
kv.set("test", "k3", "v3")
assert_equal(3, kv.size("test"))
kv.delete("test", "k2")
assert_equal(2, kv.size("test"))
end)
test("Keys and values retrieval", function()
kv.clear("test")
kv.set("test", "a", "apple")
kv.set("test", "b", "banana")
kv.set("test", "c", "cherry")
local keys = kv.keys("test")
local values = kv.values("test")
assert_equal(3, #keys)
assert_equal(3, #values)
-- Check all keys exist
local key_set = {}
for _, k in ipairs(keys) do
key_set[k] = true
end
assert(key_set["a"] and key_set["b"] and key_set["c"])
-- Check all values exist
local value_set = {}
for _, v in ipairs(values) do
value_set[v] = true
end
assert(value_set["apple"] and value_set["banana"] and value_set["cherry"])
end)
test("Clear store", function()
kv.set("test", "temp1", "value1")
kv.set("test", "temp2", "value2")
assert(kv.size("test") > 0)
assert(kv.clear("test"))
assert_equal(0, kv.size("test"))
assert(not kv.has("test", "temp1"))
end)
test("Save and close operations", function()
-- Ensure store is properly opened with filename
kv.close("test") -- Close if already open
assert(kv.open("test", "test_store.json"))
assert(kv.set("test", "persistent", "data"))
assert(kv.save("test"))
assert(kv.close("test"))
-- Reopen and verify data persists
assert(kv.open("test", "test_store.json"))
assert_equal("data", kv.get("test", "persistent"))
end)
test("Invalid store operations", function()
assert(not kv.set("nonexistent", "key", "value"))
assert_equal(nil, kv.get("nonexistent", "key"))
assert(not kv.has("nonexistent", "key"))
assert_equal(0, kv.size("nonexistent"))
assert(not kv.delete("nonexistent", "key"))
assert(not kv.clear("nonexistent"))
assert(not kv.save("nonexistent"))
assert(not kv.close("nonexistent"))
end)
-- ======================================================================
-- OBJECT-ORIENTED INTERFACE
-- ======================================================================
test("OOP store creation", function()
local store = kv.create("oop_test", "test_oop.json")
assert_equal("oop_test", store.name)
local memory_store = kv.create("memory_oop")
assert_equal("memory_oop", memory_store.name)
end)
test("OOP basic operations", function()
local store = kv.create("oop_basic")
assert(store:set("foo", "bar"))
assert_equal("bar", store:get("foo"))
assert_equal("default", store:get("missing", "default"))
assert(store:has("foo"))
assert(not store:has("missing"))
assert_equal(1, store:size())
assert(store:delete("foo"))
assert(not store:has("foo"))
assert_equal(0, store:size())
end)
test("OOP collections", function()
local store = kv.create("oop_collections")
store:set("a", "apple")
store:set("b", "banana")
store:set("c", "cherry")
local keys = store:keys()
local values = store:values()
assert_equal(3, #keys)
assert_equal(3, #values)
store:clear()
assert_equal(0, store:size())
store:close()
end)
-- ======================================================================
-- UTILITY FUNCTIONS
-- ======================================================================
test("Increment operations", function()
kv.open("util_test")
-- Increment non-existent key
assert_equal(1, kv.increment("util_test", "counter"))
assert_equal("1", kv.get("util_test", "counter"))
-- Increment existing key
assert_equal(6, kv.increment("util_test", "counter", 5))
assert_equal("6", kv.get("util_test", "counter"))
-- Decrement
assert_equal(4, kv.increment("util_test", "counter", -2))
kv.close("util_test")
end)
test("Append operations", function()
kv.open("append_test")
-- Append to non-existent key
assert(kv.append("append_test", "list", "first"))
assert_equal("first", kv.get("append_test", "list"))
-- Append with separator
assert(kv.append("append_test", "list", "second", ","))
assert_equal("first,second", kv.get("append_test", "list"))
assert(kv.append("append_test", "list", "third", ","))
assert_equal("first,second,third", kv.get("append_test", "list"))
kv.close("append_test")
end)
test("TTL and expiration", function()
kv.open("ttl_test", "test_temp.json")
kv.set("ttl_test", "temp_key", "temp_value")
assert(kv.expire("ttl_test", "temp_key", 1)) -- 1 second TTL
-- Key should still exist immediately
assert(kv.has("ttl_test", "temp_key"))
-- Wait for expiration
os.execute("sleep 2")
local expired = kv.cleanup_expired("ttl_test")
assert(expired >= 1)
-- Key should be gone
assert(not kv.has("ttl_test", "temp_key"))
kv.close("ttl_test")
end)
-- ======================================================================
-- FILE FORMAT TESTS
-- ======================================================================
test("JSON file format", function()
kv.open("json_test", "test.json")
kv.set("json_test", "key1", "value1")
kv.set("json_test", "key2", "value2")
kv.save("json_test")
kv.close("json_test")
-- Verify file exists and reload
assert(file_exists("test.json"))
kv.open("json_test", "test.json")
assert_equal("value1", kv.get("json_test", "key1"))
assert_equal("value2", kv.get("json_test", "key2"))
kv.close("json_test")
os.remove("test.json")
end)
test("Text file format", function()
kv.open("txt_test", "test.txt")
kv.set("txt_test", "setting1", "value1")
kv.set("txt_test", "setting2", "value2")
kv.save("txt_test")
kv.close("txt_test")
-- Verify file exists and reload
assert(file_exists("test.txt"))
kv.open("txt_test", "test.txt")
assert_equal("value1", kv.get("txt_test", "setting1"))
assert_equal("value2", kv.get("txt_test", "setting2"))
kv.close("txt_test")
os.remove("test.txt")
end)
-- ======================================================================
-- EDGE CASES AND ERROR HANDLING
-- ======================================================================
test("Empty values and keys", function()
kv.open("edge_test")
-- Empty value
assert(kv.set("edge_test", "empty", ""))
assert_equal("", kv.get("edge_test", "empty"))
-- Unicode keys and values
assert(kv.set("edge_test", "ключ", "значение"))
assert_equal("значение", kv.get("edge_test", "ключ"))
kv.close("edge_test")
end)
test("Special characters in data", function()
kv.open("special_test")
local special_value = 'Special chars: "quotes", \'apostrophes\', \n newlines, \t tabs, \\ backslashes'
assert(kv.set("special_test", "special", special_value))
assert_equal(special_value, kv.get("special_test", "special"))
kv.close("special_test")
end)
test("Large data handling", function()
kv.open("large_test")
-- Large value
local large_value = string.rep("x", 10000)
assert(kv.set("large_test", "large", large_value))
assert_equal(large_value, kv.get("large_test", "large"))
-- Many keys
for i = 1, 100 do
kv.set("large_test", "key" .. i, "value" .. i)
end
assert_equal(101, kv.size("large_test")) -- 100 + 1 large value
kv.close("large_test")
end)
-- ======================================================================
-- PERFORMANCE TESTS
-- ======================================================================
test("Performance test", function()
kv.open("perf_test")
local start = os.clock()
-- Bulk insert
for i = 1, 1000 do
kv.set("perf_test", "key" .. i, "value" .. i)
end
local insert_time = os.clock() - start
-- Bulk read
start = os.clock()
for i = 1, 1000 do
local value = kv.get("perf_test", "key" .. i)
assert_equal("value" .. i, value)
end
local read_time = os.clock() - start
-- Check final size
assert_equal(1000, kv.size("perf_test"))
print(string.format(" Insert 1000 items: %.3fs", insert_time))
print(string.format(" Read 1000 items: %.3fs", read_time))
kv.close("perf_test")
end)
-- ======================================================================
-- INTEGRATION TESTS
-- ======================================================================
test("Multiple store integration", function()
local users = kv.create("users_int")
local cache = kv.create("cache_int")
-- Simulate user data
users:set("user:123", "john_doe")
cache:set("user:123:last_seen", tostring(os.time()))
-- Verify data in stores
assert_equal("john_doe", users:get("user:123"))
assert(cache:has("user:123:last_seen"))
-- Clean up
users:close()
cache:close()
end)
-- Clean up test files
os.remove("test_store.json")
os.remove("test_oop.json")
os.remove("test_temp.json")
summary()
test_exit()

393
tests/math.lua Normal file
View File

@ -0,0 +1,393 @@
require("tests")
-- Constants tests
test("math.pi", function()
assert_close(math.pi, 3.14159265358979323846)
assert(math.pi > 3.14 and math.pi < 3.15)
end)
test("math.tau", function()
assert_close(math.tau, 6.28318530717958647693)
assert_close(math.tau, 2 * math.pi)
end)
test("math.e", function()
assert_close(math.e, 2.71828182845904523536)
assert(math.e > 2.7 and math.e < 2.8)
end)
test("math.phi", function()
assert_close(math.phi, 1.61803398874989484820)
assert_close(math.phi, (1 + math.sqrt(5)) / 2)
end)
test("math.infinity", function()
assert_equal(math.infinity, 1/0)
assert(math.infinity > 0)
end)
test("math.nan", function()
assert_equal(math.isnan(math.nan), true)
assert(math.nan ~= math.nan) -- NaN property
end)
-- Extended functions tests
test("math.cbrt", function()
assert_close(math.cbrt(8), 2)
assert_close(math.cbrt(-8), -2)
end)
test("math.hypot", function()
assert_close(math.hypot(3, 4), 5)
assert_close(math.hypot(5, 12), 13)
end)
test("math.isnan", function()
assert_equal(math.isnan(0/0), true)
assert_equal(math.isnan(5), false)
end)
test("math.isfinite", function()
assert_equal(math.isfinite(5), true)
assert_equal(math.isfinite(math.infinity), false)
end)
test("math.sign", function()
assert_equal(math.sign(5), 1)
assert_equal(math.sign(-5), -1)
assert_equal(math.sign(0), 0)
end)
test("math.clamp", function()
assert_equal(math.clamp(5, 0, 3), 3)
assert_equal(math.clamp(-1, 0, 3), 0)
assert_equal(math.clamp(2, 0, 3), 2)
end)
test("math.lerp", function()
assert_close(math.lerp(0, 10, 0.5), 5)
assert_close(math.lerp(2, 8, 0.25), 3.5)
end)
test("math.smoothstep", function()
assert_close(math.smoothstep(0, 1, 0.5), 0.5)
assert_close(math.smoothstep(0, 10, 5), 0.5)
end)
test("math.map", function()
assert_close(math.map(5, 0, 10, 0, 100), 50)
assert_close(math.map(2, 0, 4, 10, 20), 15)
end)
test("math.round", function()
assert_equal(math.round(2.7), 3)
assert_equal(math.round(-2.7), -3)
assert_equal(math.round(2.3), 2)
end)
test("math.roundto", function()
assert_close(math.roundto(3.14159, 2), 3.14)
assert_close(math.roundto(123.456, 1), 123.5)
end)
test("math.normalize_angle", function()
assert_close(math.normalize_angle(math.pi * 2.5), math.pi * 0.5)
assert_close(math.normalize_angle(-math.pi * 2.5), -math.pi * 0.5)
end)
test("math.distance", function()
assert_close(math.distance(0, 0, 3, 4), 5)
assert_close(math.distance(1, 1, 4, 5), 5)
end)
test("math.factorial", function()
assert_equal(math.factorial(5), 120)
assert_equal(math.factorial(0), 1)
assert_equal(math.factorial(-1), nil)
end)
test("math.gcd", function()
assert_equal(math.gcd(48, 18), 6)
assert_equal(math.gcd(100, 25), 25)
end)
test("math.lcm", function()
assert_equal(math.lcm(4, 6), 12)
assert_equal(math.lcm(15, 20), 60)
end)
-- Random functions tests
test("math.randomf", function()
local r1 = math.randomf(0, 1)
local r2 = math.randomf(5, 10)
assert(r1 >= 0 and r1 < 1)
assert(r2 >= 5 and r2 < 10)
end)
test("math.randint", function()
local i1 = math.randint(1, 10)
local i2 = math.randint(50, 60)
assert(i1 >= 1 and i1 <= 10)
assert(i2 >= 50 and i2 <= 60)
end)
test("math.randboolean", function()
local b1 = math.randboolean()
local b2 = math.randboolean(0.8)
assert_equal(type(b1), "boolean")
assert_equal(type(b2), "boolean")
end)
-- Statistics tests
test("math.sum", function()
assert_equal(math.sum({1, 2, 3, 4, 5}), 15)
assert_equal(math.sum({10, 20, 30}), 60)
end)
test("math.mean", function()
assert_equal(math.mean({1, 2, 3, 4, 5}), 3)
assert_equal(math.mean({10, 20, 30}), 20)
end)
test("math.median", function()
assert_equal(math.median({1, 2, 3, 4, 5}), 3)
assert_equal(math.median({1, 2, 3, 4}), 2.5)
end)
test("math.variance", function()
assert_close(math.variance({1, 2, 3, 4, 5}), 2)
assert_close(math.variance({10, 10, 10}), 0)
end)
test("math.stdev", function()
assert_close(math.stdev({1, 2, 3, 4, 5}), math.sqrt(2))
assert_close(math.stdev({10, 10, 10}), 0)
end)
test("math.mode", function()
assert_equal(math.mode({1, 2, 2, 3}), 2)
assert_equal(math.mode({5, 5, 4, 4, 4}), 4)
end)
test("math.minmax", function()
local min1, max1 = math.minmax({1, 2, 3, 4, 5})
local min2, max2 = math.minmax({-5, 0, 10})
assert_equal(min1, 1)
assert_equal(max1, 5)
assert_equal(min2, -5)
assert_equal(max2, 10)
end)
-- 2D Vector tests
test("math.vec2.new", function()
local v1 = math.vec2.new(3, 4)
local v2 = math.vec2.new()
assert_equal(v1.x, 3)
assert_equal(v1.y, 4)
assert_equal(v2.x, 0)
assert_equal(v2.y, 0)
end)
test("math.vec2.add", function()
local v1 = math.vec2.new(3, 4)
local v2 = math.vec2.new(1, 2)
local result = math.vec2.add(v1, v2)
assert_equal(result.x, 4)
assert_equal(result.y, 6)
end)
test("math.vec2.sub", function()
local v1 = math.vec2.new(3, 4)
local v2 = math.vec2.new(1, 2)
local result = math.vec2.sub(v1, v2)
assert_equal(result.x, 2)
assert_equal(result.y, 2)
end)
test("math.vec2.mul", function()
local v1 = math.vec2.new(3, 4)
local result1 = math.vec2.mul(v1, 2)
local result2 = math.vec2.mul(v1, math.vec2.new(2, 3))
assert_equal(result1.x, 6)
assert_equal(result1.y, 8)
assert_equal(result2.x, 6)
assert_equal(result2.y, 12)
end)
test("math.vec2.dot", function()
local v1 = math.vec2.new(3, 4)
local v2 = math.vec2.new(1, 2)
assert_equal(math.vec2.dot(v1, v2), 11)
assert_equal(math.vec2.dot(math.vec2.new(1, 0), math.vec2.new(0, 1)), 0)
end)
test("math.vec2.length", function()
local v1 = math.vec2.new(3, 4)
local v2 = math.vec2.new(0, 5)
assert_close(math.vec2.length(v1), 5)
assert_close(math.vec2.length(v2), 5)
end)
test("math.vec2.distance", function()
local v1 = math.vec2.new(0, 0)
local v2 = math.vec2.new(3, 4)
assert_close(math.vec2.distance(v1, v2), 5)
assert_close(math.vec2.distance(math.vec2.new(1, 1), math.vec2.new(4, 5)), 5)
end)
test("math.vec2.normalize", function()
local v1 = math.vec2.new(3, 4)
local normalized = math.vec2.normalize(v1)
assert_close(math.vec2.length(normalized), 1)
assert_close(normalized.x, 0.6)
assert_close(normalized.y, 0.8)
end)
test("math.vec2.rotate", function()
local v1 = math.vec2.new(1, 0)
local rotated90 = math.vec2.rotate(v1, math.pi/2)
local rotated180 = math.vec2.rotate(v1, math.pi)
assert_close(rotated90.x, 0, 1e-10)
assert_close(rotated90.y, 1)
assert_close(rotated180.x, -1)
assert_close(rotated180.y, 0, 1e-10)
end)
-- 3D Vector tests
test("math.vec3.new", function()
local v1 = math.vec3.new(1, 2, 3)
local v2 = math.vec3.new()
assert_equal(v1.x, 1)
assert_equal(v1.y, 2)
assert_equal(v1.z, 3)
assert_equal(v2.x, 0)
assert_equal(v2.y, 0)
assert_equal(v2.z, 0)
end)
test("math.vec3.cross", function()
local v1 = math.vec3.new(1, 0, 0)
local v2 = math.vec3.new(0, 1, 0)
local cross = math.vec3.cross(v1, v2)
assert_equal(cross.x, 0)
assert_equal(cross.y, 0)
assert_equal(cross.z, 1)
end)
test("math.vec3.length", function()
local v1 = math.vec3.new(1, 2, 3)
local v2 = math.vec3.new(0, 0, 5)
assert_close(math.vec3.length(v1), math.sqrt(14))
assert_close(math.vec3.length(v2), 5)
end)
-- Matrix tests
test("math.mat2.det", function()
local m1 = math.mat2.new(1, 2, 3, 4)
local m2 = math.mat2.new(2, 0, 0, 3)
assert_equal(math.mat2.det(m1), -2)
assert_equal(math.mat2.det(m2), 6)
end)
test("math.mat2.mul", function()
local m1 = math.mat2.new(1, 2, 3, 4)
local m2 = math.mat2.new(2, 0, 1, 3)
local product = math.mat2.mul(m1, m2)
assert_equal(product[1][1], 4)
assert_equal(product[1][2], 6)
assert_equal(product[2][1], 10)
assert_equal(product[2][2], 12)
end)
test("math.mat2.rotation", function()
local rot90 = math.mat2.rotation(math.pi/2)
local rot180 = math.mat2.rotation(math.pi)
assert_close(rot90[1][1], 0, 1e-10)
assert_close(rot90[1][2], -1)
assert_close(rot180[1][1], -1)
assert_close(rot180[2][2], -1)
end)
test("math.mat3.transform_point", function()
local translation = math.mat3.translation(5, 10)
local point1 = {x = 1, y = 2}
local point2 = {x = 0, y = 0}
local t1 = math.mat3.transform_point(translation, point1)
local t2 = math.mat3.transform_point(translation, point2)
assert_equal(t1.x, 6)
assert_equal(t1.y, 12)
assert_equal(t2.x, 5)
assert_equal(t2.y, 10)
end)
test("math.mat3.det", function()
local identity = math.mat3.identity()
local scale = math.mat3.scale(2, 3)
assert_close(math.mat3.det(identity), 1)
assert_close(math.mat3.det(scale), 6)
end)
-- Geometry tests
test("math.geometry.triangle_area", function()
local area1 = math.geometry.triangle_area(0, 0, 4, 0, 0, 3)
local area2 = math.geometry.triangle_area(0, 0, 2, 0, 1, 2)
assert_close(area1, 6)
assert_close(area2, 2)
end)
test("math.geometry.point_in_triangle", function()
local inside1 = math.geometry.point_in_triangle(1, 1, 0, 0, 4, 0, 0, 3)
local inside2 = math.geometry.point_in_triangle(5, 5, 0, 0, 4, 0, 0, 3)
assert_equal(inside1, true)
assert_equal(inside2, false)
end)
test("math.geometry.line_intersect", function()
local intersects1, x1, y1 = math.geometry.line_intersect(0, 0, 2, 2, 0, 2, 2, 0)
local intersects2, x2, y2 = math.geometry.line_intersect(0, 0, 1, 1, 2, 2, 3, 3)
assert_equal(intersects1, true)
assert_close(x1, 1)
assert_close(y1, 1)
assert_equal(intersects2, false)
end)
test("math.geometry.closest_point_on_segment", function()
local x1, y1 = math.geometry.closest_point_on_segment(1, 3, 0, 0, 4, 0)
local x2, y2 = math.geometry.closest_point_on_segment(5, 1, 0, 0, 4, 0)
assert_close(x1, 1)
assert_close(y1, 0)
assert_close(x2, 4)
assert_close(y2, 0)
end)
-- Interpolation tests
test("math.interpolation.bezier", function()
local result1 = math.interpolation.bezier(0.5, 0, 1, 2, 3)
local result2 = math.interpolation.bezier(0, 0, 1, 2, 3)
assert_close(result1, 1.5)
assert_close(result2, 0)
end)
test("math.interpolation.quadratic_bezier", function()
local result1 = math.interpolation.quadratic_bezier(0.5, 0, 2, 4)
local result2 = math.interpolation.quadratic_bezier(1, 0, 2, 4)
assert_close(result1, 2)
assert_close(result2, 4)
end)
test("math.interpolation.smootherstep", function()
local result1 = math.interpolation.smootherstep(0, 1, 0.5)
local result2 = math.interpolation.smootherstep(0, 10, 5)
assert_close(result1, 0.5)
assert_close(result2, 0.5)
end)
test("math.interpolation.catmull_rom", function()
local result1 = math.interpolation.catmull_rom(0.5, 0, 1, 2, 3)
local result2 = math.interpolation.catmull_rom(0, 0, 1, 2, 3)
assert_close(result1, 1.5)
assert_close(result2, 1)
end)
summary()
test_exit()

397
tests/sessions.lua Normal file
View File

@ -0,0 +1,397 @@
require("tests")
local sessions = require("sessions")
-- Clean up test files
os.remove("test_sessions.json")
os.remove("test_sessions2.json")
-- ======================================================================
-- SESSION STORE INITIALIZATION
-- ======================================================================
test("Session store initialization", function()
assert(sessions.init("test_sessions", "test_sessions.json"))
assert(sessions.init("memory_sessions"))
-- Test with explicit store name
assert(sessions.init("named_sessions", "test_sessions2.json"))
end)
test("Session ID generation", function()
local id1 = sessions.generate_id()
local id2 = sessions.generate_id()
assert_equal(32, #id1)
assert_equal(32, #id2)
assert(id1 ~= id2, "Session IDs should be unique")
-- Check character set (alphanumeric)
assert(id1:match("^[a-zA-Z0-9]+$"), "Session ID should be alphanumeric")
end)
-- ======================================================================
-- BASIC SESSION OPERATIONS
-- ======================================================================
test("Session creation and retrieval", function()
sessions.init("basic_test")
local session_id = sessions.generate_id()
local session_data = {
user_id = 123,
username = "testuser",
role = "admin",
permissions = {"read", "write"}
}
assert(sessions.create(session_id, session_data))
local retrieved = sessions.get(session_id)
assert_equal(123, retrieved.user_id)
assert_equal("testuser", retrieved.username)
assert_equal("admin", retrieved.role)
assert_table_equal({"read", "write"}, retrieved.permissions)
assert(retrieved._created ~= nil)
assert(retrieved._last_accessed ~= nil)
-- Test non-existent session
assert_equal(nil, sessions.get("nonexistent"))
end)
test("Session updates", function()
sessions.init("update_test")
local session_id = sessions.generate_id()
sessions.create(session_id, {count = 1})
local session = sessions.get(session_id)
local new_data = {
count = 2,
new_field = "added"
}
assert(sessions.update(session_id, new_data))
local updated = sessions.get(session_id)
assert_equal(2, updated.count)
assert_equal("added", updated.new_field)
assert(updated._created ~= nil)
assert(updated._last_accessed ~= nil)
end)
test("Session deletion", function()
sessions.init("delete_test")
local session_id = sessions.generate_id()
sessions.create(session_id, {temp = true})
assert(sessions.get(session_id) ~= nil)
assert(sessions.delete(session_id))
assert_equal(nil, sessions.get(session_id))
-- Delete non-existent session
assert(not sessions.delete("nonexistent"))
end)
test("Session existence check", function()
sessions.init("exists_test")
local session_id = sessions.generate_id()
assert(not sessions.exists(session_id))
sessions.create(session_id, {test = true})
assert(sessions.exists(session_id))
sessions.delete(session_id)
assert(not sessions.exists(session_id))
end)
-- ======================================================================
-- MULTI-STORE OPERATIONS
-- ======================================================================
test("Multiple session stores", function()
assert(sessions.init("store1", "store1.json"))
assert(sessions.init("store2", "store2.json"))
local id1 = sessions.generate_id()
local id2 = sessions.generate_id()
-- Create sessions in different stores
assert(sessions.create(id1, {store = "store1"}, "store1"))
assert(sessions.create(id2, {store = "store2"}, "store2"))
-- Verify isolation
assert_equal(nil, sessions.get(id1, "store2"))
assert_equal(nil, sessions.get(id2, "store1"))
-- Verify correct retrieval
local s1 = sessions.get(id1, "store1")
local s2 = sessions.get(id2, "store2")
assert_equal("store1", s1.store)
assert_equal("store2", s2.store)
os.remove("store1.json")
os.remove("store2.json")
end)
test("Default store behavior", function()
sessions.reset()
sessions.init("default_store")
local session_id = sessions.generate_id()
sessions.create(session_id, {test = "default"})
-- Should work without specifying store
local retrieved = sessions.get(session_id)
assert_equal("default", retrieved.test)
end)
-- ======================================================================
-- SESSION CLEANUP
-- ======================================================================
test("Session cleanup", function()
sessions.init("cleanup_test")
local old_session = sessions.generate_id()
local new_session = sessions.generate_id()
sessions.create(old_session, {test = "old"}, "cleanup_test")
sessions.create(new_session, {test = "new"}, "cleanup_test")
-- Wait to create age difference
os.execute("sleep 2")
-- Access new session to update timestamp
sessions.get(new_session, "cleanup_test")
-- Clean up sessions older than 1 second
local deleted = sessions.cleanup(1, "cleanup_test")
assert(deleted >= 1)
-- New session should remain (recently accessed)
assert(sessions.get(new_session, "cleanup_test") ~= nil)
end)
-- ======================================================================
-- SESSION LISTING AND COUNTING
-- ======================================================================
test("Session listing and counting", function()
sessions.init("list_test")
local ids = {}
for i = 1, 5 do
local id = sessions.generate_id()
sessions.create(id, {index = i}, "list_test")
table.insert(ids, id)
end
local session_list = sessions.list("list_test")
assert_equal(5, #session_list)
local count = sessions.count("list_test")
assert_equal(5, count)
-- Verify all IDs are in the list
local id_set = {}
for _, id in ipairs(session_list) do
id_set[id] = true
end
for _, id in ipairs(ids) do
assert(id_set[id], "Session ID should be in list")
end
end)
-- ======================================================================
-- OBJECT-ORIENTED INTERFACE
-- ======================================================================
test("OOP session store", function()
local store = sessions.create_store("oop_sessions", "oop_test.json")
local session_id = sessions.generate_id()
local data = {user = "oop_test", role = "admin"}
assert(store:create(session_id, data))
local retrieved = store:get(session_id)
assert_equal("oop_test", retrieved.user)
assert_equal("admin", retrieved.role)
retrieved.last_action = "login"
assert(store:update(session_id, retrieved))
local updated = store:get(session_id)
assert_equal("login", updated.last_action)
assert(store:exists(session_id))
assert_equal(1, store:count())
local session_list = store:list()
assert_equal(1, #session_list)
assert_equal(session_id, session_list[1])
assert(store:delete(session_id))
assert(not store:exists(session_id))
store:close()
os.remove("oop_test.json")
end)
-- ======================================================================
-- ERROR HANDLING
-- ======================================================================
test("Session error handling", function()
sessions.reset()
-- Try operations without initialization
local success1, err1 = pcall(sessions.create, "test_id", {})
assert(not success1)
local success2, err2 = pcall(sessions.get, "test_id")
assert(not success2)
-- Initialize and test invalid inputs
sessions.init("error_test")
-- Invalid session ID type
local success3, err3 = pcall(sessions.create, 123, {})
assert(not success3)
-- Invalid data type
local success4, err4 = pcall(sessions.create, "test", "not_a_table")
assert(not success4)
-- Invalid store name
local success5, err5 = pcall(sessions.get, "test", 123)
assert(not success5)
end)
-- ======================================================================
-- DATA PERSISTENCE
-- ======================================================================
test("Session persistence", function()
sessions.init("persist_test", "persist_sessions.json")
local session_id = sessions.generate_id()
local data = {
user_id = 789,
settings = {theme = "dark", lang = "en"},
cart = {items = {"item1", "item2"}, total = 25.99}
}
sessions.create(session_id, data, "persist_test")
sessions.close("persist_test")
-- Reinitialize and verify data persists
sessions.init("persist_test", "persist_sessions.json")
local retrieved = sessions.get(session_id, "persist_test")
assert_equal(789, retrieved.user_id)
assert_equal("dark", retrieved.settings.theme)
assert_equal("en", retrieved.settings.lang)
assert_equal(2, #retrieved.cart.items)
assert_equal(25.99, retrieved.cart.total)
sessions.close("persist_test")
os.remove("persist_sessions.json")
end)
-- ======================================================================
-- COMPLEX DATA STRUCTURES
-- ======================================================================
test("Complex session data", function()
sessions.init("complex_test")
local session_id = sessions.generate_id()
local complex_data = {
user = {
id = 456,
profile = {
name = "Jane Doe",
email = "jane@example.com",
preferences = {
notifications = true,
privacy = "friends_only"
}
}
},
activity = {
pages_visited = {"home", "profile", "settings"},
actions = {
{type = "login", time = os.time()},
{type = "view_page", page = "profile", time = os.time()}
}
},
metadata = {
ip = "192.168.1.1",
user_agent = "Test Browser 1.0"
}
}
sessions.create(session_id, complex_data)
local retrieved = sessions.get(session_id)
assert_equal(456, retrieved.user.id)
assert_equal("Jane Doe", retrieved.user.profile.name)
assert_equal(true, retrieved.user.profile.preferences.notifications)
assert_equal(3, #retrieved.activity.pages_visited)
assert_equal("login", retrieved.activity.actions[1].type)
assert_equal("192.168.1.1", retrieved.metadata.ip)
end)
-- ======================================================================
-- WORKFLOW INTEGRATION
-- ======================================================================
test("Session workflow integration", function()
sessions.init("workflow_test")
-- Simulate user workflow
local session_id = sessions.generate_id()
-- User login
sessions.create(session_id, {
user_id = 999,
username = "workflow_user",
status = "logged_in"
})
-- User adds items to cart
local session = sessions.get(session_id)
session.cart = {"item1", "item2"}
session.cart_total = 19.99
sessions.update(session_id, session)
-- User proceeds to checkout
session = sessions.get(session_id)
session.checkout_step = "payment"
session.payment_method = "credit_card"
sessions.update(session_id, session)
-- Verify final state
local final_session = sessions.get(session_id)
assert_equal(999, final_session.user_id)
assert_equal("logged_in", final_session.status)
assert_equal(2, #final_session.cart)
assert_equal(19.99, final_session.cart_total)
assert_equal("payment", final_session.checkout_step)
assert_equal("credit_card", final_session.payment_method)
-- User completes order and logs out
sessions.delete(session_id)
assert_equal(nil, sessions.get(session_id))
end)
-- Clean up test files
os.remove("test_sessions.json")
os.remove("test_sessions2.json")
summary()
test_exit()

778
tests/string.lua Normal file
View File

@ -0,0 +1,778 @@
require("tests")
-- Test data
local test_string = "Hello, World!"
local multi_line = "Line 1\nLine 2\nLine 3"
local padded_string = " Hello World "
local unicode_string = "café naïve résumé"
local mixed_case = "hELLo WoRLd"
-- ======================================================================
-- BASIC STRING OPERATIONS
-- ======================================================================
test("string.split", function()
-- Basic split
local parts = string.split("a,b,c,d", ",")
assert_equal(4, #parts)
assert_equal("a", parts[1])
assert_equal("d", parts[4])
-- Empty delimiter (character split)
local chars = string.split("abc", "")
assert_equal(3, #chars)
assert_equal("a", chars[1])
assert_equal("c", chars[3])
-- Empty string
local empty_parts = string.split("", ",")
assert_equal(1, #empty_parts)
assert_equal("", empty_parts[1])
-- No matches
local no_match = string.split("hello", ",")
assert_equal(1, #no_match)
assert_equal("hello", no_match[1])
-- Multiple consecutive delimiters
local multiple = string.split("a,,b,", ",")
assert_equal(4, #multiple)
assert_equal("a", multiple[1])
assert_equal("", multiple[2])
assert_equal("b", multiple[3])
assert_equal("", multiple[4])
end)
test("string.join", function()
assert_equal("a-b-c", string.join({"a", "b", "c"}, "-"))
assert_equal("abc", string.join({"a", "b", "c"}, ""))
assert_equal("", string.join({}, ","))
assert_equal("a", string.join({"a"}, ","))
-- Mixed types (should convert to string)
assert_equal("1,2,3", string.join({1, 2, 3}, ","))
end)
test("string.trim", function()
assert_equal("Hello World", string.trim(padded_string))
assert_equal("", string.trim(""))
assert_equal("", string.trim(" "))
assert_equal("hello", string.trim("hello"))
assert_equal("a b", string.trim(" a b "))
assert_equal("a b", string.trim("xxxa bxxx", "x"))
end)
test("string.trim_left", function()
assert_equal("Hello World ", string.trim_left(padded_string))
assert_equal("hello", string.trim_left("hello"))
assert_equal("", string.trim_left(""))
-- Custom cutset
assert_equal("Helloxxx", string.trim_left("xxxHelloxxx", "x"))
assert_equal("yHelloxxx", string.trim_left("xyHelloxxx", "x"))
assert_equal("", string.trim_left("xxxx", "x"))
end)
test("string.trim_right", function()
assert_equal(" Hello World", string.trim_right(padded_string))
assert_equal("hello", string.trim_right("hello"))
assert_equal("", string.trim_right(""))
-- Custom cutset
assert_equal("xxxHello", string.trim_right("xxxHelloxxx", "x"))
assert_equal("", string.trim_right("xxxx", "x"))
end)
test("string.upper", function()
assert_equal("HELLO", string.upper("hello"))
assert_equal("HELLO123!", string.upper("Hello123!"))
assert_equal("", string.upper(""))
assert_equal("ABC", string.upper("abc"))
end)
test("string.lower", function()
assert_equal("hello", string.lower("HELLO"))
assert_equal("hello123!", string.lower("HELLO123!"))
assert_equal("", string.lower(""))
assert_equal("abc", string.lower("ABC"))
end)
test("string.title", function()
assert_equal("Hello World", string.title("hello world"))
assert_equal("Hello World", string.title("HELLO WORLD"))
assert_equal("", string.title(""))
assert_equal("A", string.title("a"))
assert_equal("Test_Case", string.title("test_case"))
end)
test("string.contains", function()
assert_equal(true, string.contains(test_string, "World"))
assert_equal(false, string.contains(test_string, "world"))
assert_equal(true, string.contains(test_string, ""))
assert_equal(false, string.contains("", "a"))
assert_equal(true, string.contains("abc", "b"))
end)
test("string.starts_with", function()
assert_equal(true, string.starts_with(test_string, "Hello"))
assert_equal(false, string.starts_with(test_string, "hello"))
assert_equal(true, string.starts_with(test_string, ""))
assert_equal(false, string.starts_with("", "a"))
assert_equal(true, string.starts_with("hello", "h"))
end)
test("string.ends_with", function()
assert_equal(true, string.ends_with(test_string, "!"))
assert_equal(false, string.ends_with(test_string, "?"))
assert_equal(true, string.ends_with(test_string, ""))
assert_equal(false, string.ends_with("", "a"))
assert_equal(true, string.ends_with("hello", "o"))
end)
test("string.replace", function()
assert_equal("hi world hi", string.replace("hello world hello", "hello", "hi"))
assert_equal("hello", string.replace("hello", "xyz", "abc"))
assert_equal("", string.replace("hello", "hello", ""))
assert_equal("xyzxyz", string.replace("abcabc", "abc", "xyz"))
-- Special characters
assert_equal("h*llo", string.replace("hello", "e", "*"))
end)
test("string.replace_n", function()
assert_equal("hi world hello", string.replace_n("hello world hello", "hello", "hi", 1))
assert_equal("hi world hi", string.replace_n("hello world hello", "hello", "hi", 2))
assert_equal("hello world hello", string.replace_n("hello world hello", "hello", "hi", 0))
assert_equal("hello world hello", string.replace_n("hello world hello", "xyz", "hi", 5))
end)
test("string.index", function()
assert_equal(7, string.index("hello world", "world"))
assert_equal(nil, string.index("hello world", "xyz"))
assert_equal(1, string.index("hello", "h"))
assert_equal(5, string.index("hello", "o"))
assert_equal(1, string.index("hello", ""))
end)
test("string.last_index", function()
assert_equal(7, string.last_index("hello hello", "hello"))
assert_equal(nil, string.last_index("hello world", "xyz"))
assert_equal(1, string.last_index("hello", "h"))
assert_equal(9, string.last_index("hello o o", "o"))
end)
test("string.count", function()
assert_equal(3, string.count("hello hello hello", "hello"))
assert_equal(0, string.count("hello world", "xyz"))
assert_equal(2, string.count("hello", "l"))
assert_equal(6, string.count("hello", ""))
end)
test("string.repeat_", function()
assert_equal("abcabcabc", string.repeat_("abc", 3))
assert_equal("", string.repeat_("x", 0))
assert_equal("x", string.repeat_("x", 1))
assert_equal("", string.repeat_("", 5))
end)
test("string.reverse", function()
assert_equal("olleh", string.reverse("hello"))
assert_equal("", string.reverse(""))
assert_equal("a", string.reverse("a"))
assert_equal("dcba", string.reverse("abcd"))
-- Test with Go fallback for longer strings
local long_str = string.rep("abc", 50)
local reversed = string.reverse(long_str)
assert_equal(string.length(long_str), string.length(reversed))
end)
test("string.length", function()
assert_equal(5, string.length("hello"))
assert_equal(0, string.length(""))
assert_equal(1, string.length("a"))
-- Unicode length
assert_equal(4, string.length("café"))
assert_equal(17, string.length(unicode_string))
end)
test("string.byte_length", function()
assert_equal(5, string.byte_length("hello"))
assert_equal(0, string.byte_length(""))
assert_equal(1, string.byte_length("a"))
-- Unicode byte length (UTF-8)
assert_equal(5, string.byte_length("café")) -- é is 2 bytes
assert_equal(21, string.byte_length(unicode_string)) -- accented chars are 2 bytes each
end)
test("string.lines", function()
local lines = string.lines(multi_line)
assert_equal(3, #lines)
assert_equal("Line 1", lines[1])
assert_equal("Line 3", lines[3])
-- Empty string
assert_table_equal({""}, string.lines(""))
-- Different line endings
assert_table_equal({"a", "b"}, string.lines("a\nb"))
assert_table_equal({"a", "b"}, string.lines("a\r\nb"))
assert_table_equal({"a", "b"}, string.lines("a\rb"))
-- Trailing newline
assert_table_equal({"a", "b"}, string.lines("a\nb\n"))
end)
test("string.words", function()
local words = string.words("Hello world test")
assert_equal(3, #words)
assert_equal("Hello", words[1])
assert_equal("test", words[3])
-- Extra whitespace
assert_table_equal({"Hello", "world"}, string.words(" Hello world "))
-- Empty string
assert_table_equal({}, string.words(""))
assert_table_equal({}, string.words(" "))
-- Single word
assert_table_equal({"hello"}, string.words("hello"))
end)
test("string.pad_left", function()
assert_equal(" hi", string.pad_left("hi", 5))
assert_equal("000hi", string.pad_left("hi", 5, "0"))
assert_equal("hello", string.pad_left("hello", 3))
assert_equal("hi", string.pad_left("hi", 2))
assert_equal("hi", string.pad_left("hi", 0))
-- Unicode padding
assert_equal(" café", string.pad_left("café", 6))
end)
test("string.pad_right", function()
assert_equal("hi ", string.pad_right("hi", 5))
assert_equal("hi***", string.pad_right("hi", 5, "*"))
assert_equal("hello", string.pad_right("hello", 3))
assert_equal("hi", string.pad_right("hi", 2))
assert_equal("hi", string.pad_right("hi", 0))
-- Unicode padding
assert_equal("café ", string.pad_right("café", 6))
end)
test("string.slice", function()
assert_equal("ell", string.slice("hello", 2, 4))
assert_equal("ello", string.slice("hello", 2))
assert_equal("", string.slice("hello", 10))
assert_equal("h", string.slice("hello", 1, 1))
assert_equal("", string.slice("hello", 3, 2))
-- Negative indices
assert_equal("lo", string.slice("hello", 4, -1))
-- Unicode slicing
assert_equal("afé", string.slice("café", 2, 4))
end)
-- ======================================================================
-- REGULAR EXPRESSIONS
-- ======================================================================
test("string.match", function()
assert_equal(true, string.match("hello123", "%d+") ~= nil)
assert_equal(false, string.match("hello", "%d+") ~= nil)
assert_equal(true, string.match("hello", "^[a-z]+$") ~= nil)
assert_equal(false, string.match("Hello", "^[a-z]+$") ~= nil)
assert_equal(true, string.match("testing", "test") ~= nil)
end)
test("string.find_match", function()
assert_equal("123", string.find_match("hello123world", "%d+"))
assert_equal(nil, string.find_match("hello", "%d+"))
assert_equal("test", string.find_match("testing", "test"))
assert_equal("world", string.find_match("hello world", "world"))
end)
test("string.find_all", function()
local matches = string.find_all("123 and 456 and 789", "%d+")
assert_equal(3, #matches)
assert_equal("123", matches[1])
assert_equal("789", matches[3])
-- No matches
assert_table_equal({}, string.find_all("hello", "%d+"))
-- Multiple matches
local overlaps = string.find_all("test test test", "test")
assert_equal(3, #overlaps)
end)
test("string.gsub", function()
assert_equal("helloXXXworldXXX", (string.gsub("hello123world456", "%d+", "XXX")))
assert_equal("hello world", (string.gsub("hello world", "%s+", " ")))
assert_equal("abc abc abc", (string.gsub("test abc test", "test", "abc")))
-- No matches
assert_equal("hello", (string.gsub("hello", "%d+", "XXX")))
end)
-- ======================================================================
-- TYPE CONVERSION & VALIDATION
-- ======================================================================
test("string.to_number", function()
assert_equal(123, string.to_number("123"))
assert_equal(123.45, string.to_number("123.45"))
assert_equal(-42, string.to_number("-42"))
assert_equal(nil, string.to_number("not_a_number"))
assert_equal(nil, string.to_number(""))
assert_equal(42, string.to_number(" 42 "))
end)
test("string.is_numeric", function()
assert_equal(true, string.is_numeric("123"))
assert_equal(true, string.is_numeric("123.45"))
assert_equal(true, string.is_numeric("-42"))
assert_equal(false, string.is_numeric("abc"))
assert_equal(false, string.is_numeric(""))
assert_equal(true, string.is_numeric(" 42 "))
end)
test("string.is_alpha", function()
assert_equal(true, string.is_alpha("hello"))
assert_equal(false, string.is_alpha("hello123"))
assert_equal(false, string.is_alpha(""))
assert_equal(false, string.is_alpha("hello!"))
assert_equal(true, string.is_alpha("ABC"))
end)
test("string.is_alphanumeric", function()
assert_equal(true, string.is_alphanumeric("hello123"))
assert_equal(false, string.is_alphanumeric("hello!"))
assert_equal(false, string.is_alphanumeric(""))
assert_equal(true, string.is_alphanumeric("ABC123"))
assert_equal(true, string.is_alphanumeric("hello"))
end)
test("string.is_empty", function()
assert_equal(true, string.is_empty(""))
assert_equal(true, string.is_empty(nil))
assert_equal(false, string.is_empty("hello"))
assert_equal(false, string.is_empty(" "))
end)
test("string.is_blank", function()
assert_equal(true, string.is_blank(""))
assert_equal(true, string.is_blank(" "))
assert_equal(true, string.is_blank(nil))
assert_equal(false, string.is_blank("hello"))
assert_equal(false, string.is_blank(" a "))
end)
test("string.is_utf8", function()
assert_equal(true, string.is_utf8("hello"))
assert_equal(true, string.is_utf8("café"))
assert_equal(true, string.is_utf8(""))
assert_equal(true, string.is_utf8(unicode_string))
end)
-- ======================================================================
-- ADVANCED STRING OPERATIONS
-- ======================================================================
test("string.capitalize", function()
assert_equal("Hello World", string.capitalize("hello world"))
assert_equal("Hello", string.capitalize("hello"))
assert_equal("", string.capitalize(""))
assert_equal("A", string.capitalize("a"))
end)
test("string.camel_case", function()
assert_equal("helloWorld", string.camel_case("hello world"))
assert_equal("hello", string.camel_case("hello"))
assert_equal("", string.camel_case(""))
assert_equal("testCaseExample", string.camel_case("test case example"))
end)
test("string.pascal_case", function()
assert_equal("HelloWorld", string.pascal_case("hello world"))
assert_equal("Hello", string.pascal_case("hello"))
assert_equal("", string.pascal_case(""))
assert_equal("TestCaseExample", string.pascal_case("test case example"))
end)
test("string.snake_case", function()
assert_equal("hello_world", string.snake_case("Hello World"))
assert_equal("hello", string.snake_case("hello"))
assert_equal("", string.snake_case(""))
assert_equal("test_case_example", string.snake_case("Test Case Example"))
end)
test("string.kebab_case", function()
assert_equal("hello-world", string.kebab_case("Hello World"))
assert_equal("hello", string.kebab_case("hello"))
assert_equal("", string.kebab_case(""))
assert_equal("test-case-example", string.kebab_case("Test Case Example"))
end)
test("string.screaming_snake_case", function()
assert_equal("HELLO_WORLD", string.screaming_snake_case("hello world"))
assert_equal("HELLO", string.screaming_snake_case("hello"))
assert_equal("", string.screaming_snake_case(""))
assert_equal("TEST_CASE", string.screaming_snake_case("test case"))
end)
test("string.center", function()
assert_equal(" hi ", string.center("hi", 6))
assert_equal("**hi***", string.center("hi", 7, "*"))
assert_equal("hello", string.center("hello", 3))
assert_equal("hi", string.center("hi", 2))
assert_equal(" hi ", string.center("hi", 4))
end)
test("string.truncate", function()
assert_equal("hello...", string.truncate("hello world", 8))
assert_equal("hello>>", string.truncate("hello world", 8, ">>"))
assert_equal("hi", string.truncate("hi", 10))
assert_equal("...", string.truncate("hello", 3))
assert_equal("h>", string.truncate("hello", 2, ">"))
end)
test("string.wrap", function()
local wrapped = string.wrap("The quick brown fox jumps over the lazy dog", 10)
assert_equal("table", type(wrapped))
assert(#wrapped > 1, "should wrap into multiple lines")
-- Each line should be within limit
for _, line in ipairs(wrapped) do
assert(string.length(line) <= 10, "line should be within width limit: " .. line)
end
-- Empty string
assert_table_equal({""}, string.wrap("", 10))
-- Single word longer than width
local long_word = string.wrap("supercalifragilisticexpialidocious", 10)
assert_equal(1, #long_word)
end)
test("string.dedent", function()
local indented = " line1\n line2\n line3"
local dedented = string.dedent(indented)
local lines = string.lines(dedented)
assert_equal("line1", lines[1])
assert_equal("line2", lines[2])
assert_equal("line3", lines[3])
-- Mixed indentation
local mixed = " a\n b\n c"
local mixed_result = string.dedent(mixed)
local mixed_lines = string.lines(mixed_result)
assert_equal("a", mixed_lines[1])
assert_equal(" b", mixed_lines[2])
assert_equal("c", mixed_lines[3])
end)
test("string.escape", function()
assert_equal("hello%.world", string.escape("hello.world"))
assert_equal("a%+b%*c%?", string.escape("a+b*c?"))
assert_equal("%[%]%(%)", string.escape("[]()"))
assert_equal("hello", string.escape("hello"))
end)
test("string.shell_quote", function()
assert_equal("'hello world'", string.shell_quote("hello world"))
assert_equal("'it'\"'\"'s great'", string.shell_quote("it's great"))
assert_equal("hello", string.shell_quote("hello"))
assert_equal("hello-world.txt", string.shell_quote("hello-world.txt"))
end)
test("string.url_encode", function()
assert_equal("hello%20world", string.url_encode("hello world"))
assert_equal("hello", string.url_encode("hello"))
assert_equal("hello%21%40%23", string.url_encode("hello!@#"))
end)
test("string.url_decode", function()
assert_equal("hello world", string.url_decode("hello%20world"))
assert_equal("hello world", string.url_decode("hello+world"))
assert_equal("hello", string.url_decode("hello"))
assert_equal("hello!@#", string.url_decode("hello%21%40%23"))
-- Round trip
local original = "hello world!@#"
local encoded = string.url_encode(original)
assert_equal(original, string.url_decode(encoded))
end)
test("string.template", function()
local context = {
user = {name = "Jane", role = "admin"},
count = 5
}
local template = "User ${user.name} (${user.role}) has ${count} items"
local result = string.template(template, context)
assert_equal("User Jane (admin) has 5 items", result)
-- Missing variables
local incomplete = string.template("Hello ${name} and ${unknown}", {name = "John"})
assert_equal("Hello John and ", incomplete)
-- Missing nested property
local missing = string.template("${user.missing}", context)
assert_equal("", missing)
-- No variables
assert_equal("Hello world", string.template("Hello world", {}))
end)
test("string.random", function()
local random1 = string.random(10)
local random2 = string.random(10)
assert_equal(10, string.length(random1))
assert_equal(10, string.length(random2))
assert(random1 ~= random2, "random strings should be different")
-- Custom charset
local custom = string.random(5, "abc")
assert_equal(5, string.length(custom))
-- Zero length
assert_equal("", string.random(0))
end)
test("string.slug", function()
assert_equal("hello-world", string.slug("Hello World"))
assert_equal("cafe-restaurant", string.slug("Café & Restaurant"))
assert_equal("specialcharacters", string.slug("Special!@#$%Characters"))
assert_equal("", string.slug(""))
assert_equal("test", string.slug("test"))
end)
test("string.iequals", function()
assert_equal(true, string.iequals("Hello", "HELLO"))
assert_equal(true, string.iequals("hello", "hello"))
assert_equal(false, string.iequals("Hello", "world"))
assert_equal(true, string.iequals("", ""))
end)
test("string.is_whitespace", function()
assert_equal(true, string.is_whitespace(" "))
assert_equal(true, string.is_whitespace(""))
assert_equal(true, string.is_whitespace("\t\n\r "))
assert_equal(false, string.is_whitespace("hello"))
assert_equal(false, string.is_whitespace(" a "))
end)
test("string.strip_whitespace", function()
assert_equal("hello", string.strip_whitespace("h e l l o"))
assert_equal("helloworld", string.strip_whitespace("hello world"))
assert_equal("", string.strip_whitespace(" "))
assert_equal("abc", string.strip_whitespace("a\tb\nc"))
end)
test("string.normalize_whitespace", function()
assert_equal("hello world test", string.normalize_whitespace("hello world test"))
assert_equal("a b c", string.normalize_whitespace(" a b c "))
assert_equal("", string.normalize_whitespace(" "))
assert_equal("hello", string.normalize_whitespace("hello"))
end)
test("string.extract_numbers", function()
local numbers = string.extract_numbers("The price is $123.45 and tax is 8.5%")
assert_equal(2, #numbers)
assert_close(123.45, numbers[1])
assert_close(8.5, numbers[2])
local negative = string.extract_numbers("Temperature: -15.5 degrees")
assert_equal(1, #negative)
assert_close(-15.5, negative[1])
-- No numbers
assert_table_equal({}, string.extract_numbers("hello world"))
end)
test("string.remove_accents", function()
assert_equal("cafe", string.remove_accents("café"))
assert_equal("resume", string.remove_accents("résumé"))
assert_equal("naive", string.remove_accents("naïve"))
assert_equal("hello", string.remove_accents("hello"))
assert_equal("", string.remove_accents(""))
-- Mixed case
assert_equal("Cafe", string.remove_accents("Café"))
end)
-- ======================================================================
-- EDGE CASES AND ERROR HANDLING
-- ======================================================================
test("Empty String Handling", function()
assert_table_equal({""}, string.split("", ","))
assert_equal("", string.join({}, ","))
assert_equal("", string.trim(""))
assert_equal("", string.reverse(""))
assert_equal("", string.repeat_("", 5))
assert_table_equal({""}, string.lines(""))
assert_table_equal({}, string.words(""))
assert_equal("", string.slice("", 1, 5))
assert_equal("", string.slug(""))
end)
test("Large String Handling", function()
local large_string = string.rep("test ", 1000)
assert_equal(5000, string.length(large_string))
assert_equal(1000, string.count(large_string, "test"))
local words = string.words(large_string)
assert_equal(1000, #words)
local trimmed = string.trim(large_string)
assert_equal(true, string.ends_with(trimmed, "test"))
-- Should use Go for reverse on large strings
local reversed = string.reverse(large_string)
assert_equal(string.length(large_string), string.length(reversed))
end)
test("Unicode Handling", function()
local unicode_str = "Hello 🌍 World 🚀"
assert_equal(true, string.contains(unicode_str, "🌍"))
assert_equal(true, string.starts_with(unicode_str, "Hello"))
assert_equal(true, string.ends_with(unicode_str, "🚀"))
local parts = string.split(unicode_str, " ")
assert_equal(4, #parts)
assert_equal("🌍", parts[2])
-- Length should count Unicode characters, not bytes
assert_equal(15, string.length(unicode_str))
assert(string.byte_length(unicode_str) > 15, "byte length should be larger")
end)
test("Boundary Conditions", function()
-- Index boundary tests
assert_equal("", string.slice("hello", 6))
assert_equal("", string.slice("hello", 1, 0))
assert_equal("hello", string.slice("hello", 1, 100))
-- Padding boundary tests
assert_equal("hi", string.pad_left("hi", 0))
assert_equal("hi", string.pad_right("hi", 1))
-- Count boundary tests
assert_equal(0, string.count("", "a"))
assert_equal(1, string.count("", ""))
-- Replace boundary tests
assert_equal("", string.replace_n("hello", "hello", "", 1))
assert_equal("hello", string.replace_n("hello", "x", "y", 0))
end)
-- ======================================================================
-- INTEGRATION TESTS
-- ======================================================================
test("String Processing Pipeline", function()
local messy_input = " HELLO, world! How ARE you? "
-- Clean and normalize
local cleaned = string.normalize_whitespace(string.trim(messy_input))
local lowered = string.lower(cleaned)
local words = string.words(lowered)
local filtered = {}
for _, word in ipairs(words) do
local clean_word = string.gsub(word, "[%p]", "") -- Remove punctuation using Lua pattern
if string.length(clean_word) > 2 then
table.insert(filtered, clean_word)
end
end
local result = string.join(filtered, "-")
assert_equal("hello-world-how-are-you", result)
end)
test("Text Analysis", function()
local text = "The quick brown fox jumps over the lazy dog. The dog was sleeping."
local word_count = #string.words(text)
local sentence_count = string.count(text, ".")
local the_count = string.count(string.lower(text), "the")
assert_equal(13, word_count)
assert_equal(2, sentence_count)
assert_equal(3, the_count)
-- Template processing
local template = "Found ${word_count} words and ${the_count} instances of 'the'"
local vars = {word_count = word_count, the_count = the_count}
local summary = string.template(template, vars)
assert_equal("Found 13 words and 3 instances of 'the'", summary)
end)
test("Case Conversion Chain", function()
local original = "Hello World Test Case"
-- Test conversion chain
local snake = string.snake_case(original)
local camel = string.camel_case(original)
local pascal = string.pascal_case(original)
local kebab = string.kebab_case(original)
local screaming = string.screaming_snake_case(original)
assert_equal("hello_world_test_case", snake)
assert_equal("helloWorldTestCase", camel)
assert_equal("HelloWorldTestCase", pascal)
assert_equal("hello-world-test-case", kebab)
assert_equal("HELLO_WORLD_TEST_CASE", screaming)
-- Convert back should be similar
local back_to_words = string.split(snake, "_")
local rejoined = string.join(back_to_words, " ")
local capitalized = string.title(rejoined)
assert_equal("Hello World Test Case", capitalized)
end)
-- ======================================================================
-- PERFORMANCE TESTS
-- ======================================================================
test("Performance Characteristics", function()
local large_text = string.rep("The quick brown fox jumps over the lazy dog. ", 1000)
-- Test that operations complete in reasonable time
local start = os.clock()
local words = string.words(large_text)
local lines = string.lines(large_text)
local replaced = string.replace(large_text, "fox", "cat")
local parts = string.split(large_text, " ")
local reversed = string.reverse(large_text)
local elapsed = os.clock() - start
-- Verify results
assert(#words > 8000, "should extract many words")
assert(string.contains(replaced, "cat"), "replacement should work")
assert(#parts > 8000, "should split into many parts")
assert_equal(string.length(large_text), string.length(reversed))
print(string.format(" Processed %d characters in %.3fs", string.length(large_text), elapsed))
-- Performance should be reasonable (< 1 second for this test)
assert(elapsed < 1.0, "operations should complete quickly")
end)
summary()
test_exit()

911
tests/table.lua Normal file
View File

@ -0,0 +1,911 @@
require("tests")
-- Test data
local simple_array = {1, 2, 3, 4, 5}
local simple_table = {a = 1, b = 2, c = 3}
local mixed_table = {1, 2, a = "hello", b = "world"}
local nested_table = {
a = {x = 1, y = 2},
b = {x = 3, y = 4},
c = {1, 2, 3}
}
-- ======================================================================
-- BUILT-IN TABLE FUNCTIONS
-- ======================================================================
test("Table Insert Operations", function()
local t = {1, 2, 3}
table.insert(t, 4)
assert_equal(4, #t)
assert_equal(4, t[4])
table.insert(t, 2, "inserted")
assert_equal(5, #t)
assert_equal("inserted", t[2])
assert_equal(2, t[3])
end)
test("Table Remove Operations", function()
local t = {1, 2, 3, 4, 5}
local removed = table.remove(t)
assert_equal(5, removed)
assert_equal(4, #t)
removed = table.remove(t, 2)
assert_equal(2, removed)
assert_equal(3, #t)
assert_equal(3, t[2])
end)
test("Table Concat", function()
local t = {"hello", "world", "test"}
assert_equal("helloworldtest", table.concat(t))
assert_equal("hello,world,test", table.concat(t, ","))
assert_equal("world,test", table.concat(t, ",", 2))
assert_equal("world", table.concat(t, ",", 2, 2))
end)
test("Table Sort", function()
local t = {3, 1, 4, 1, 5}
table.sort(t)
assert_table_equal({1, 1, 3, 4, 5}, t)
local t2 = {"c", "a", "b"}
table.sort(t2)
assert_table_equal({"a", "b", "c"}, t2)
local t3 = {3, 1, 4, 1, 5}
table.sort(t3, function(a, b) return a > b end)
assert_table_equal({5, 4, 3, 1, 1}, t3)
end)
-- ======================================================================
-- BASIC TABLE OPERATIONS
-- ======================================================================
test("Table Length and Size", function()
assert_equal(5, table.length(simple_array))
assert_equal(0, table.length({}))
assert_equal(3, table.size(simple_table))
assert_equal(4, table.size(mixed_table))
assert_equal(0, table.size({}))
end)
test("Table Empty Check", function()
assert_equal(true, table.is_empty({}))
assert_equal(false, table.is_empty(simple_array))
assert_equal(false, table.is_empty(simple_table))
end)
test("Table Array Check", function()
assert_equal(true, table.is_array(simple_array))
assert_equal(true, table.is_array({}))
assert_equal(false, table.is_array(simple_table))
assert_equal(false, table.is_array(mixed_table))
assert_equal(true, table.is_array({1, 2, 3}))
assert_equal(false, table.is_array({1, 2, nil, 4}))
assert_equal(false, table.is_array({[0] = 1, [1] = 2}))
end)
test("Table Clear", function()
local t = table.clone(simple_table)
table.clear(t)
assert_equal(true, table.is_empty(t))
end)
test("Table Clone", function()
local cloned = table.clone(simple_table)
assert_table_equal(simple_table, cloned)
-- Modify original shouldn't affect clone
simple_table.new_key = "test"
assert_equal(nil, cloned.new_key)
simple_table.new_key = nil
end)
test("Table Deep Copy", function()
local copied = table.deep_copy(nested_table)
assert_table_equal(nested_table, copied)
-- Modify nested part shouldn't affect copy
nested_table.a.z = 99
assert_equal(nil, copied.a.z)
nested_table.a.z = nil
end)
-- ======================================================================
-- SEARCHING AND FINDING
-- ======================================================================
test("Table Contains", function()
assert_equal(true, table.contains(simple_array, 3))
assert_equal(false, table.contains(simple_array, 6))
assert_equal(true, table.contains(simple_table, 2))
assert_equal(false, table.contains(simple_table, "hello"))
end)
test("Table Index Of", function()
assert_equal(3, table.index_of(simple_array, 3))
assert_equal(nil, table.index_of(simple_array, 6))
assert_equal("b", table.index_of(simple_table, 2))
assert_equal(nil, table.index_of(simple_table, "hello"))
end)
test("Table Find", function()
local value, key = table.find(simple_array, function(v) return v > 3 end)
assert_equal(4, value)
assert_equal(4, key)
local value2, key2 = table.find(simple_table, function(v, k) return k == "b" end)
assert_equal(2, value2)
assert_equal("b", key2)
local value3 = table.find(simple_array, function(v) return v > 10 end)
assert_equal(nil, value3)
end)
test("Table Find Index", function()
local idx = table.find_index(simple_array, function(v) return v > 3 end)
assert_equal(4, idx)
local idx2 = table.find_index(simple_table, function(v, k) return k == "c" end)
assert_equal("c", idx2)
local idx3 = table.find_index(simple_array, function(v) return v > 10 end)
assert_equal(nil, idx3)
end)
test("Table Count", function()
local arr = {1, 2, 3, 2, 4, 2}
assert_equal(3, table.count(arr, 2))
assert_equal(0, table.count(arr, 5))
assert_equal(2, table.count(arr, function(v) return v > 2 end))
assert_equal(2, table.count(arr, function(v) return v == 1 or v == 4 end))
end)
-- ======================================================================
-- FILTERING AND MAPPING
-- ======================================================================
test("Table Filter", function()
local evens = table.filter(simple_array, function(v) return v % 2 == 0 end)
assert_table_equal({2, 4}, evens)
local filtered_table = table.filter(simple_table, function(v) return v > 1 end)
assert_equal(2, table.size(filtered_table))
assert_equal(2, filtered_table.b)
assert_equal(3, filtered_table.c)
end)
test("Table Reject", function()
local odds = table.reject(simple_array, function(v) return v % 2 == 0 end)
assert_table_equal({1, 3, 5}, odds)
end)
test("Table Map", function()
local doubled = table.map(simple_array, function(v) return v * 2 end)
assert_table_equal({2, 4, 6, 8, 10}, doubled)
local mapped_table = table.map(simple_table, function(v) return v + 10 end)
assert_equal(11, mapped_table.a)
assert_equal(12, mapped_table.b)
assert_equal(13, mapped_table.c)
end)
test("Table Map Values", function()
local incremented = table.map_values(simple_table, function(v) return v + 1 end)
assert_equal(2, incremented.a)
assert_equal(3, incremented.b)
assert_equal(4, incremented.c)
end)
test("Table Map Keys", function()
local prefixed = table.map_keys(simple_table, function(k) return "key_" .. k end)
assert_equal(1, prefixed.key_a)
assert_equal(2, prefixed.key_b)
assert_equal(3, prefixed.key_c)
assert_equal(nil, prefixed.a)
end)
-- ======================================================================
-- REDUCING AND AGGREGATING
-- ======================================================================
test("Table Reduce", function()
local sum = table.reduce(simple_array, function(acc, v) return acc + v end)
assert_equal(15, sum)
local sum_with_initial = table.reduce(simple_array, function(acc, v) return acc + v end, 10)
assert_equal(25, sum_with_initial)
local product = table.reduce({2, 3, 4}, function(acc, v) return acc * v end)
assert_equal(24, product)
end)
test("Table Fold", function()
local sum = table.fold(simple_array, function(acc, v) return acc + v end, 0)
assert_equal(15, sum)
local concatenated = table.fold({"a", "b", "c"}, function(acc, v) return acc .. v end, "")
assert_equal("abc", concatenated)
end)
test("Table Math Operations", function()
assert_equal(15, table.sum(simple_array))
assert_equal(120, table.product(simple_array))
assert_equal(1, table.min(simple_array))
assert_equal(5, table.max(simple_array))
assert_equal(3, table.average(simple_array))
local floats = {1.5, 2.5, 3.0}
assert_close(7.0, table.sum(floats))
assert_close(2.33333, table.average(floats), 0.001)
end)
-- ======================================================================
-- BOOLEAN OPERATIONS
-- ======================================================================
test("Table All", function()
assert_equal(true, table.all({true, true, true}))
assert_equal(false, table.all({true, false, true}))
assert_equal(true, table.all({}))
assert_equal(true, table.all(simple_array, function(v) return v > 0 end))
assert_equal(false, table.all(simple_array, function(v) return v > 3 end))
end)
test("Table Any", function()
assert_equal(true, table.any({false, true, false}))
assert_equal(false, table.any({false, false, false}))
assert_equal(false, table.any({}))
assert_equal(true, table.any(simple_array, function(v) return v > 3 end))
assert_equal(false, table.any(simple_array, function(v) return v > 10 end))
end)
test("Table None", function()
assert_equal(true, table.none({false, false, false}))
assert_equal(false, table.none({false, true, false}))
assert_equal(true, table.none({}))
assert_equal(false, table.none(simple_array, function(v) return v > 3 end))
assert_equal(true, table.none(simple_array, function(v) return v > 10 end))
end)
-- ======================================================================
-- SET OPERATIONS
-- ======================================================================
test("Table Unique", function()
local duplicates = {1, 2, 2, 3, 3, 3, 4}
local unique = table.unique(duplicates)
assert_table_equal({1, 2, 3, 4}, unique)
local empty_unique = table.unique({})
assert_table_equal({}, empty_unique)
end)
test("Table Intersection", function()
local arr1 = {1, 2, 3, 4}
local arr2 = {3, 4, 5, 6}
local intersect = table.intersection(arr1, arr2)
assert_equal(2, #intersect)
assert_equal(true, table.contains(intersect, 3))
assert_equal(true, table.contains(intersect, 4))
end)
test("Table Union", function()
local arr1 = {1, 2, 3}
local arr2 = {3, 4, 5}
local union = table.union(arr1, arr2)
assert_equal(5, #union)
for i = 1, 5 do
assert_equal(true, table.contains(union, i))
end
end)
test("Table Difference", function()
local arr1 = {1, 2, 3, 4, 5}
local arr2 = {3, 4}
local diff = table.difference(arr1, arr2)
assert_table_equal({1, 2, 5}, diff)
end)
-- ======================================================================
-- ARRAY OPERATIONS
-- ======================================================================
test("Table Reverse", function()
local reversed = table.reverse(simple_array)
assert_table_equal({5, 4, 3, 2, 1}, reversed)
local single = table.reverse({42})
assert_table_equal({42}, single)
local empty = table.reverse({})
assert_table_equal({}, empty)
end)
test("Table Shuffle", function()
local shuffled = table.shuffle(simple_array)
assert_equal(5, #shuffled)
-- All original elements should still be present
for _, v in ipairs(simple_array) do
assert_equal(true, table.contains(shuffled, v))
end
-- Should be same length
assert_equal(#simple_array, #shuffled)
end)
test("Table Rotate", function()
local arr = {1, 2, 3, 4, 5}
local rotated_right = table.rotate(arr, 2)
assert_table_equal({4, 5, 1, 2, 3}, rotated_right)
local rotated_left = table.rotate(arr, -2)
assert_table_equal({3, 4, 5, 1, 2}, rotated_left)
local no_rotation = table.rotate(arr, 0)
assert_table_equal(arr, no_rotation)
local full_rotation = table.rotate(arr, 5)
assert_table_equal(arr, full_rotation)
end)
test("Table Slice", function()
local sliced = table.slice(simple_array, 2, 4)
assert_table_equal({2, 3, 4}, sliced)
local from_start = table.slice(simple_array, 1, 3)
assert_table_equal({1, 2, 3}, from_start)
local to_end = table.slice(simple_array, 3)
assert_table_equal({3, 4, 5}, to_end)
local negative_indices = table.slice(simple_array, -3, -1)
assert_table_equal({3, 4, 5}, negative_indices)
end)
test("Table Splice", function()
local arr = {1, 2, 3, 4, 5}
-- Remove elements
local removed = table.splice(arr, 2, 2)
assert_table_equal({2, 3}, removed)
assert_table_equal({1, 4, 5}, arr)
-- Insert elements
arr = {1, 2, 3, 4, 5}
removed = table.splice(arr, 3, 0, "a", "b")
assert_table_equal({}, removed)
assert_table_equal({1, 2, "a", "b", 3, 4, 5}, arr)
-- Replace elements
arr = {1, 2, 3, 4, 5}
removed = table.splice(arr, 2, 2, "x", "y", "z")
assert_table_equal({2, 3}, removed)
assert_table_equal({1, "x", "y", "z", 4, 5}, arr)
end)
-- ======================================================================
-- SORTING HELPERS
-- ======================================================================
test("Table Sort By", function()
local people = {
{name = "Alice", age = 30},
{name = "Bob", age = 25},
{name = "Charlie", age = 35}
}
local sorted_by_age = table.sort_by(people, function(p) return p.age end)
assert_equal("Bob", sorted_by_age[1].name)
assert_equal("Charlie", sorted_by_age[3].name)
local sorted_by_name = table.sort_by(people, function(p) return p.name end)
assert_equal("Alice", sorted_by_name[1].name)
assert_equal("Charlie", sorted_by_name[3].name)
end)
test("Table Is Sorted", function()
assert_equal(true, table.is_sorted({1, 2, 3, 4, 5}))
assert_equal(false, table.is_sorted({1, 3, 2, 4, 5}))
assert_equal(true, table.is_sorted({}))
assert_equal(true, table.is_sorted({42}))
assert_equal(true, table.is_sorted({5, 4, 3, 2, 1}, function(a, b) return a > b end))
assert_equal(false, table.is_sorted({1, 2, 3, 4, 5}, function(a, b) return a > b end))
end)
-- ======================================================================
-- UTILITY FUNCTIONS
-- ======================================================================
test("Table Keys and Values", function()
local keys = table.keys(simple_table)
assert_equal(3, #keys)
assert_equal(true, table.contains(keys, "a"))
assert_equal(true, table.contains(keys, "b"))
assert_equal(true, table.contains(keys, "c"))
local values = table.values(simple_table)
assert_equal(3, #values)
assert_equal(true, table.contains(values, 1))
assert_equal(true, table.contains(values, 2))
assert_equal(true, table.contains(values, 3))
end)
test("Table Pairs", function()
local pairs_list = table.pairs({a = 1, b = 2})
assert_equal(2, #pairs_list)
-- Should contain key-value pairs
local found_a, found_b = false, false
for _, pair in ipairs(pairs_list) do
if pair[1] == "a" and pair[2] == 1 then found_a = true end
if pair[1] == "b" and pair[2] == 2 then found_b = true end
end
assert_equal(true, found_a)
assert_equal(true, found_b)
end)
test("Table Merge", function()
local t1 = {a = 1, b = 2}
local t2 = {c = 3, d = 4}
local t3 = {b = 20, e = 5}
local merged = table.merge(t1, t2, t3)
assert_equal(5, table.size(merged))
assert_equal(1, merged.a)
assert_equal(20, merged.b) -- Last one wins
assert_equal(3, merged.c)
assert_equal(4, merged.d)
assert_equal(5, merged.e)
end)
test("Table Extend", function()
local t1 = {a = 1, b = 2}
local t2 = {c = 3, d = 4}
local extended = table.extend(t1, t2)
assert_equal(t1, extended) -- Should return t1
assert_equal(4, table.size(t1))
assert_equal(3, t1.c)
assert_equal(4, t1.d)
end)
test("Table Invert", function()
local inverted = table.invert(simple_table)
assert_equal("a", inverted[1])
assert_equal("b", inverted[2])
assert_equal("c", inverted[3])
end)
test("Table Pick and Omit", function()
local big_table = {a = 1, b = 2, c = 3, d = 4, e = 5}
local picked = table.pick(big_table, "a", "c", "e")
assert_equal(3, table.size(picked))
assert_equal(1, picked.a)
assert_equal(3, picked.c)
assert_equal(5, picked.e)
assert_equal(nil, picked.b)
assert_equal(nil, picked.d)
local omitted = table.omit(big_table, "b", "d")
assert_equal(3, table.size(omitted))
assert_equal(1, omitted.a)
assert_equal(3, omitted.c)
assert_equal(5, omitted.e)
assert_equal(nil, omitted.b)
assert_equal(nil, omitted.d)
end)
-- ======================================================================
-- DEEP OPERATIONS
-- ======================================================================
test("Table Deep Equals", function()
local t1 = {a = {x = 1, y = 2}, b = {1, 2, 3}}
local t2 = {a = {x = 1, y = 2}, b = {1, 2, 3}}
local t3 = {a = {x = 1, y = 3}, b = {1, 2, 3}}
assert_equal(true, table.deep_equals(t1, t2))
assert_equal(false, table.deep_equals(t1, t3))
assert_equal(true, table.deep_equals({}, {}))
assert_equal(false, table.deep_equals({a = 1}, {a = 1, b = 2}))
end)
test("Table Flatten", function()
local nested = {{1, 2}, {3, 4}, {5, {6, 7}}}
local flattened = table.flatten(nested)
assert_table_equal({1, 2, 3, 4, 5, {6, 7}}, flattened)
local deep_flattened = table.flatten(nested, 2)
assert_table_equal({1, 2, 3, 4, 5, 6, 7}, deep_flattened)
local already_flat = table.flatten({1, 2, 3})
assert_table_equal({1, 2, 3}, already_flat)
end)
test("Table Deep Merge", function()
local t1 = {a = {x = 1}, b = 2}
local t2 = {a = {y = 3}, c = 4}
local merged = table.deep_merge(t1, t2)
assert_equal(1, merged.a.x)
assert_equal(3, merged.a.y)
assert_equal(2, merged.b)
assert_equal(4, merged.c)
-- Original tables should be unchanged
assert_equal(nil, t1.a.y)
assert_equal(nil, t1.c)
end)
-- ======================================================================
-- ADVANCED OPERATIONS
-- ======================================================================
test("Table Chunk", function()
local chunks = table.chunk({1, 2, 3, 4, 5, 6, 7}, 3)
assert_equal(3, #chunks)
assert_table_equal({1, 2, 3}, chunks[1])
assert_table_equal({4, 5, 6}, chunks[2])
assert_table_equal({7}, chunks[3])
local exact_chunks = table.chunk({1, 2, 3, 4}, 2)
assert_equal(2, #exact_chunks)
assert_table_equal({1, 2}, exact_chunks[1])
assert_table_equal({3, 4}, exact_chunks[2])
end)
test("Table Partition", function()
local evens, odds = table.partition(simple_array, function(v) return v % 2 == 0 end)
assert_table_equal({2, 4}, evens)
assert_table_equal({1, 3, 5}, odds)
local empty_true, all_false = table.partition({1, 3, 5}, function(v) return v % 2 == 0 end)
assert_table_equal({}, empty_true)
assert_table_equal({1, 3, 5}, all_false)
end)
test("Table Group By", function()
local people = {
{name = "Alice", department = "engineering"},
{name = "Bob", department = "sales"},
{name = "Charlie", department = "engineering"},
{name = "David", department = "sales"}
}
local by_dept = table.group_by(people, function(person) return person.department end)
assert_equal(2, table.size(by_dept))
assert_equal(2, #by_dept.engineering)
assert_equal(2, #by_dept.sales)
assert_equal("Alice", by_dept.engineering[1].name)
assert_equal("Bob", by_dept.sales[1].name)
end)
test("Table Zip", function()
local names = {"Alice", "Bob", "Charlie"}
local ages = {25, 30, 35}
local cities = {"NYC", "LA", "Chicago"}
local zipped = table.zip(names, ages, cities)
assert_equal(3, #zipped)
assert_table_equal({"Alice", 25, "NYC"}, zipped[1])
assert_table_equal({"Bob", 30, "LA"}, zipped[2])
assert_table_equal({"Charlie", 35, "Chicago"}, zipped[3])
-- Different lengths
local short_zip = table.zip({1, 2, 3}, {"a", "b"})
assert_equal(2, #short_zip)
assert_table_equal({1, "a"}, short_zip[1])
assert_table_equal({2, "b"}, short_zip[2])
end)
test("Table Compact", function()
local messy = {1, nil, false, 2, nil, 3, false}
local compacted = table.compact(messy)
assert_table_equal({1, 2, 3}, compacted)
local clean = {1, 2, 3}
local unchanged = table.compact(clean)
assert_table_equal(clean, unchanged)
end)
test("Table Sample", function()
local sample1 = table.sample(simple_array, 3)
assert_equal(3, #sample1)
-- All sampled elements should be from original
for _, v in ipairs(sample1) do
assert_equal(true, table.contains(simple_array, v))
end
local single_sample = table.sample(simple_array)
assert_equal(1, #single_sample)
assert_equal(true, table.contains(simple_array, single_sample[1]))
local oversample = table.sample({1, 2}, 5)
assert_equal(2, #oversample)
end)
-- ======================================================================
-- EDGE CASES AND ERROR HANDLING
-- ======================================================================
test("Empty Table Handling", function()
local empty = {}
assert_equal(true, table.is_empty(empty))
assert_equal(0, table.length(empty))
assert_equal(0, table.size(empty))
assert_equal(true, table.is_array(empty))
assert_table_equal({}, table.filter(empty, function() return true end))
assert_table_equal({}, table.map(empty, function(v) return v * 2 end))
assert_table_equal({}, table.keys(empty))
assert_table_equal({}, table.values(empty))
assert_equal(true, table.all(empty))
assert_equal(false, table.any(empty))
assert_equal(true, table.none(empty))
end)
test("Single Element Tables", function()
local single = {42}
assert_equal(1, table.length(single))
assert_equal(1, table.size(single))
assert_equal(true, table.is_array(single))
assert_equal(false, table.is_empty(single))
assert_equal(42, table.sum(single))
assert_equal(42, table.product(single))
assert_equal(42, table.min(single))
assert_equal(42, table.max(single))
assert_equal(42, table.average(single))
assert_table_equal({42}, table.reverse(single))
assert_table_equal({84}, table.map(single, function(v) return v * 2 end))
end)
test("Circular Reference Handling", function()
local t1 = {a = 1}
local t2 = {b = 2}
t1.ref = t2
t2.ref = t1
-- Deep copy should handle circular references
local copied = table.deep_copy(t1)
assert_equal(1, copied.a)
assert_equal(2, copied.ref.b)
assert_equal(copied, copied.ref.ref) -- Should maintain circular structure
end)
test("Large Table Performance", function()
local large = {}
for i = 1, 10000 do
large[i] = i
end
assert_equal(10000, table.length(large))
assert_equal(true, table.is_array(large))
assert_equal(50005000, table.sum(large)) -- Sum of 1 to 10000
local evens = table.filter(large, function(v) return v % 2 == 0 end)
assert_equal(5000, #evens)
local doubled = table.map(large, function(v) return v * 2 end)
assert_equal(10000, #doubled)
assert_equal(2, doubled[1])
assert_equal(20000, doubled[10000])
end)
test("Mixed Type Table Handling", function()
local mixed = {1, "hello", true, {a = 1}, function() end}
assert_equal(5, table.length(mixed))
assert_equal(true, table.is_array(mixed))
assert_equal(true, table.contains(mixed, "hello"))
assert_equal(true, table.contains(mixed, true))
local strings_only = table.filter(mixed, function(v) return type(v) == "string" end)
assert_equal(1, #strings_only)
assert_equal("hello", strings_only[1])
end)
-- ======================================================================
-- PERFORMANCE TESTS
-- ======================================================================
test("Performance Test", function()
local large_array = {}
for i = 1, 10000 do
large_array[i] = math.random(1, 1000)
end
local start = os.clock()
local filtered = table.filter(large_array, function(v) return v > 500 end)
local filter_time = os.clock() - start
start = os.clock()
local mapped = table.map(large_array, function(v) return v * 2 end)
local map_time = os.clock() - start
start = os.clock()
local sum = table.sum(large_array)
local sum_time = os.clock() - start
start = os.clock()
local sorted = table.sort_by(large_array, function(v) return v end)
local sort_time = os.clock() - start
start = os.clock()
local unique = table.unique(large_array)
local unique_time = os.clock() - start
print(string.format(" Filter %d elements: %.3fs", #filtered, filter_time))
print(string.format(" Map %d elements: %.3fs", #mapped, map_time))
print(string.format(" Sum %d elements: %.3fs", #large_array, sum_time))
print(string.format(" Sort %d elements: %.3fs", #sorted, sort_time))
print(string.format(" Unique from %d to %d: %.3fs", #large_array, #unique, unique_time))
assert(#filtered > 0, "should filter some elements")
assert_equal(#large_array, #mapped)
assert(sum > 0, "sum should be positive")
assert_equal(#large_array, #sorted)
assert(table.is_sorted(sorted), "should be sorted")
end)
-- ======================================================================
-- INTEGRATION TESTS
-- ======================================================================
test("Data Processing Pipeline", function()
local sales_data = {
{product = "laptop", price = 1000, quantity = 2, category = "electronics"},
{product = "mouse", price = 25, quantity = 10, category = "electronics"},
{product = "book", price = 15, quantity = 5, category = "books"},
{product = "phone", price = 800, quantity = 3, category = "electronics"},
{product = "magazine", price = 5, quantity = 20, category = "books"}
}
-- Calculate total revenue per item
local with_revenue = table.map(sales_data, function(item)
local new_item = table.clone(item)
new_item.revenue = item.price * item.quantity
return new_item
end)
-- Filter high-value items (revenue >= 100)
local high_value = table.filter(with_revenue, function(item)
return item.revenue >= 100
end)
-- Group by category
local by_category = table.group_by(high_value, function(item)
return item.category
end)
-- Calculate total revenue by category
local category_totals = table.map_values(by_category, function(items)
return table.sum(table.map(items, function(item) return item.revenue end))
end)
assert_equal(2, table.size(category_totals))
assert_equal(4650, category_totals.electronics) -- laptop: 2000, mouse: 250, phone: 2400
assert_equal(100, category_totals.books) -- magazine: 100
end)
test("Complex Data Transformation", function()
local users = {
{id = 1, name = "Alice", age = 25, skills = {"lua", "python"}},
{id = 2, name = "Bob", age = 30, skills = {"javascript", "lua"}},
{id = 3, name = "Charlie", age = 35, skills = {"python", "java"}},
{id = 4, name = "David", age = 28, skills = {"lua", "go"}}
}
-- Find Lua developers
local lua_devs = table.filter(users, function(user)
return table.contains(user.skills, "lua")
end)
-- Sort by age
local sorted_lua_devs = table.sort_by(lua_devs, function(user) return user.age end)
-- Extract just names and ages
local simplified = table.map(sorted_lua_devs, function(user)
return {name = user.name, age = user.age}
end)
assert_equal(3, #simplified)
assert_equal("Alice", simplified[1].name) -- Youngest
assert_equal("David", simplified[2].name)
assert_equal("Bob", simplified[3].name) -- Oldest
-- Group all users by age ranges
local age_groups = table.group_by(users, function(user)
if user.age < 30 then return "young"
else return "experienced" end
end)
assert_equal(2, #age_groups.young) -- Alice, David
assert_equal(2, #age_groups.experienced) -- Bob, Charlie
end)
test("Statistical Analysis", function()
local test_scores = {
{student = "Alice", scores = {85, 92, 78, 95, 88}},
{student = "Bob", scores = {72, 85, 90, 77, 82}},
{student = "Charlie", scores = {95, 88, 92, 90, 85}},
{student = "David", scores = {68, 75, 80, 72, 78}}
}
-- Calculate average score for each student
local with_averages = table.map(test_scores, function(student)
local avg = table.average(student.scores)
return {
student = student.student,
scores = student.scores,
average = avg,
max_score = table.max(student.scores),
min_score = table.min(student.scores)
}
end)
-- Find top performer
local top_student = table.reduce(with_averages, function(best, current)
return current.average > best.average and current or best
end)
-- Students above class average
local all_averages = table.map(with_averages, function(s) return s.average end)
local class_average = table.average(all_averages)
local above_average = table.filter(with_averages, function(s)
return s.average > class_average
end)
assert_equal("Charlie", top_student.student)
assert_equal(90, top_student.average)
assert_equal(2, #above_average) -- Charlie and Alice
assert_close(83.4, class_average, 0.1)
end)
test("Table Metatable Method Chaining", function()
local t = {1, 2, 3, 4, 5}
-- Check if methods are available directly on table instances
assert_equal("function", type(t.filter), "table should have filter method via metatable")
assert_equal("function", type(t.map), "table should have map method via metatable")
assert_equal("function", type(t.length), "table should have length method via metatable")
-- Test actual method chaining
local result = t:filter(function(v) return v > 2 end)
:map(function(v) return v * 2 end)
assert_table_equal({6, 8, 10}, result)
-- Test chaining with method calls
local nums = {1, 2, 3, 4, 5}
local sum = nums:sum()
assert_equal(15, sum)
local sizes = {a = 1, b = 2}
local size = sizes:size()
assert_equal(2, size)
end)
summary()
test_exit()

192
tests/tests.lua Normal file
View File

@ -0,0 +1,192 @@
local passed = 0
local total = 0
function assert(condition, message, level)
if condition then
return true
end
level = level or 2
local info = debug.getinfo(level, "Sl")
local file = info.source
-- Extract filename from source or use generic name
if file:sub(1,1) == "@" then
file = file:sub(2) -- Remove @ prefix for files
else
file = "<script>" -- Generic name for inline scripts
end
local line = info.currentline or "unknown"
local error_msg = message or "assertion failed"
local full_msg = string.format("%s:%s: %s", file, line, error_msg)
error(full_msg, 0)
end
function assert_equal(expected, actual, message)
if expected == actual then
return true
end
local msg = message or string.format("Expected %s, got %s", tostring(expected), tostring(actual))
assert(false, msg, 3)
end
function assert_table_equal(expected, actual, message, path)
path = path or "root"
if type(expected) ~= type(actual) then
local msg = message or string.format("Type mismatch at %s: expected %s, got %s", path, type(expected), type(actual))
assert(false, msg, 3)
end
if type(expected) ~= "table" then
if expected ~= actual then
local msg = message or string.format("Value mismatch at %s: expected %s, got %s", path, tostring(expected), tostring(actual))
assert(false, msg, 3)
end
return true
end
-- Check all keys in a exist in b with same values
for k, v in pairs(expected) do
local new_path = path .. "." .. tostring(k)
if actual[k] == nil then
local msg = message or string.format("Missing key at %s", new_path)
assert(false, msg, 3)
end
assert_table_equal(v, actual[k], message, new_path)
end
-- Check all keys in b exist in a
for k, v in pairs(actual) do
if expected[k] == nil then
local new_path = path .. "." .. tostring(k)
local msg = message or string.format("Extra key at %s", new_path)
assert(false, msg, 3)
end
end
return true
end
function assert_close(expected, actual, tolerance, message)
tolerance = tolerance or 1e-10
local diff = math.abs(expected - actual)
if diff <= tolerance then
return true
end
local msg = message or string.format("Expected %g, got %g (diff: %g > %g)", expected, actual, diff, tolerance)
assert(false, msg, 3)
end
function test(name, fn)
print("Testing " .. name .. "...")
total = total + 1
local start_time = os.clock()
local ok, err = pcall(fn)
local end_time = os.clock()
local duration = end_time - start_time
if ok then
passed = passed + 1
print(string.format(" ✓ PASS (%.3fs)", duration))
return true
else
print(" ✗ FAIL: " .. err)
if duration > 0.001 then
print(string.format(" (%.3fs)", duration))
end
return false
end
end
function run_tests(tests)
print("Running test suite...")
print("=" .. string.rep("=", 50))
for name, test_fn in pairs(tests) do
test(name, test_fn)
end
return summary()
end
function reset_tests()
passed = 0
total = 0
end
function test_stats()
return {
passed = passed,
total = total,
failed = total - passed,
success_rate = total > 0 and (passed / total) or 0
}
end
function summary()
print("=" .. string.rep("=", 50))
print(string.format("Test Results: %d/%d passed", passed, total))
local success = passed == total
if success then
print("🎉 All tests passed!")
else
local failed = total - passed
local rate = total > 0 and (passed / total * 100) or 0
print(string.format("❌ %d test(s) failed! (%.1f%% success rate)", failed, rate))
end
return success
end
function test_exit()
local success = passed == total
os.exit(success and 0 or 1)
end
function run_and_exit(tests)
local success = run_tests(tests)
os.exit(success and 0 or 1)
end
function benchmark(name, fn, iterations)
iterations = iterations or 1000
print("Benchmarking " .. name .. " (" .. iterations .. " iterations)...")
-- Warmup
for i = 1, math.min(10, iterations) do
fn()
end
-- Actual benchmark
local start = os.clock()
for i = 1, iterations do
fn()
end
local total_time = os.clock() - start
local avg_time = total_time / iterations
print(string.format(" Total: %.3fs, Average: %.6fs, Rate: %.0f ops/sec",
total_time, avg_time, 1/avg_time))
return {
total_time = total_time,
avg_time = avg_time,
ops_per_sec = 1/avg_time,
iterations = iterations
}
end
function file_exists(filename)
local file = io.open(filename, "r")
if file then
file:close()
return true
end
return false
end

View File

@ -1,68 +0,0 @@
package color
// ANSI color codes
const (
Reset = "\033[0m"
Red = "\033[31m"
Green = "\033[32m"
Yellow = "\033[33m"
Blue = "\033[34m"
Purple = "\033[35m"
Cyan = "\033[36m"
White = "\033[37m"
Gray = "\033[90m"
)
var useColors = true
// SetColors enables or disables colors globally
func SetColors(enabled bool) {
useColors = enabled
}
// ColorsEnabled returns current global color setting
func ColorsEnabled() bool {
return useColors
}
// Apply adds color to text using global color setting
func Apply(text, color string) string {
if useColors {
return color + text + Reset
}
return text
}
// ApplyIf adds color to text if useColors is true (for backward compatibility)
func ApplyIf(text, color string, enabled bool) string {
if enabled {
return color + text + Reset
}
return text
}
// Set adds color to text (always applies color, ignores global setting)
func Set(text, color string) string {
return color + text + Reset
}
// Strip removes ANSI color codes from a string
func Strip(s string) string {
result := ""
inEscape := false
for _, c := range s {
if inEscape {
if c == 'm' {
inEscape = false
}
continue
}
if c == '\033' {
inEscape = true
continue
}
result += string(c)
}
return result
}

View File

@ -1,290 +0,0 @@
package config
import (
"errors"
"fmt"
"runtime"
"strconv"
"strings"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Config represents a configuration loaded from a Lua file
type Config struct {
// Server settings
Server struct {
Port int
Debug bool
HTTPLogging bool
StaticPrefix string
}
// Runner settings
Runner struct {
PoolSize int
}
// Directory paths
Dirs struct {
Routes string
Static string
FS string
Data string
Override string
Libs []string
}
// Raw values map for custom values
values map[string]any
}
// NewConfig creates a new configuration with default values
func New() *Config {
config := &Config{
// Initialize values map
values: make(map[string]any),
}
// Server defaults
config.Server.Port = 3117
config.Server.Debug = false
config.Server.HTTPLogging = false
config.Server.StaticPrefix = "static/"
// Runner defaults
config.Runner.PoolSize = runtime.GOMAXPROCS(0)
// Dirs defaults
config.Dirs.Routes = "routes"
config.Dirs.Static = "public"
config.Dirs.FS = "fs"
config.Dirs.Data = "data"
config.Dirs.Override = "override"
config.Dirs.Libs = []string{"libs"}
return config
}
// Load loads configuration from a Lua file
func Load(filePath string) (*Config, error) {
// Create Lua state
state := luajit.New(true)
if state == nil {
return nil, errors.New("failed to create Lua state")
}
defer state.Close()
defer state.Cleanup()
// Create config with default values
config := New()
// Execute the config file
if err := state.DoFile(filePath); err != nil {
return nil, fmt.Errorf("failed to load config file: %w", err)
}
// Store values directly to the config
config.values = make(map[string]any)
// Extract top-level tables
tables := []string{"server", "runner", "dirs"}
for _, table := range tables {
state.GetGlobal(table)
if state.IsTable(-1) {
tableMap, err := state.ToTable(-1)
if err == nil {
config.values[table] = tableMap
}
}
state.Pop(1)
}
// Apply configuration values
applyConfig(config, config.values)
return config, nil
}
// applyConfig applies configuration values from the globals map
func applyConfig(config *Config, values map[string]any) {
// Apply server settings
if serverTable, ok := values["server"].(map[string]any); ok {
if v, ok := serverTable["port"].(float64); ok {
config.Server.Port = int(v)
}
if v, ok := serverTable["debug"].(bool); ok {
config.Server.Debug = v
}
if v, ok := serverTable["http_logging"].(bool); ok {
config.Server.HTTPLogging = v
}
if v, ok := serverTable["static_prefix"].(string); ok {
config.Server.StaticPrefix = v
}
}
// Apply runner settings
if runnerTable, ok := values["runner"].(map[string]any); ok {
if v, ok := runnerTable["pool_size"].(float64); ok && v != 0 {
config.Runner.PoolSize = int(v)
}
}
// Apply dirs settings
if dirsTable, ok := values["dirs"].(map[string]any); ok {
if v, ok := dirsTable["routes"].(string); ok {
config.Dirs.Routes = v
}
if v, ok := dirsTable["static"].(string); ok {
config.Dirs.Static = v
}
if v, ok := dirsTable["fs"].(string); ok {
config.Dirs.FS = v
}
if v, ok := dirsTable["data"].(string); ok {
config.Dirs.Data = v
}
if v, ok := dirsTable["override"].(string); ok {
config.Dirs.Override = v
}
// Handle libs array
if libs, ok := dirsTable["libs"]; ok {
if libsArray := extractStringArray(libs); len(libsArray) > 0 {
config.Dirs.Libs = libsArray
}
}
}
}
// extractStringArray extracts a string array from a Lua table
func extractStringArray(value any) []string {
// Direct array case
if arr, ok := value.([]any); ok {
result := make([]string, 0, len(arr))
for _, v := range arr {
if str, ok := v.(string); ok {
result = append(result, str)
}
}
return result
}
// Map with numeric keys case
if tableMap, ok := value.(map[string]any); ok {
result := make([]string, 0, len(tableMap))
for _, v := range tableMap {
if str, ok := v.(string); ok {
result = append(result, str)
}
}
return result
}
return nil
}
// GetCustomValue returns any custom configuration value by key
// Key can be a dotted path like "server.port"
func (c *Config) GetCustomValue(key string) any {
parts := strings.Split(key, ".")
if len(parts) == 1 {
return c.values[key]
}
current := c.values
for _, part := range parts[:len(parts)-1] {
next, ok := current[part].(map[string]any)
if !ok {
return nil
}
current = next
}
return current[parts[len(parts)-1]]
}
// GetCustomString returns a custom string configuration value
func (c *Config) GetCustomString(key string, defaultValue string) string {
value := c.GetCustomValue(key)
if value == nil {
return defaultValue
}
// Convert to string
switch v := value.(type) {
case string:
return v
case float64:
return fmt.Sprintf("%g", v)
case int:
return strconv.Itoa(v)
case bool:
if v {
return "true"
}
return "false"
default:
return defaultValue
}
}
// GetCustomInt returns a custom integer configuration value
func (c *Config) GetCustomInt(key string, defaultValue int) int {
value := c.GetCustomValue(key)
if value == nil {
return defaultValue
}
// Convert to int
switch v := value.(type) {
case int:
return v
case float64:
return int(v)
case string:
if i, err := strconv.Atoi(v); err == nil {
return i
}
case bool:
if v {
return 1
}
return 0
default:
return defaultValue
}
return defaultValue
}
// GetCustomBool returns a custom boolean configuration value
func (c *Config) GetCustomBool(key string, defaultValue bool) bool {
value := c.GetCustomValue(key)
if value == nil {
return defaultValue
}
// Convert to bool
switch v := value.(type) {
case bool:
return v
case string:
switch v {
case "true", "yes", "1":
return true
case "false", "no", "0", "":
return false
}
case int:
return v != 0
case float64:
return v != 0
default:
return defaultValue
}
return defaultValue
}

View File

@ -1,217 +0,0 @@
package utils
import (
"fmt"
"html/template"
"runtime"
"strings"
"time"
"Moonshark/utils/config"
"Moonshark/utils/metadata"
)
// ComponentStats holds stats from various system components
type ComponentStats struct {
RouteCount int // Number of routes
BytecodeBytes int64 // Total size of bytecode in bytes
ModuleCount int // Number of loaded modules
SessionStats map[string]uint64 // Session cache statistics
}
// SystemStats represents system statistics for debugging
type SystemStats struct {
Timestamp time.Time
GoVersion string
GoRoutines int
Memory runtime.MemStats
Components ComponentStats
Version string
Config *config.Config
}
// CollectSystemStats gathers basic system statistics
func CollectSystemStats(cfg *config.Config) SystemStats {
var stats SystemStats
var mem runtime.MemStats
stats.Timestamp = time.Now()
stats.GoVersion = runtime.Version()
stats.GoRoutines = runtime.NumGoroutine()
stats.Version = metadata.Version
stats.Config = cfg
runtime.ReadMemStats(&mem)
stats.Memory = mem
return stats
}
// DebugStatsPage generates an HTML debug stats page
func DebugStatsPage(stats SystemStats) string {
const debugTemplate = `
<!DOCTYPE html>
<html>
<head>
<title>Moonshark</title>
<style>
body {
font-family: sans-serif;
max-width: 900px;
margin: 0 auto;
background-color: #333;
color: white;
}
h1 {
padding: 1rem;
background-color: #4F5B93;
box-shadow: 0 2px 4px 0px rgba(0, 0, 0, 0.2);
margin-top: 0;
}
h2 { margin-top: 0; margin-bottom: 0.5rem; }
table { width: 100%; border-collapse: collapse; }
th { width: 1%; white-space: nowrap; border-right: 1px solid rgba(0, 0, 0, 0.1); }
th, td { text-align: left; padding: 0.75rem 0.5rem; border-bottom: 1px solid #ddd; }
tr:last-child th, tr:last-child td { border-bottom: none; }
table tr:nth-child(even), tbody tr:nth-child(even) { background-color: rgba(0, 0, 0, 0.1); }
.card {
background: #F2F2F2;
color: #333;
border-radius: 4px;
margin-bottom: 1rem;
box-shadow: 0 2px 4px 0px rgba(0, 0, 0, 0.2);
}
.timestamp { color: #999; font-size: 0.9em; margin-bottom: 1rem; }
.section { margin-bottom: 30px; }
</style>
</head>
<body>
<h1>Moonshark</h1>
<div class="timestamp">Generated at: {{.Timestamp.Format "2006-01-02 15:04:05"}}</div>
<div class="section">
<h2>Server</h2>
<div class="card">
<table>
<tr><th>Version</th><td>{{.Version}}</td></tr>
</table>
</div>
</div>
<div class="section">
<h2>System</h2>
<div class="card">
<table>
<tr><th>Go Version</th><td>{{.GoVersion}}</td></tr>
<tr><th>Goroutines</th><td>{{.GoRoutines}}</td></tr>
</table>
</div>
</div>
<div class="section">
<h2>Memory</h2>
<div class="card">
<table>
<tr><th>Allocated</th><td>{{ByteCount .Memory.Alloc}}</td></tr>
<tr><th>Total Allocated</th><td>{{ByteCount .Memory.TotalAlloc}}</td></tr>
<tr><th>System Memory</th><td>{{ByteCount .Memory.Sys}}</td></tr>
<tr><th>GC Cycles</th><td>{{.Memory.NumGC}}</td></tr>
</table>
</div>
</div>
<div class="section">
<h2>Sessions</h2>
<div class="card">
<table>
<tr><th>Active Sessions</th><td>{{index .Components.SessionStats "entries"}}</td></tr>
<tr><th>Cache Size</th><td>{{ByteCount (index .Components.SessionStats "bytes")}}</td></tr>
<tr><th>Max Cache Size</th><td>{{ByteCount (index .Components.SessionStats "max_bytes")}}</td></tr>
<tr><th>Cache Gets</th><td>{{index .Components.SessionStats "gets"}}</td></tr>
<tr><th>Cache Sets</th><td>{{index .Components.SessionStats "sets"}}</td></tr>
<tr><th>Cache Misses</th><td>{{index .Components.SessionStats "misses"}}</td></tr>
</table>
</div>
</div>
<div class="section">
<h2>LuaRunner</h2>
<div class="card">
<table>
<tr><th>Interpreter</th><td>LuaJIT 2.1 (Lua 5.1)</td></tr>
<tr><th>Active Routes</th><td>{{.Components.RouteCount}}</td></tr>
<tr><th>Bytecode Size</th><td>{{ByteCount .Components.BytecodeBytes}}</td></tr>
<tr><th>Loaded Modules</th><td>{{.Components.ModuleCount}}</td></tr>
<tr><th>State Pool Size</th><td>{{.Config.Runner.PoolSize}}</td></tr>
</table>
</div>
</div>
<div class="section">
<h2>Config</h2>
<div class="card">
<table>
<tr><th>Port</th><td>{{.Config.Server.Port}}</td></tr>
<tr><th>Pool Size</th><td>{{.Config.Runner.PoolSize}}</td></tr>
<tr><th>Debug Mode</th><td>{{.Config.Server.Debug}}</td></tr>
<tr><th>Log Level</th><td>{{.Config.Server.LogLevel}}</td></tr>
<tr><th>HTTP Logging</th><td>{{.Config.Server.HTTPLogging}}</td></tr>
<tr>
<th>Directories</th>
<td>
<div>Routes: {{.Config.Dirs.Routes}}</div>
<div>Static: {{.Config.Dirs.Static}}</div>
<div>Override: {{.Config.Dirs.Override}}</div>
<div>Libs: {{range .Config.Dirs.Libs}}{{.}}, {{end}}</div>
</td>
</tr>
</table>
</div>
</div>
</body>
</html>
`
// Create a template function map
funcMap := template.FuncMap{
"ByteCount": func(b any) string {
var bytes uint64
switch v := b.(type) {
case uint64:
bytes = v
case int64:
bytes = uint64(v)
case int:
bytes = uint64(v)
default:
return fmt.Sprintf("%T: %v", b, b)
}
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := uint64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
},
}
// Parse the template
tmpl, err := template.New("debug").Funcs(funcMap).Parse(debugTemplate)
if err != nil {
return fmt.Sprintf("Error parsing template: %v", err)
}
// Execute the template
var output strings.Builder
if err := tmpl.Execute(&output, stats); err != nil {
return fmt.Sprintf("Error executing template: %v", err)
}
return output.String()
}

View File

@ -1,257 +0,0 @@
package utils
import (
"math/rand"
"os"
"path/filepath"
)
// ErrorPageConfig holds configuration for generating error pages
type ErrorPageConfig struct {
OverrideDir string // Directory where override templates are stored
DebugMode bool // Whether to show debug information
}
// ErrorType represents HTTP error types
type ErrorType int
const (
ErrorTypeNotFound ErrorType = 404
ErrorTypeMethodNotAllowed ErrorType = 405
ErrorTypeInternalError ErrorType = 500
ErrorTypeForbidden ErrorType = 403 // Added CSRF/Forbidden error type
)
// ErrorPage generates an HTML error page based on the error type
// It first checks for an override file, and if not found, generates a default page
func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg string) string {
// Check for override file
if config.OverrideDir != "" {
var filename string
switch errorType {
case ErrorTypeNotFound:
filename = "404.html"
case ErrorTypeMethodNotAllowed:
filename = "405.html"
case ErrorTypeInternalError:
filename = "500.html"
case ErrorTypeForbidden:
filename = "403.html"
}
if filename != "" {
overridePath := filepath.Join(config.OverrideDir, filename)
if content, err := os.ReadFile(overridePath); err == nil {
return string(content)
}
}
}
// No override found, generate default page
switch errorType {
case ErrorTypeNotFound:
return generateNotFoundHTML(url)
case ErrorTypeMethodNotAllowed:
return generateMethodNotAllowedHTML(url)
case ErrorTypeInternalError:
return generateInternalErrorHTML(config.DebugMode, url, errMsg)
case ErrorTypeForbidden:
return generateForbiddenHTML(config.DebugMode, url, errMsg)
default:
// Fallback to internal error
return generateInternalErrorHTML(config.DebugMode, url, errMsg)
}
}
// NotFoundPage generates a 404 Not Found error page
func NotFoundPage(config ErrorPageConfig, url string) string {
return ErrorPage(config, ErrorTypeNotFound, url, "")
}
// MethodNotAllowedPage generates a 405 Method Not Allowed error page
func MethodNotAllowedPage(config ErrorPageConfig, url string) string {
return ErrorPage(config, ErrorTypeMethodNotAllowed, url, "")
}
// InternalErrorPage generates a 500 Internal Server Error page
func InternalErrorPage(config ErrorPageConfig, url string, errMsg string) string {
return ErrorPage(config, ErrorTypeInternalError, url, errMsg)
}
// ForbiddenPage generates a 403 Forbidden error page
func ForbiddenPage(config ErrorPageConfig, url string, errMsg string) string {
return ErrorPage(config, ErrorTypeForbidden, url, errMsg)
}
// generateInternalErrorHTML creates a 500 Internal Server Error page
func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string {
errorMessages := []string{
"Oops! Something went wrong",
"Oh no! The server choked",
"Well, this is embarrassing...",
"Houston, we have a problem",
"Gremlins in the system",
"The server is taking a coffee break",
"Moonshark encountered a lunar eclipse",
"Our code monkeys are working on it",
"The server is feeling under the weather",
"500 Brain Not Found",
}
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
return generateErrorHTML("500", randomMessage, "Internal Server Error", debugMode, errMsg)
}
// generateForbiddenHTML creates a 403 Forbidden error page
func generateForbiddenHTML(debugMode bool, url string, errMsg string) string {
errorMessages := []string{
"Access denied",
"You shall not pass",
"This area is off-limits",
"Security check failed",
"Invalid security token",
"Request blocked for security reasons",
"Permission denied",
"Security violation detected",
"This request was rejected",
"Security first, access second",
}
defaultMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
if errMsg == "" {
errMsg = defaultMsg
}
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
return generateErrorHTML("403", randomMessage, "Forbidden", debugMode, errMsg)
}
// generateNotFoundHTML creates a 404 Not Found error page
func generateNotFoundHTML(url string) string {
errorMessages := []string{
"Nothing to see here",
"This page is on vacation",
"The page is missing in action",
"This page has left the building",
"This page is in another castle",
"Sorry, we can't find that",
"The page you're looking for doesn't exist",
"Lost in space",
"That's a 404",
"Page not found",
}
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
return generateErrorHTML("404", randomMessage, "Page Not Found", false, url)
}
// generateMethodNotAllowedHTML creates a 405 Method Not Allowed error page
func generateMethodNotAllowedHTML(url string) string {
errorMessages := []string{
"That's not how this works",
"Method not allowed",
"Wrong way!",
"This method is not supported",
"You can't do that here",
"Sorry, wrong door",
"That method won't work here",
"Try a different approach",
"Access denied for this method",
"Method mismatch",
}
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
return generateErrorHTML("405", randomMessage, "Method Not Allowed", false, url)
}
// generateErrorHTML creates the common HTML structure for error pages
func generateErrorHTML(errorCode, mainMessage, subMessage string, showDebugInfo bool, codeContent string) string {
errorHTML := `<!doctype html>
<html>
<head>
<title>` + errorCode + `</title>
<style>
:root {
--bg-color: #2d2e2d;
--bg-gradient: linear-gradient(to bottom, #2d2e2d 0%, #000 100%);
--text-color: white;
--code-bg: rgba(255, 255, 255, 0.1);
}
@media (prefers-color-scheme: light) {
:root {
--bg-color: #f5f5f5;
--bg-gradient: linear-gradient(to bottom, #f5f5f5 0%, #ddd 100%);
--text-color: #333;
--code-bg: rgba(0, 0, 0, 0.1);
}
}
body {
font-family: sans-serif;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100vh;
margin: 0;
background-color: var(--bg-color);
color: var(--text-color);
background: var(--bg-gradient);
}
h1 {
font-size: 4rem;
margin: 0;
padding: 0;
}
p {
font-size: 1.5rem;
margin: 0.5rem 0;
padding: 0;
}
.sub-message {
font-size: 1.2rem;
margin-bottom: 1rem;
opacity: 0.8;
}
code {
display: inline-block;
font-size: 1rem;
font-family: monospace;
background-color: var(--code-bg);
padding: 0.25em 0.5em;
border-radius: 0.25em;
margin-top: 1rem;
max-width: 90vw;
overflow-wrap: break-word;
word-break: break-all;
}
</style>
</head>
<body>
<div>
<h1>` + errorCode + `</h1>
<p>` + mainMessage + `</p>
<div class="sub-message">` + subMessage + `</div>`
if codeContent != "" {
errorHTML += `
<code>` + codeContent + `</code>`
}
// Add a note for debug mode
if showDebugInfo {
errorHTML += `
<p style="font-size: 0.9rem; margin-top: 1rem;">
An error occurred while processing your request.<br>
Please check the server logs for details.
</p>`
}
errorHTML += `
</div>
</body>
</html>`
return errorHTML
}

View File

@ -1,352 +0,0 @@
package logger
import (
"fmt"
"io"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
// ANSI color codes
const (
colorReset = "\033[0m"
colorRed = "\033[31m"
colorGreen = "\033[32m"
colorYellow = "\033[33m"
colorBlue = "\033[34m"
colorPurple = "\033[35m"
colorCyan = "\033[36m"
colorGray = "\033[90m"
)
// Log types
const (
TypeNone = iota
TypeDebug
TypeInfo
TypeWarning
TypeError
TypeServer
TypeFatal
)
// Type properties
var typeProps = map[int]struct {
tag string
color string
}{
TypeDebug: {"D", colorCyan},
TypeInfo: {"I", colorBlue},
TypeWarning: {"W", colorYellow},
TypeError: {"E", colorRed},
TypeServer: {"S", colorGreen},
TypeFatal: {"F", colorPurple},
}
const timeFormat = "15:04:05"
var (
globalLogger *Logger
globalLoggerOnce sync.Once
)
type Logger struct {
writer io.Writer
useColors bool
timeFormat string
showTimestamp bool
mu sync.Mutex
debugMode atomic.Bool
}
func GetLogger() *Logger {
globalLoggerOnce.Do(func() {
globalLogger = newLogger(true, true)
})
return globalLogger
}
func InitGlobalLogger(useColors bool, showTimestamp bool) {
globalLogger = newLogger(useColors, showTimestamp)
}
func newLogger(useColors bool, showTimestamp bool) *Logger {
return &Logger{
writer: os.Stdout,
useColors: useColors,
timeFormat: timeFormat,
showTimestamp: showTimestamp,
}
}
func New(useColors bool, showTimestamp bool) *Logger {
return newLogger(useColors, showTimestamp)
}
func (l *Logger) SetOutput(w io.Writer) {
l.mu.Lock()
l.writer = w
l.mu.Unlock()
}
func (l *Logger) TimeFormat() string {
return l.timeFormat
}
func (l *Logger) SetTimeFormat(format string) {
l.mu.Lock()
l.timeFormat = format
l.mu.Unlock()
}
func (l *Logger) EnableTimestamp() {
l.showTimestamp = true
}
func (l *Logger) DisableTimestamp() {
l.showTimestamp = false
}
func (l *Logger) EnableColors() {
l.useColors = true
}
func (l *Logger) DisableColors() {
l.useColors = false
}
func (l *Logger) EnableDebug() {
l.debugMode.Store(true)
}
func (l *Logger) DisableDebug() {
l.debugMode.Store(false)
}
func (l *Logger) IsDebugEnabled() bool {
return l.debugMode.Load()
}
func (l *Logger) applyColor(text, color string) string {
if l.useColors {
return color + text + colorReset
}
return text
}
func stripAnsiColors(s string) string {
result := ""
inEscape := false
for _, c := range s {
if inEscape {
if c == 'm' {
inEscape = false
}
continue
}
if c == '\033' {
inEscape = true
continue
}
result += string(c)
}
return result
}
func (l *Logger) writeMessage(logType int, message string, rawMode bool) {
if rawMode {
l.mu.Lock()
_, _ = fmt.Fprint(l.writer, message+"\n")
l.mu.Unlock()
return
}
parts := []string{}
if l.showTimestamp {
timestamp := time.Now().Format(l.timeFormat)
if l.useColors {
timestamp = l.applyColor(timestamp, colorGray)
}
parts = append(parts, timestamp)
}
if logType != TypeNone {
props := typeProps[logType]
tag := "[" + props.tag + "]"
if l.useColors {
tag = l.applyColor(tag, props.color)
}
parts = append(parts, tag)
}
parts = append(parts, message)
logLine := strings.Join(parts, " ") + "\n"
l.mu.Lock()
_, _ = fmt.Fprint(l.writer, logLine)
if logType == TypeFatal {
if f, ok := l.writer.(*os.File); ok {
_ = f.Sync()
}
}
l.mu.Unlock()
}
func (l *Logger) log(logType int, format string, args ...any) {
// Only filter debug messages
if logType == TypeDebug && !l.debugMode.Load() {
return
}
var message string
if len(args) > 0 {
message = fmt.Sprintf(format, args...)
} else {
message = format
}
l.writeMessage(logType, message, false)
if logType == TypeFatal {
os.Exit(1)
}
}
func (l *Logger) LogRaw(format string, args ...any) {
var message string
if len(args) > 0 {
message = fmt.Sprintf(format, args...)
} else {
message = format
}
if !l.useColors {
message = stripAnsiColors(message)
}
l.writeMessage(TypeInfo, message, true)
}
func (l *Logger) Debug(format string, args ...any) {
l.log(TypeDebug, format, args...)
}
func (l *Logger) Info(format string, args ...any) {
l.log(TypeInfo, format, args...)
}
func (l *Logger) Warning(format string, args ...any) {
l.log(TypeWarning, format, args...)
}
func (l *Logger) Error(format string, args ...any) {
l.log(TypeError, format, args...)
}
func (l *Logger) Fatal(format string, args ...any) {
l.log(TypeFatal, format, args...)
}
func (l *Logger) Server(format string, args ...any) {
l.log(TypeServer, format, args...)
}
func (l *Logger) LogRequest(statusCode int, method, path string, duration time.Duration) {
var statusColor string
switch {
case statusCode < 300:
statusColor = colorGreen
case statusCode < 400:
statusColor = colorCyan
case statusCode < 500:
statusColor = colorYellow
default:
statusColor = colorRed
}
var durationStr string
micros := duration.Microseconds()
if micros < 1000 {
durationStr = fmt.Sprintf("%.0fµs", float64(micros))
} else if micros < 1000000 {
durationStr = fmt.Sprintf("%.1fms", float64(micros)/1000)
} else {
durationStr = fmt.Sprintf("%.2fs", duration.Seconds())
}
message := fmt.Sprintf("%s %s %s %s",
l.applyColor("["+method+"]", colorGray),
l.applyColor(fmt.Sprintf("%d", statusCode), statusColor),
l.applyColor(path, colorGray),
l.applyColor(durationStr, colorGray),
)
l.writeMessage(TypeNone, message, false)
}
// Global functions
func Debug(format string, args ...any) {
GetLogger().Debug(format, args...)
}
func Info(format string, args ...any) {
GetLogger().Info(format, args...)
}
func Warning(format string, args ...any) {
GetLogger().Warning(format, args...)
}
func Error(format string, args ...any) {
GetLogger().Error(format, args...)
}
func Fatal(format string, args ...any) {
GetLogger().Fatal(format, args...)
}
func Server(format string, args ...any) {
GetLogger().Server(format, args...)
}
func LogRaw(format string, args ...any) {
GetLogger().LogRaw(format, args...)
}
func SetOutput(w io.Writer) {
GetLogger().SetOutput(w)
}
func TimeFormat() string {
return GetLogger().TimeFormat()
}
func EnableDebug() {
GetLogger().EnableDebug()
}
func DisableDebug() {
GetLogger().DisableDebug()
}
func IsDebugEnabled() bool {
return GetLogger().IsDebugEnabled()
}
func EnableTimestamp() {
GetLogger().EnableTimestamp()
}
func DisableTimestamp() {
GetLogger().DisableTimestamp()
}
func LogRequest(statusCode int, method, path string, duration time.Duration) {
GetLogger().LogRequest(statusCode, method, path, duration)
}

View File

@ -1,11 +0,0 @@
package metadata
// Version holds the current Moonshark version
const Version = "1.0"
// Build time information
var (
BuildTime = "unknown"
GitCommit = "unknown"
GoVersion = "unknown"
)

197
watcher.go Normal file
View File

@ -0,0 +1,197 @@
package main
import (
"fmt"
"maps"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
type FileWatcher struct {
files map[string]time.Time
dirs map[string]bool // Track watched directories
mu sync.RWMutex
restartCh chan bool
stopCh chan bool
debounceMs int
lastEvent time.Time
pollMs int
}
func NewFileWatcher(debounceMs int) (*FileWatcher, error) {
return &FileWatcher{
files: make(map[string]time.Time),
dirs: make(map[string]bool),
restartCh: make(chan bool, 1),
stopCh: make(chan bool, 1),
debounceMs: debounceMs,
pollMs: 250, // Poll every 250ms
}, nil
}
func (fw *FileWatcher) AddFile(path string) error {
absPath, err := filepath.Abs(path)
if err != nil {
return err
}
info, err := os.Stat(absPath)
if err != nil {
return err
}
fw.mu.Lock()
fw.files[absPath] = info.ModTime()
fw.mu.Unlock()
fmt.Printf("Watching: %s\n", absPath)
return nil
}
func (fw *FileWatcher) AddDirectory(dir string) error {
absDir, err := filepath.Abs(dir)
if err != nil {
return err
}
fw.mu.Lock()
fw.dirs[absDir] = true
fw.mu.Unlock()
return filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(path, ".lua") {
return fw.AddFile(path)
}
return nil
})
}
func (fw *FileWatcher) Start() <-chan bool {
go fw.pollLoop()
return fw.restartCh
}
func (fw *FileWatcher) pollLoop() {
ticker := time.NewTicker(time.Duration(fw.pollMs) * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-fw.stopCh:
return
case <-ticker.C:
fw.checkFiles()
fw.scanForNewFiles()
}
}
}
func (fw *FileWatcher) checkFiles() {
fw.mu.RLock()
files := make(map[string]time.Time, len(fw.files))
maps.Copy(files, fw.files)
fw.mu.RUnlock()
changed := false
for path, lastMod := range files {
info, err := os.Stat(path)
if err != nil {
continue
}
if info.ModTime().After(lastMod) {
fw.mu.Lock()
fw.files[path] = info.ModTime()
fw.mu.Unlock()
now := time.Now()
if now.Sub(fw.lastEvent) > time.Duration(fw.debounceMs)*time.Millisecond {
fw.lastEvent = now
// log.Printf("File changed: %s", path)
changed = true
}
}
}
if changed {
select {
case fw.restartCh <- true:
default:
}
}
}
func (fw *FileWatcher) scanForNewFiles() {
fw.mu.RLock()
dirs := make(map[string]bool, len(fw.dirs))
maps.Copy(dirs, fw.dirs)
existingFiles := make(map[string]bool, len(fw.files))
for path := range fw.files {
existingFiles[path] = true
}
fw.mu.RUnlock()
newFilesFound := false
for dir := range dirs {
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(path, ".lua") {
absPath, err := filepath.Abs(path)
if err != nil {
return err
}
if !existingFiles[absPath] {
fw.mu.Lock()
fw.files[absPath] = info.ModTime()
fw.mu.Unlock()
fmt.Printf("New file detected: %s\n", absPath)
newFilesFound = true
}
}
return nil
})
if err != nil {
continue
}
}
if newFilesFound {
now := time.Now()
if now.Sub(fw.lastEvent) > time.Duration(fw.debounceMs)*time.Millisecond {
fw.lastEvent = now
select {
case fw.restartCh <- true:
default:
}
}
}
}
func (fw *FileWatcher) Close() error {
select {
case fw.stopCh <- true:
default:
}
return nil
}
func (fw *FileWatcher) DiscoverRequiredFiles(scriptPath string) error {
if err := fw.AddFile(scriptPath); err != nil {
return err
}
scriptDir := filepath.Dir(scriptPath)
return fw.AddDirectory(scriptDir)
}

View File

@ -1,84 +0,0 @@
package watchers
import (
"fmt"
"strings"
"sync"
"Moonshark/router"
"Moonshark/runner"
"Moonshark/utils/color"
"Moonshark/utils/logger"
)
// Global watcher manager instance with explicit creation
var (
globalManager *WatcherManager
globalManagerOnce sync.Once
)
// GetWatcherManager returns the global watcher manager, creating it if needed
func GetWatcherManager() *WatcherManager {
globalManagerOnce.Do(func() {
globalManager = NewWatcherManager(DefaultPollInterval)
})
return globalManager
}
// ShutdownWatcherManager closes the global watcher manager if it exists
func ShutdownWatcherManager() {
if globalManager != nil {
globalManager.Close()
globalManager = nil
}
}
// WatchLuaRouter sets up a watcher for a LuaRouter's routes directory
func WatchLuaRouter(router *router.LuaRouter, runner *runner.Runner, routesDir string) (*DirectoryWatcher, error) {
manager := GetWatcherManager()
config := DirectoryWatcherConfig{
Dir: routesDir,
Callback: router.Refresh,
Recursive: true,
}
watcher, err := manager.WatchDirectory(config)
if err != nil {
return nil, fmt.Errorf("failed to watch directory: %w", err)
}
logger.Info("Started watching Lua routes! %s", color.Apply(routesDir, color.Yellow))
return watcher, nil
}
// WatchLuaModules sets up watchers for Lua module directories
func WatchLuaModules(luaRunner *runner.Runner, libDirs []string) ([]*DirectoryWatcher, error) {
manager := GetWatcherManager()
watchers := make([]*DirectoryWatcher, 0, len(libDirs))
for _, dir := range libDirs {
config := DirectoryWatcherConfig{
Dir: dir,
EnhancedCallback: func(changes []FileChange) error {
for _, change := range changes {
if !change.IsDeleted && strings.HasSuffix(change.Path, ".lua") {
luaRunner.NotifyFileChanged(change.Path)
}
}
return nil
},
Recursive: true,
}
watcher, err := manager.WatchDirectory(config)
if err != nil {
// Error handling...
}
watchers = append(watchers, watcher)
logger.Info("Started watching Lua modules! %s", color.Apply(dir, color.Yellow))
}
return watchers, nil
}

View File

@ -1,212 +0,0 @@
package watchers
import (
"fmt"
"os"
"path/filepath"
"sync"
"time"
"Moonshark/utils/logger"
)
// Default debounce time between detected change and callback
const defaultDebounceTime = 300 * time.Millisecond
// FileChange represents a detected file change
type FileChange struct {
Path string
IsNew bool
IsDeleted bool
}
// FileInfo stores minimal metadata about a file for change detection
type FileInfo struct {
ModTime time.Time
}
// DirectoryWatcher watches a specific directory for changes
type DirectoryWatcher struct {
// Directory to watch
dir string
// Map of file paths to their metadata
files map[string]FileInfo
filesMu sync.RWMutex
// Track changed files during a check cycle
changedFiles []FileChange
// Enhanced callback that receives changes (optional)
enhancedCallback func([]FileChange) error
// Configuration
callback func() error
debounceTime time.Duration
recursive bool
// Debounce timer
debounceTimer *time.Timer
debouncing bool
debounceMu sync.Mutex
// Error tracking
consecutiveErrors int
lastError error
}
// DirectoryWatcherConfig contains configuration for a directory watcher
type DirectoryWatcherConfig struct {
Dir string // Directory to watch
Callback func() error // Callback function to call when changes are detected
DebounceTime time.Duration // Debounce time (0 means use default)
Recursive bool // Recursive watching (watch subdirectories)
EnhancedCallback func([]FileChange) error // Enhanced callback that receives file changes
}
// scanDirectory builds the initial file list
func (w *DirectoryWatcher) scanDirectory() error {
w.filesMu.Lock()
defer w.filesMu.Unlock()
w.files = make(map[string]FileInfo)
return filepath.Walk(w.dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil // Skip files with errors
}
if !w.recursive && info.IsDir() && path != w.dir {
return filepath.SkipDir
}
w.files[path] = FileInfo{
ModTime: info.ModTime(),
}
return nil
})
}
// checkForChanges detects if any files have been added, modified, or deleted
func (w *DirectoryWatcher) checkForChanges() (bool, error) {
w.filesMu.RLock()
prevFileCount := len(w.files)
w.filesMu.RUnlock()
newFiles := make(map[string]FileInfo)
changed := false
w.changedFiles = nil // Reset changed files list
err := filepath.Walk(w.dir, func(path string, info os.FileInfo, err error) error {
// Skip errors
if err != nil {
return nil
}
if !w.recursive && info.IsDir() && path != w.dir {
return filepath.SkipDir
}
currentInfo := FileInfo{
ModTime: info.ModTime(),
}
newFiles[path] = currentInfo
w.filesMu.RLock()
prevInfo, exists := w.files[path]
w.filesMu.RUnlock()
if !exists {
changed = true
w.changedFiles = append(w.changedFiles, FileChange{
Path: path,
IsNew: true,
})
w.logDebug("File added: %s", path)
} else if currentInfo.ModTime != prevInfo.ModTime {
changed = true
w.changedFiles = append(w.changedFiles, FileChange{
Path: path,
})
w.logDebug("File changed: %s", path)
}
return nil
})
// Only check for deleted files if needed
if err == nil && (!changed && len(newFiles) != prevFileCount) {
w.filesMu.RLock()
for path := range w.files {
if _, exists := newFiles[path]; !exists {
changed = true
w.changedFiles = append(w.changedFiles, FileChange{
Path: path,
IsDeleted: true,
})
w.logDebug("File deleted: %s", path)
break // We already know changes happened
}
}
w.filesMu.RUnlock()
}
if changed {
w.filesMu.Lock()
w.files = newFiles
w.filesMu.Unlock()
}
return changed, err
}
// notifyChange triggers the callback with debouncing
func (w *DirectoryWatcher) notifyChange() {
w.debounceMu.Lock()
defer w.debounceMu.Unlock()
if w.debouncing {
if w.debounceTimer != nil {
w.debounceTimer.Stop()
}
} else {
w.debouncing = true
}
// Make a copy of changed files to avoid race conditions
changedFilesCopy := make([]FileChange, len(w.changedFiles))
copy(changedFilesCopy, w.changedFiles)
w.debounceTimer = time.AfterFunc(w.debounceTime, func() {
var err error
if w.enhancedCallback != nil {
err = w.enhancedCallback(changedFilesCopy)
} else if w.callback != nil {
err = w.callback()
}
if err != nil {
w.logError("Callback error: %v", err)
}
w.debounceMu.Lock()
w.debouncing = false
w.debounceMu.Unlock()
})
}
// logDebug logs a debug message with the watcher's directory prefix
func (w *DirectoryWatcher) logDebug(format string, args ...any) {
logger.Debug("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...))
}
// logError logs an error message with the watcher's directory prefix
func (w *DirectoryWatcher) logError(format string, args ...any) {
logger.Error("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...))
}
// GetDir gets the DirectoryWatcher's current directory
func (w *DirectoryWatcher) GetDir() string {
return w.dir
}

View File

@ -1,154 +0,0 @@
package watchers
import (
"errors"
"sync"
"time"
"Moonshark/utils/logger"
)
// DefaultPollInterval is the time between directory checks
const DefaultPollInterval = 1 * time.Second
// Common errors
var (
ErrDirectoryNotFound = errors.New("directory not found")
ErrAlreadyWatching = errors.New("already watching directory")
)
// WatcherManager coordinates file watching across multiple directories
type WatcherManager struct {
watchers map[string]*DirectoryWatcher
mu sync.RWMutex
done chan struct{}
ticker *time.Ticker
interval time.Duration
wg sync.WaitGroup
}
// NewWatcherManager creates a new watcher manager with a specified poll interval
func NewWatcherManager(pollInterval time.Duration) *WatcherManager {
if pollInterval <= 0 {
pollInterval = DefaultPollInterval
}
manager := &WatcherManager{
watchers: make(map[string]*DirectoryWatcher),
done: make(chan struct{}),
interval: pollInterval,
}
manager.ticker = time.NewTicker(pollInterval)
manager.wg.Add(1)
go manager.pollLoop()
return manager
}
// Close stops all watchers and the manager
func (m *WatcherManager) Close() error {
close(m.done)
if m.ticker != nil {
m.ticker.Stop()
}
m.wg.Wait()
return nil
}
// WatchDirectory adds a new directory to watch and returns the watcher
func (m *WatcherManager) WatchDirectory(config DirectoryWatcherConfig) (*DirectoryWatcher, error) {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.watchers[config.Dir]; exists {
return nil, ErrAlreadyWatching
}
if config.DebounceTime == 0 {
config.DebounceTime = defaultDebounceTime
}
watcher := &DirectoryWatcher{
dir: config.Dir,
files: make(map[string]FileInfo),
callback: config.Callback,
enhancedCallback: config.EnhancedCallback,
debounceTime: config.DebounceTime,
recursive: config.Recursive,
}
// Perform initial scan
if err := watcher.scanDirectory(); err != nil {
return nil, err
}
m.watchers[config.Dir] = watcher
logger.Debug("WatcherManager added watcher for %s", config.Dir)
return watcher, nil
}
// UnwatchDirectory removes a directory from being watched
func (m *WatcherManager) UnwatchDirectory(dir string) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.watchers[dir]; !exists {
return ErrDirectoryNotFound
}
delete(m.watchers, dir)
logger.Debug("WatcherManager removed watcher for %s", dir)
return nil
}
// pollLoop is the main polling loop that checks all watched directories
func (m *WatcherManager) pollLoop() {
defer m.wg.Done()
for {
select {
case <-m.ticker.C:
m.checkAllDirectories()
case <-m.done:
return
}
}
}
// checkAllDirectories polls all registered directories for changes
func (m *WatcherManager) checkAllDirectories() {
m.mu.RLock()
watchers := make([]*DirectoryWatcher, 0, len(m.watchers))
for _, w := range m.watchers {
watchers = append(watchers, w)
}
m.mu.RUnlock()
changesDetected := 0
for _, watcher := range watchers {
if watcher.consecutiveErrors > 3 {
if watcher.consecutiveErrors == 4 {
logger.Error("Temporarily skipping directory %s due to errors: %v",
watcher.dir, watcher.lastError)
watcher.consecutiveErrors++
}
continue
}
changed, err := watcher.checkForChanges()
if err != nil {
logger.Error("Error checking directory %s: %v", watcher.dir, err)
continue
}
if changed {
changesDetected++
watcher.notifyChange()
}
}
}