Compare commits
No commits in common. "55e6b5478940558d5094e13be9254cce0ffabfcb" and "31a3625873f27c893c8abd634137fa9d721bc32b" have entirely different histories.
55e6b54789
...
31a3625873
14
.gitignore
vendored
14
.gitignore
vendored
@ -22,14 +22,8 @@ luajit/.git
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
||||
# Claude workspace files
|
||||
.claude
|
||||
CLAUDE.md
|
||||
|
||||
# Test directories and files
|
||||
/*.lua
|
||||
test_fs_dir
|
||||
public
|
||||
test
|
||||
test.db
|
||||
build
|
||||
/config.lua
|
||||
test/
|
||||
/init.lua
|
||||
/moonshark
|
||||
|
||||
19
LICENSE
19
LICENSE
@ -1,2 +1,19 @@
|
||||
## Sharkk Open License
|
||||
|
||||
### Version 1.0, March 2025
|
||||
|
||||
Copyright (c) Sharkk, Skylear Johnson
|
||||
DO NOT USE THIS SOFTWARE
|
||||
|
||||
Hey there, code surfer! You're free to ride this wave—use, modify, and share this software however you like, as long as you stick to these chill but important rules:
|
||||
|
||||
1. **Share Your Changes**: If you tweak, remix, or build on this software, you’ve gotta share your work with the world under the same license. That means making your modified source code available in a reasonable way—like linking to a public repo. Keep the stoke alive!
|
||||
|
||||
2. **Keep This License**: Whenever you pass this software along (whether you’ve changed it or not), you need to include this license in full. No sneaky restrictions that limit the freedom to ride the digital waves.
|
||||
|
||||
3. **Give Credit Where It’s Due**: Show some love to the original author(s) by keeping the copyright notice and, if possible, linking back to the original source. Good vibes and respect go a long way.
|
||||
|
||||
4. **Make It Your Own**: If you add your own original code or features, you’re totally free to monetize those additions. Sell it, license it, or turn it into the next big thing—just keep the original parts open for everyone.
|
||||
|
||||
5. **No Guarantees**: This software comes "as is." No promises, no warranties—just pure, unfiltered code. If things go sideways, you’re riding that wave at your own risk. The authors aren’t responsible for any wipeouts.
|
||||
|
||||
By using, modifying, or sharing this software, you’re agreeing to these terms. Keep it open, keep it flowing, and most of all—have fun!
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
# Moonshark
|
||||
|
||||
```bash
|
||||
go build -trimpath -ldflags="-s -w" -o build/moonshark .
|
||||
git submodule update --init --recursive
|
||||
git submodule update --remote --recursive
|
||||
|
||||
go build -trimpath -ldflags="-s -w" -o moonshark .
|
||||
```
|
||||
|
||||
4
build.sh
4
build.sh
@ -1,4 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
mkdir -p build
|
||||
go build -trimpath -ldflags="-s -w" -o build/moonshark .
|
||||
21
config-example.lua
Normal file
21
config-example.lua
Normal file
@ -0,0 +1,21 @@
|
||||
server = {
|
||||
port = 3117,
|
||||
debug = false,
|
||||
log_level = "info",
|
||||
http_logging = false
|
||||
}
|
||||
|
||||
runner = {
|
||||
pool_size = 0 -- 0 defaults to GOMAXPROCS
|
||||
}
|
||||
|
||||
dirs = {
|
||||
routes = "routes",
|
||||
static = "public",
|
||||
fs = "fs",
|
||||
data = "data",
|
||||
override = "override",
|
||||
libs = {
|
||||
"libs"
|
||||
}
|
||||
}
|
||||
41
go.mod
41
go.mod
@ -2,39 +2,34 @@ module Moonshark
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6
|
||||
|
||||
require (
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
git.sharkk.net/Go/LRU v1.0.0
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.4.1
|
||||
github.com/VictoriaMetrics/fastcache v1.12.4
|
||||
github.com/alexedwards/argon2id v1.0.0
|
||||
github.com/deneonet/benc v1.1.8
|
||||
github.com/goccy/go-json v0.10.5
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.5
|
||||
golang.org/x/crypto v0.40.0
|
||||
github.com/golang/snappy v1.0.0
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0
|
||||
github.com/valyala/bytebufferpool v1.0.0
|
||||
github.com/valyala/fasthttp v1.62.0
|
||||
zombiezen.com/go/sqlite v1.4.2
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.34.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
golang.org/x/crypto v0.38.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
modernc.org/libc v1.65.8 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.38.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.64.0
|
||||
modernc.org/sqlite v1.37.1 // indirect
|
||||
)
|
||||
|
||||
134
go.sum
134
go.sum
@ -1,32 +1,35 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6 h1:XytP9R2fWykv0MXIzxggPx5S/PmTkjyZVvUX2sn4EaU=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.6/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
git.sharkk.net/Go/LRU v1.0.0 h1:/KqdRVhHldi23aVfQZ4ss6vhCWZqA3vFiQyf1MJPpQc=
|
||||
git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.4.1 h1:CAYt+C6Vgo4JxK876j0ApQ2GDFFvy9FKO0OoZBVD18k=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.4.1/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
|
||||
github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0=
|
||||
github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI=
|
||||
github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w=
|
||||
github.com/alexedwards/argon2id v1.0.0/go.mod h1:tYKkqIjzXvZdzPvADMWOEZ+l6+BD6CtBXMj5fnJppiw=
|
||||
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8=
|
||||
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/deneonet/benc v1.1.8 h1:Qk9diyH0UcnduvCrZ62mBrwUeSZzte4kQxMbclVdhW4=
|
||||
github.com/deneonet/benc v1.1.8/go.mod h1:UCfkM5Od0B2huwv/ZItvtUb7QnALFt9YXtX8NXX4Lts=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs=
|
||||
github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE=
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
@ -35,48 +38,79 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.64.0 h1:QBygLLQmiAyiXuRhthf0tuRkqAFcrC42dckN2S+N3og=
|
||||
github.com/valyala/fasthttp v1.64.0/go.mod h1:dGmFxwkWXSK0NbOSJuF7AMVzU+lkHz0wQVvVITv2UQA=
|
||||
github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0=
|
||||
github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
||||
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
|
||||
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
|
||||
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/cc/v4 v4.26.1 h1:+X5NtzVBn0KgsBCBe+xkDC7twLb/jNVj9FPgiwSQO3s=
|
||||
modernc.org/cc/v4 v4.26.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
||||
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
||||
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
|
||||
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/fileutil v1.3.1 h1:8vq5fe7jdtEvoCf3Zf9Nm0Q05sH6kGx0Op2CPx1wTC8=
|
||||
modernc.org/fileutil v1.3.1/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
|
||||
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
|
||||
modernc.org/libc v1.65.8 h1:7PXRJai0TXZ8uNA3srsmYzmTyrLoHImV5QxHeni108Q=
|
||||
modernc.org/libc v1.65.8/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
@ -85,8 +119,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
|
||||
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
|
||||
modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs=
|
||||
modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
344
http/server.go
Normal file
344
http/server.go
Normal file
@ -0,0 +1,344 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"Moonshark/router"
|
||||
"Moonshark/runner"
|
||||
"Moonshark/sessions"
|
||||
"Moonshark/utils"
|
||||
"Moonshark/utils/color"
|
||||
"Moonshark/utils/config"
|
||||
"Moonshark/utils/logger"
|
||||
"Moonshark/utils/metadata"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var (
|
||||
//methodGET = []byte("GET")
|
||||
methodPOST = []byte("POST")
|
||||
methodPUT = []byte("PUT")
|
||||
methodPATCH = []byte("PATCH")
|
||||
debugPath = []byte("/debug/stats")
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
luaRouter *router.LuaRouter
|
||||
staticHandler fasthttp.RequestHandler
|
||||
staticFS *fasthttp.FS
|
||||
luaRunner *runner.Runner
|
||||
fasthttpServer *fasthttp.Server
|
||||
loggingEnabled bool
|
||||
debugMode bool
|
||||
config *config.Config
|
||||
sessionManager *sessions.SessionManager
|
||||
errorConfig utils.ErrorPageConfig
|
||||
ctxPool sync.Pool
|
||||
paramsPool sync.Pool
|
||||
staticDir string
|
||||
staticPrefix string
|
||||
staticPrefixBytes []byte
|
||||
|
||||
// Cached error pages
|
||||
cached404 []byte
|
||||
cached500 []byte
|
||||
errorCacheMu sync.RWMutex
|
||||
}
|
||||
|
||||
func New(luaRouter *router.LuaRouter, staticDir string,
|
||||
runner *runner.Runner, loggingEnabled bool, debugMode bool,
|
||||
overrideDir string, config *config.Config) *Server {
|
||||
|
||||
staticPrefix := config.Server.StaticPrefix
|
||||
if staticPrefix == "" {
|
||||
staticPrefix = "/static/"
|
||||
}
|
||||
|
||||
if staticPrefix[0] != '/' {
|
||||
staticPrefix = "/" + staticPrefix
|
||||
}
|
||||
if staticPrefix[len(staticPrefix)-1] != '/' {
|
||||
staticPrefix = staticPrefix + "/"
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
luaRouter: luaRouter,
|
||||
luaRunner: runner,
|
||||
loggingEnabled: loggingEnabled,
|
||||
debugMode: debugMode,
|
||||
config: config,
|
||||
sessionManager: sessions.GlobalSessionManager,
|
||||
staticDir: staticDir,
|
||||
staticPrefix: staticPrefix,
|
||||
staticPrefixBytes: []byte(staticPrefix),
|
||||
errorConfig: utils.ErrorPageConfig{
|
||||
OverrideDir: overrideDir,
|
||||
DebugMode: debugMode,
|
||||
},
|
||||
ctxPool: sync.Pool{
|
||||
New: func() any {
|
||||
return make(map[string]any, 6)
|
||||
},
|
||||
},
|
||||
paramsPool: sync.Pool{
|
||||
New: func() any {
|
||||
return make(map[string]any, 4)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Pre-cache error pages
|
||||
s.cached404 = []byte(utils.NotFoundPage(s.errorConfig, ""))
|
||||
s.cached500 = []byte(utils.InternalErrorPage(s.errorConfig, "", "Internal Server Error"))
|
||||
|
||||
// Setup static file serving
|
||||
if staticDir != "" {
|
||||
s.staticFS = &fasthttp.FS{
|
||||
Root: staticDir,
|
||||
IndexNames: []string{"index.html"},
|
||||
GenerateIndexPages: false,
|
||||
AcceptByteRange: true,
|
||||
Compress: true,
|
||||
CompressedFileSuffix: ".gz",
|
||||
CompressBrotli: true,
|
||||
CompressZstd: true,
|
||||
PathRewrite: fasthttp.NewPathPrefixStripper(len(staticPrefix) - 1),
|
||||
}
|
||||
s.staticHandler = s.staticFS.NewRequestHandler()
|
||||
}
|
||||
|
||||
s.fasthttpServer = &fasthttp.Server{
|
||||
Handler: s.handleRequest,
|
||||
Name: "Moonshark/" + metadata.Version,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
MaxRequestBodySize: 16 << 20,
|
||||
TCPKeepalive: true,
|
||||
TCPKeepalivePeriod: 60 * time.Second,
|
||||
ReduceMemoryUsage: true,
|
||||
DisablePreParseMultipartForm: true,
|
||||
DisableHeaderNamesNormalizing: true,
|
||||
NoDefaultServerHeader: true,
|
||||
StreamRequestBody: true,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(addr string) error {
|
||||
logger.Info("Catch the swell at %s", color.Apply("http://localhost"+addr, color.Cyan))
|
||||
return s.fasthttpServer.ListenAndServe(addr)
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
return s.fasthttpServer.ShutdownWithContext(ctx)
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
|
||||
start := time.Now()
|
||||
methodBytes := ctx.Method()
|
||||
pathBytes := ctx.Path()
|
||||
|
||||
// Fast path for debug stats
|
||||
if s.debugMode && bytes.Equal(pathBytes, debugPath) {
|
||||
s.handleDebugStats(ctx)
|
||||
if s.loggingEnabled {
|
||||
logger.LogRequest(ctx.Response.StatusCode(), string(methodBytes), string(pathBytes), time.Since(start))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Fast path for static files
|
||||
if s.staticHandler != nil && bytes.HasPrefix(pathBytes, s.staticPrefixBytes) {
|
||||
s.staticHandler(ctx)
|
||||
if s.loggingEnabled {
|
||||
logger.LogRequest(ctx.Response.StatusCode(), string(methodBytes), string(pathBytes), time.Since(start))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Lua route lookup - only allocate params if found
|
||||
bytecode, scriptPath, routeErr, params, found := s.luaRouter.GetRouteInfo(methodBytes, pathBytes)
|
||||
|
||||
if found {
|
||||
if len(bytecode) == 0 || routeErr != nil {
|
||||
s.sendError(ctx, fasthttp.StatusInternalServerError, pathBytes, routeErr)
|
||||
} else {
|
||||
s.handleLuaRoute(ctx, bytecode, scriptPath, params, methodBytes, pathBytes)
|
||||
}
|
||||
} else {
|
||||
s.send404(ctx, pathBytes)
|
||||
}
|
||||
|
||||
if s.loggingEnabled {
|
||||
logger.LogRequest(ctx.Response.StatusCode(), string(methodBytes), string(pathBytes), time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string,
|
||||
params *router.Params, methodBytes, pathBytes []byte) {
|
||||
|
||||
luaCtx := runner.NewHTTPContext(ctx)
|
||||
defer luaCtx.Release()
|
||||
|
||||
if runner.GetGlobalEnvManager() != nil {
|
||||
luaCtx.Set("env", runner.GetGlobalEnvManager().GetAll())
|
||||
}
|
||||
|
||||
sessionMap := s.ctxPool.Get().(map[string]any)
|
||||
defer func() {
|
||||
for k := range sessionMap {
|
||||
delete(sessionMap, k)
|
||||
}
|
||||
s.ctxPool.Put(sessionMap)
|
||||
}()
|
||||
|
||||
session := s.sessionManager.GetSessionFromRequest(ctx)
|
||||
sessionMap["id"] = session.ID
|
||||
|
||||
// Only get session data if not empty
|
||||
if !session.IsEmpty() {
|
||||
sessionMap["data"] = session.GetAll()
|
||||
} else {
|
||||
sessionMap["data"] = emptyMap
|
||||
}
|
||||
|
||||
// Set basic context
|
||||
luaCtx.Set("method", string(methodBytes))
|
||||
luaCtx.Set("path", string(pathBytes))
|
||||
luaCtx.Set("host", string(ctx.Host()))
|
||||
luaCtx.Set("session", sessionMap)
|
||||
|
||||
// Add headers to context
|
||||
headers := make(map[string]any)
|
||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
headers[string(key)] = string(value)
|
||||
})
|
||||
luaCtx.Set("headers", headers)
|
||||
|
||||
// Handle params
|
||||
if params != nil && params.Count > 0 {
|
||||
paramMap := s.paramsPool.Get().(map[string]any)
|
||||
for i := range params.Count {
|
||||
paramMap[params.Keys[i]] = params.Values[i]
|
||||
}
|
||||
luaCtx.Set("params", paramMap)
|
||||
defer func() {
|
||||
for k := range paramMap {
|
||||
delete(paramMap, k)
|
||||
}
|
||||
s.paramsPool.Put(paramMap)
|
||||
}()
|
||||
} else {
|
||||
luaCtx.Set("params", emptyMap)
|
||||
}
|
||||
|
||||
// Parse form data for POST/PUT/PATCH
|
||||
if bytes.Equal(methodBytes, methodPOST) ||
|
||||
bytes.Equal(methodBytes, methodPUT) ||
|
||||
bytes.Equal(methodBytes, methodPATCH) {
|
||||
if formData, err := ParseForm(ctx); err == nil {
|
||||
luaCtx.Set("form", formData)
|
||||
} else {
|
||||
if s.debugMode {
|
||||
logger.Warning("Error parsing form: %v", err)
|
||||
}
|
||||
luaCtx.Set("form", emptyMap)
|
||||
}
|
||||
} else {
|
||||
luaCtx.Set("form", emptyMap)
|
||||
}
|
||||
|
||||
response, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
|
||||
if err != nil {
|
||||
logger.Error("Lua execution error: %v", err)
|
||||
s.sendError(ctx, fasthttp.StatusInternalServerError, pathBytes, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle session updates
|
||||
if len(response.SessionData) > 0 {
|
||||
if _, clearAll := response.SessionData["__clear_all"]; clearAll {
|
||||
session.Clear()
|
||||
delete(response.SessionData, "__clear_all")
|
||||
}
|
||||
|
||||
for k, v := range response.SessionData {
|
||||
if v == "__SESSION_DELETE_MARKER__" {
|
||||
session.Delete(k)
|
||||
} else {
|
||||
session.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.sessionManager.ApplySessionCookie(ctx, session)
|
||||
runner.ApplyResponse(response, ctx)
|
||||
runner.ReleaseResponse(response)
|
||||
}
|
||||
|
||||
func (s *Server) send404(ctx *fasthttp.RequestCtx, pathBytes []byte) {
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
|
||||
// Use cached 404 for common case
|
||||
if len(pathBytes) == 1 && pathBytes[0] == '/' {
|
||||
s.errorCacheMu.RLock()
|
||||
ctx.SetBody(s.cached404)
|
||||
s.errorCacheMu.RUnlock()
|
||||
} else {
|
||||
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, string(pathBytes))))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sendError(ctx *fasthttp.RequestCtx, status int, pathBytes []byte, err error) {
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(status)
|
||||
|
||||
if err == nil {
|
||||
s.errorCacheMu.RLock()
|
||||
ctx.SetBody(s.cached500)
|
||||
s.errorCacheMu.RUnlock()
|
||||
} else {
|
||||
ctx.SetBody([]byte(utils.InternalErrorPage(s.errorConfig, string(pathBytes), err.Error())))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) {
|
||||
stats := utils.CollectSystemStats(s.config)
|
||||
routeCount, bytecodeBytes := s.luaRouter.GetRouteStats()
|
||||
stats.Components = utils.ComponentStats{
|
||||
RouteCount: routeCount,
|
||||
BytecodeBytes: bytecodeBytes,
|
||||
SessionStats: sessions.GlobalSessionManager.GetCacheStats(),
|
||||
}
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetBody([]byte(utils.DebugStatsPage(stats)))
|
||||
}
|
||||
|
||||
// SetStaticCaching enables/disables static file caching
|
||||
func (s *Server) SetStaticCaching(duration time.Duration) {
|
||||
if s.staticFS != nil {
|
||||
s.staticFS.CacheDuration = duration
|
||||
s.staticHandler = s.staticFS.NewRequestHandler()
|
||||
}
|
||||
}
|
||||
|
||||
// GetStaticPrefix returns the URL prefix for static files
|
||||
func (s *Server) GetStaticPrefix() string {
|
||||
return s.staticPrefix
|
||||
}
|
||||
|
||||
// UpdateErrorCache refreshes cached error pages
|
||||
func (s *Server) UpdateErrorCache() {
|
||||
s.errorCacheMu.Lock()
|
||||
s.cached404 = []byte(utils.NotFoundPage(s.errorConfig, ""))
|
||||
s.cached500 = []byte(utils.InternalErrorPage(s.errorConfig, "", "Internal Server Error"))
|
||||
s.errorCacheMu.Unlock()
|
||||
}
|
||||
147
http/utils.go
Normal file
147
http/utils.go
Normal file
@ -0,0 +1,147 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var (
|
||||
emptyMap = make(map[string]any)
|
||||
formDataPool = sync.Pool{
|
||||
New: func() any {
|
||||
return make(map[string]any, 16)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
|
||||
args := ctx.QueryArgs()
|
||||
if args.Len() == 0 {
|
||||
return emptyMap
|
||||
}
|
||||
|
||||
queryMap := make(map[string]any, args.Len())
|
||||
args.VisitAll(func(key, value []byte) {
|
||||
k := string(key)
|
||||
v := string(value)
|
||||
appendValue(queryMap, k, v)
|
||||
})
|
||||
return queryMap
|
||||
}
|
||||
|
||||
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
if strings.Contains(string(ctx.Request.Header.ContentType()), "multipart/form-data") {
|
||||
return parseMultipartForm(ctx)
|
||||
}
|
||||
|
||||
args := ctx.PostArgs()
|
||||
if args.Len() == 0 {
|
||||
return emptyMap, nil
|
||||
}
|
||||
|
||||
formData := formDataPool.Get().(map[string]any)
|
||||
for k := range formData {
|
||||
delete(formData, k)
|
||||
}
|
||||
|
||||
args.VisitAll(func(key, value []byte) {
|
||||
k := string(key)
|
||||
v := string(value)
|
||||
appendValue(formData, k, v)
|
||||
})
|
||||
return formData, nil
|
||||
}
|
||||
|
||||
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
form, err := ctx.MultipartForm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
formData := formDataPool.Get().(map[string]any)
|
||||
for k := range formData {
|
||||
delete(formData, k)
|
||||
}
|
||||
|
||||
for key, values := range form.Value {
|
||||
if len(values) == 1 {
|
||||
formData[key] = values[0]
|
||||
} else if len(values) > 1 {
|
||||
formData[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
if len(form.File) > 0 {
|
||||
files := make(map[string]any, len(form.File))
|
||||
for fieldName, fileHeaders := range form.File {
|
||||
if len(fileHeaders) == 1 {
|
||||
files[fieldName] = fileInfoToMap(fileHeaders[0])
|
||||
} else {
|
||||
fileInfos := make([]map[string]any, len(fileHeaders))
|
||||
for i, fh := range fileHeaders {
|
||||
fileInfos[i] = fileInfoToMap(fh)
|
||||
}
|
||||
files[fieldName] = fileInfos
|
||||
}
|
||||
}
|
||||
formData["_files"] = files
|
||||
}
|
||||
|
||||
return formData, nil
|
||||
}
|
||||
|
||||
func fileInfoToMap(fh *multipart.FileHeader) map[string]any {
|
||||
ct := fh.Header.Get("Content-Type")
|
||||
if ct == "" {
|
||||
ct = getMimeType(fh.Filename)
|
||||
}
|
||||
return map[string]any{
|
||||
"filename": fh.Filename,
|
||||
"size": fh.Size,
|
||||
"mimetype": ct,
|
||||
}
|
||||
}
|
||||
|
||||
func getMimeType(filename string) string {
|
||||
if i := strings.LastIndex(filename, "."); i >= 0 {
|
||||
switch filename[i:] {
|
||||
case ".pdf":
|
||||
return "application/pdf"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".gif":
|
||||
return "image/gif"
|
||||
case ".svg":
|
||||
return "image/svg+xml"
|
||||
}
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
func appendValue(m map[string]any, k, v string) {
|
||||
if existing, exists := m[k]; exists {
|
||||
switch typed := existing.(type) {
|
||||
case []string:
|
||||
m[k] = append(typed, v)
|
||||
case string:
|
||||
m[k] = []string{typed, v}
|
||||
}
|
||||
} else {
|
||||
m[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateSecureToken(length int) (string, error) {
|
||||
b := make([]byte, length)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b)[:length], nil
|
||||
}
|
||||
393
main.go
Normal file
393
main.go
Normal file
@ -0,0 +1,393 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"Moonshark/http"
|
||||
"Moonshark/router"
|
||||
"Moonshark/runner"
|
||||
"Moonshark/sessions"
|
||||
"Moonshark/utils/color"
|
||||
"Moonshark/utils/config"
|
||||
"Moonshark/utils/logger"
|
||||
"Moonshark/utils/metadata"
|
||||
"Moonshark/watchers"
|
||||
)
|
||||
|
||||
// Moonshark represents the server and all its dependencies
|
||||
type Moonshark struct {
|
||||
Config *config.Config
|
||||
LuaRouter *router.LuaRouter
|
||||
LuaRunner *runner.Runner
|
||||
HTTPServer *http.Server
|
||||
cleanupFuncs []func() error
|
||||
scriptMode bool
|
||||
}
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("config", "config.lua", "Path to configuration file")
|
||||
debugFlag := flag.Bool("debug", false, "Enable debug mode")
|
||||
scriptPath := flag.String("script", "", "Path to Lua script to execute once")
|
||||
flag.Parse()
|
||||
scriptMode := *scriptPath != ""
|
||||
|
||||
banner()
|
||||
mode := ""
|
||||
if scriptMode {
|
||||
mode = "[Script Mode]"
|
||||
} else {
|
||||
mode = "[Server Mode]"
|
||||
}
|
||||
fmt.Printf("%s %s\n\n", color.Apply(mode, color.Gray), color.Apply("v"+metadata.Version, color.Blue))
|
||||
|
||||
// Initialize logger
|
||||
logger.InitGlobalLogger(true, false)
|
||||
|
||||
// Load config
|
||||
cfg, err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
logger.Warning("Config load failed: %v, using defaults", color.Apply(err.Error(), color.Red))
|
||||
cfg = config.New()
|
||||
}
|
||||
|
||||
// Setup logging with debug mode
|
||||
if *debugFlag || cfg.Server.Debug {
|
||||
logger.EnableDebug()
|
||||
logger.Debug("Debug logging enabled")
|
||||
}
|
||||
|
||||
var moonshark *Moonshark
|
||||
|
||||
if scriptMode {
|
||||
moonshark, err = initScriptMode(cfg)
|
||||
} else {
|
||||
moonshark, err = initServerMode(cfg, *debugFlag)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Fatal("Initialization failed: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := moonshark.Shutdown(); err != nil {
|
||||
logger.Error("Error during shutdown: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
if scriptMode {
|
||||
// Run the script and exit
|
||||
if err := moonshark.RunScript(*scriptPath); err != nil {
|
||||
logger.Fatal("Script execution failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Start the server
|
||||
if err := moonshark.Start(); err != nil {
|
||||
logger.Fatal("Failed to start server: %v", err)
|
||||
}
|
||||
|
||||
// Wait for shutdown signal
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||
<-stop
|
||||
|
||||
fmt.Print("\n")
|
||||
logger.Info("Shutdown signal received")
|
||||
}
|
||||
|
||||
// initScriptMode initializes minimal components needed for script execution
|
||||
func initScriptMode(cfg *config.Config) (*Moonshark, error) {
|
||||
moonshark := &Moonshark{
|
||||
Config: cfg,
|
||||
scriptMode: true,
|
||||
}
|
||||
|
||||
// Only initialize the Lua runner with required paths
|
||||
runnerOpts := []runner.RunnerOption{
|
||||
runner.WithPoolSize(1), // Only need one state for script mode
|
||||
runner.WithLibDirs(cfg.Dirs.Libs...),
|
||||
runner.WithFsDir(cfg.Dirs.FS),
|
||||
runner.WithDataDir(cfg.Dirs.Data),
|
||||
}
|
||||
|
||||
var err error
|
||||
moonshark.LuaRunner, err = runner.NewRunner(runnerOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize Lua runner: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug("Script mode initialized with minimized components")
|
||||
return moonshark, nil
|
||||
}
|
||||
|
||||
// initServerMode initializes all components needed for server operation
|
||||
func initServerMode(cfg *config.Config, debug bool) (*Moonshark, error) {
|
||||
moonshark := &Moonshark{
|
||||
Config: cfg,
|
||||
scriptMode: false,
|
||||
}
|
||||
|
||||
if debug {
|
||||
cfg.Server.Debug = true
|
||||
}
|
||||
|
||||
if err := initLuaRouter(moonshark); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := initRunner(moonshark); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := setupWatchers(moonshark); err != nil {
|
||||
logger.Warning("Watcher setup failed: %v", err)
|
||||
}
|
||||
|
||||
// Get static directory - empty string if it doesn't exist
|
||||
staticDir := ""
|
||||
if dirExists(cfg.Dirs.Static) {
|
||||
staticDir = cfg.Dirs.Static
|
||||
logger.Info("Static files enabled: %s", color.Apply(staticDir, color.Yellow))
|
||||
} else {
|
||||
logger.Warning("Static directory not found: %s", color.Apply(cfg.Dirs.Static, color.Yellow))
|
||||
}
|
||||
|
||||
moonshark.HTTPServer = http.New(
|
||||
moonshark.LuaRouter,
|
||||
staticDir,
|
||||
moonshark.LuaRunner,
|
||||
cfg.Server.HTTPLogging,
|
||||
cfg.Server.Debug,
|
||||
cfg.Dirs.Override,
|
||||
cfg,
|
||||
)
|
||||
|
||||
// For development, disable caching. For production, enable it
|
||||
if cfg.Server.Debug {
|
||||
moonshark.HTTPServer.SetStaticCaching(0) // No caching in debug mode
|
||||
} else {
|
||||
moonshark.HTTPServer.SetStaticCaching(1 * time.Hour) // Cache for 1 hour in production
|
||||
}
|
||||
|
||||
return moonshark, nil
|
||||
}
|
||||
|
||||
// RunScript executes a Lua script in the sandbox environment
|
||||
func (s *Moonshark) RunScript(scriptPath string) error {
|
||||
scriptPath, err := filepath.Abs(scriptPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve script path: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(scriptPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("script file not found: %s", scriptPath)
|
||||
}
|
||||
|
||||
logger.Info("Executing: %s", scriptPath)
|
||||
|
||||
resp, err := s.LuaRunner.RunScriptFile(scriptPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("execution failed: %v", err)
|
||||
}
|
||||
|
||||
if resp != nil && resp.Body != nil {
|
||||
logger.Info("Script result: %v", resp.Body)
|
||||
} else {
|
||||
logger.Info("Script executed successfully (no return value)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *Moonshark) Start() error {
|
||||
if s.scriptMode {
|
||||
return errors.New("cannot start server in script mode")
|
||||
}
|
||||
|
||||
logger.Info("Surf's up on port %s!", color.Apply(strconv.Itoa(s.Config.Server.Port), color.Cyan))
|
||||
|
||||
go func() {
|
||||
if err := s.HTTPServer.ListenAndServe(fmt.Sprintf(":%d", s.Config.Server.Port)); err != nil {
|
||||
if err.Error() != "http: Server closed" {
|
||||
logger.Error("Server error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down Moonshark
|
||||
func (s *Moonshark) Shutdown() error {
|
||||
logger.Info("Shutting down...")
|
||||
|
||||
if !s.scriptMode && s.HTTPServer != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.HTTPServer.Shutdown(ctx); err != nil {
|
||||
logger.Error("HTTP server shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, cleanup := range s.cleanupFuncs {
|
||||
if err := cleanup(); err != nil {
|
||||
logger.Warning("Cleanup error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.LuaRunner != nil {
|
||||
s.LuaRunner.Close()
|
||||
}
|
||||
|
||||
if err := runner.CleanupEnv(); err != nil {
|
||||
logger.Warning("Environment cleanup failed: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Shutdown complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
func dirExists(path string) bool {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return info.IsDir()
|
||||
}
|
||||
|
||||
func initLuaRouter(s *Moonshark) error {
|
||||
if !dirExists(s.Config.Dirs.Routes) {
|
||||
return fmt.Errorf("routes directory doesn't exist: %s", s.Config.Dirs.Routes)
|
||||
}
|
||||
|
||||
var err error
|
||||
s.LuaRouter, err = router.NewLuaRouter(s.Config.Dirs.Routes)
|
||||
if err != nil {
|
||||
if errors.Is(err, router.ErrRoutesCompilationErrors) {
|
||||
// Non-fatal, some routes failed
|
||||
logger.Warning("Some routes failed to compile")
|
||||
|
||||
if failedRoutes := s.LuaRouter.ReportFailedRoutes(); len(failedRoutes) > 0 {
|
||||
for _, re := range failedRoutes {
|
||||
logger.Error("Route %s %s: %v", re.Method, re.Path, re.Err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("lua router init failed: %v", err)
|
||||
}
|
||||
}
|
||||
logger.Info("LuaRouter is g2g! %s", color.Set(s.Config.Dirs.Routes, color.Yellow))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initRunner(s *Moonshark) error {
|
||||
if !dirExists(s.Config.Dirs.Override) {
|
||||
logger.Warning("Override directory not found... %s", color.Apply(s.Config.Dirs.Override, color.Yellow))
|
||||
s.Config.Dirs.Override = ""
|
||||
}
|
||||
|
||||
for _, dir := range s.Config.Dirs.Libs {
|
||||
if !dirExists(dir) {
|
||||
logger.Warning("Lib directory not found... %s", color.Apply(dir, color.Yellow))
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize environment manager
|
||||
if err := runner.InitEnv(s.Config.Dirs.Data); err != nil {
|
||||
logger.Warning("Environment initialization failed: %v", err)
|
||||
}
|
||||
|
||||
sessionManager := sessions.GlobalSessionManager
|
||||
sessionManager.SetCookieOptions(
|
||||
"MoonsharkSID",
|
||||
"/",
|
||||
"",
|
||||
false,
|
||||
true,
|
||||
86400,
|
||||
)
|
||||
|
||||
poolSize := s.Config.Runner.PoolSize
|
||||
if s.scriptMode {
|
||||
poolSize = 1 // Only need one state for script mode
|
||||
}
|
||||
|
||||
runnerOpts := []runner.RunnerOption{
|
||||
runner.WithPoolSize(poolSize),
|
||||
runner.WithLibDirs(s.Config.Dirs.Libs...),
|
||||
runner.WithFsDir(s.Config.Dirs.FS),
|
||||
runner.WithDataDir(s.Config.Dirs.Data),
|
||||
}
|
||||
|
||||
var err error
|
||||
s.LuaRunner, err = runner.NewRunner(runnerOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lua runner init failed: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("LuaRunner is g2g with %s states!", color.Apply(strconv.Itoa(poolSize), color.Yellow))
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupWatchers(s *Moonshark) error {
|
||||
manager := watchers.GetWatcherManager()
|
||||
|
||||
// Watch routes directory
|
||||
routeWatcher, err := watchers.WatchLuaRouter(s.LuaRouter, s.LuaRunner, s.Config.Dirs.Routes)
|
||||
if err != nil {
|
||||
logger.Warning("Routes directory watch failed: %v", err)
|
||||
} else {
|
||||
routesDir := routeWatcher.GetDir()
|
||||
s.cleanupFuncs = append(s.cleanupFuncs, func() error {
|
||||
return manager.UnwatchDirectory(routesDir)
|
||||
})
|
||||
}
|
||||
|
||||
// Watch module directories
|
||||
moduleWatchers, err := watchers.WatchLuaModules(s.LuaRunner, s.Config.Dirs.Libs)
|
||||
if err != nil {
|
||||
logger.Warning("Module directories watch failed: %v", err)
|
||||
} else {
|
||||
for _, watcher := range moduleWatchers {
|
||||
dirPath := watcher.GetDir()
|
||||
s.cleanupFuncs = append(s.cleanupFuncs, func() error {
|
||||
return manager.UnwatchDirectory(dirPath)
|
||||
})
|
||||
}
|
||||
plural := ""
|
||||
if len(moduleWatchers) == 1 {
|
||||
plural = "directory"
|
||||
} else {
|
||||
plural = "directories"
|
||||
}
|
||||
logger.Info("Watching %s module %s.", color.Apply(strconv.Itoa(len(moduleWatchers)), color.Yellow), plural)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func banner() {
|
||||
banner := `
|
||||
_____ _________.__ __
|
||||
/ \ ____ ____ ____ / _____/| |__ _____ _______| | __
|
||||
/ \ / \ / _ \ / _ \ / \ \_____ \ | | \\__ \\_ __ \ |/ /
|
||||
/ Y ( <_> | <_> ) | \/ \| Y \/ __ \| | \/ <
|
||||
\____|__ /\____/ \____/|___| /_______ /|___| (____ /__| |__|_ \
|
||||
\/ \/ \/ \/ \/ \/
|
||||
`
|
||||
fmt.Println(color.Apply(banner, color.Blue))
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
package metadata
|
||||
|
||||
const (
|
||||
Version = "1.0.0"
|
||||
)
|
||||
@ -1,548 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
func GetFunctionList() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"base64_encode": base64_encode,
|
||||
"base64_decode": base64_decode,
|
||||
"base64_url_encode": base64_url_encode,
|
||||
"base64_url_decode": base64_url_decode,
|
||||
"hex_encode": hex_encode,
|
||||
"hex_decode": hex_decode,
|
||||
"md5_hash": md5_hash,
|
||||
"sha1_hash": sha1_hash,
|
||||
"sha256_hash": sha256_hash,
|
||||
"sha512_hash": sha512_hash,
|
||||
"hmac_sha256": hmac_sha256,
|
||||
"hmac_sha1": hmac_sha1,
|
||||
"uuid_generate": uuid_generate,
|
||||
"uuid_generate_v4": uuid_generate_v4,
|
||||
"uuid_validate": uuid_validate,
|
||||
"random_bytes": random_bytes,
|
||||
"random_hex": random_hex,
|
||||
"random_string": random_string,
|
||||
"secure_compare": secure_compare,
|
||||
"argon2_hash": argon2_hash,
|
||||
"argon2_verify": argon2_verify,
|
||||
"bcrypt_hash": bcrypt_hash,
|
||||
"bcrypt_verify": bcrypt_verify,
|
||||
"scrypt_hash": scrypt_hash,
|
||||
"scrypt_verify": scrypt_verify,
|
||||
"pbkdf2_hash": pbkdf2_hash,
|
||||
"pbkdf2_verify": pbkdf2_verify,
|
||||
"password_hash": password_hash,
|
||||
"password_verify": password_verify,
|
||||
}
|
||||
}
|
||||
|
||||
func base64_encode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(str))
|
||||
s.PushString(encoded)
|
||||
return 1
|
||||
}
|
||||
|
||||
func base64_decode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
decoded, err := base64.StdEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("invalid base64 data")
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(decoded))
|
||||
return 1
|
||||
}
|
||||
|
||||
func base64_url_encode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
encoded := base64.URLEncoding.EncodeToString([]byte(str))
|
||||
s.PushString(encoded)
|
||||
return 1
|
||||
}
|
||||
|
||||
func base64_url_decode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
decoded, err := base64.URLEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("invalid base64url data")
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(decoded))
|
||||
return 1
|
||||
}
|
||||
|
||||
func hex_encode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
encoded := hex.EncodeToString([]byte(str))
|
||||
s.PushString(encoded)
|
||||
return 1
|
||||
}
|
||||
|
||||
func hex_decode(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
decoded, err := hex.DecodeString(str)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("invalid hex data")
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(decoded))
|
||||
return 1
|
||||
}
|
||||
|
||||
func md5_hash(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
hash := md5.Sum([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
}
|
||||
|
||||
func sha1_hash(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
hash := sha1.Sum([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
}
|
||||
|
||||
func sha256_hash(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
hash := sha256.Sum256([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
}
|
||||
|
||||
func sha512_hash(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
hash := sha512.Sum512([]byte(str))
|
||||
s.PushString(hex.EncodeToString(hash[:]))
|
||||
return 1
|
||||
}
|
||||
|
||||
func hmac_sha256(s *luajit.State) int {
|
||||
message := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
h := hmac.New(sha256.New, []byte(key))
|
||||
h.Write([]byte(message))
|
||||
s.PushString(hex.EncodeToString(h.Sum(nil)))
|
||||
return 1
|
||||
}
|
||||
|
||||
func hmac_sha1(s *luajit.State) int {
|
||||
message := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
h := hmac.New(sha1.New, []byte(key))
|
||||
h.Write([]byte(message))
|
||||
s.PushString(hex.EncodeToString(h.Sum(nil)))
|
||||
return 1
|
||||
}
|
||||
|
||||
func uuid_generate(s *luajit.State) int {
|
||||
id := uuid.New()
|
||||
s.PushString(id.String())
|
||||
return 1
|
||||
}
|
||||
|
||||
func uuid_generate_v4(s *luajit.State) int {
|
||||
id := uuid.New()
|
||||
s.PushString(id.String())
|
||||
return 1
|
||||
}
|
||||
|
||||
func uuid_validate(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
_, err := uuid.Parse(str)
|
||||
s.PushBoolean(err == nil)
|
||||
return 1
|
||||
}
|
||||
|
||||
func random_bytes(s *luajit.State) int {
|
||||
length := int(s.ToNumber(1))
|
||||
if length < 0 || length > 65536 {
|
||||
s.PushNil()
|
||||
s.PushString("invalid length")
|
||||
return 2
|
||||
}
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("failed to generate random bytes")
|
||||
return 2
|
||||
}
|
||||
s.PushString(string(bytes))
|
||||
return 1
|
||||
}
|
||||
|
||||
func random_hex(s *luajit.State) int {
|
||||
length := int(s.ToNumber(1))
|
||||
if length < 0 || length > 32768 {
|
||||
s.PushNil()
|
||||
s.PushString("invalid length")
|
||||
return 2
|
||||
}
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("failed to generate random bytes")
|
||||
return 2
|
||||
}
|
||||
s.PushString(hex.EncodeToString(bytes))
|
||||
return 1
|
||||
}
|
||||
|
||||
func random_string(s *luajit.State) int {
|
||||
length := int(s.ToNumber(1))
|
||||
if length < 0 || length > 65536 {
|
||||
s.PushNil()
|
||||
s.PushString("invalid length")
|
||||
return 2
|
||||
}
|
||||
|
||||
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
if s.GetTop() >= 2 && !s.IsNil(2) {
|
||||
charset = s.ToString(2)
|
||||
}
|
||||
|
||||
if len(charset) == 0 {
|
||||
s.PushNil()
|
||||
s.PushString("empty charset")
|
||||
return 2
|
||||
}
|
||||
|
||||
result := make([]byte, length)
|
||||
charsetLen := big.NewInt(int64(len(charset)))
|
||||
for i := range result {
|
||||
n, err := rand.Int(rand.Reader, charsetLen)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("failed to generate random number")
|
||||
return 2
|
||||
}
|
||||
result[i] = charset[n.Int64()]
|
||||
}
|
||||
s.PushString(string(result))
|
||||
return 1
|
||||
}
|
||||
|
||||
func secure_compare(s *luajit.State) int {
|
||||
a := s.ToString(1)
|
||||
b := s.ToString(2)
|
||||
s.PushBoolean(hmac.Equal([]byte(a), []byte(b)))
|
||||
return 1
|
||||
}
|
||||
|
||||
func argon2_hash(s *luajit.State) int {
|
||||
password := s.ToString(1)
|
||||
time := uint32(1)
|
||||
memory := uint32(64 * 1024)
|
||||
threads := uint8(4)
|
||||
keyLen := uint32(32)
|
||||
|
||||
if s.GetTop() >= 2 && !s.IsNil(2) {
|
||||
time = uint32(s.ToNumber(2))
|
||||
}
|
||||
if s.GetTop() >= 3 && !s.IsNil(3) {
|
||||
memory = uint32(s.ToNumber(3))
|
||||
}
|
||||
if s.GetTop() >= 4 && !s.IsNil(4) {
|
||||
threads = uint8(s.ToNumber(4))
|
||||
}
|
||||
if s.GetTop() >= 5 && !s.IsNil(5) {
|
||||
keyLen = uint32(s.ToNumber(5))
|
||||
}
|
||||
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("failed to generate salt")
|
||||
return 2
|
||||
}
|
||||
|
||||
hash := argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen)
|
||||
encodedSalt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
encodedHash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
result := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
|
||||
memory, time, threads, encodedSalt, encodedHash)
|
||||
|
||||
s.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
func argon2_verify(s *luajit.State) int {
|
||||
password := s.ToString(1)
|
||||
hash := s.ToString(2)
|
||||
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) != 6 || parts[1] != "argon2id" {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
var memory, time uint32
|
||||
var threads uint8
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads); err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
actualHash := argon2.IDKey([]byte(password), salt, time, memory, threads, uint32(len(expectedHash)))
|
||||
s.PushBoolean(hmac.Equal(actualHash, expectedHash))
|
||||
return 1
|
||||
}
|
||||
|
||||
func bcrypt_hash(s *luajit.State) int {
|
||||
password := s.ToString(1)
|
||||
cost := 12
|
||||
|
||||
if s.GetTop() >= 2 && !s.IsNil(2) {
|
||||
cost = int(s.ToNumber(2))
|
||||
if cost < 4 || cost > 31 {
|
||||
s.PushNil()
|
||||
s.PushString("invalid cost (must be 4-31)")
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), cost)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("bcrypt hash failed")
|
||||
return 2
|
||||
}
|
||||
|
||||
s.PushString(string(hash))
|
||||
return 1
|
||||
}
|
||||
|
||||
func bcrypt_verify(s *luajit.State) int {
|
||||
password := s.ToString(1)
|
||||
hash := s.ToString(2)
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
s.PushBoolean(err == nil)
|
||||
return 1
|
||||
}
|
||||
|
||||
func scrypt_hash(s *luajit.State) int {
|
||||
password := s.ToString(1)
|
||||
N := 32768 // CPU cost
|
||||
r := 8 // block size
|
||||
p := 1 // parallelization
|
||||
keyLen := 32 // key length
|
||||
|
||||
if s.GetTop() >= 2 && !s.IsNil(2) {
|
||||
N = int(s.ToNumber(2))
|
||||
}
|
||||
if s.GetTop() >= 3 && !s.IsNil(3) {
|
||||
r = int(s.ToNumber(3))
|
||||
}
|
||||
if s.GetTop() >= 4 && !s.IsNil(4) {
|
||||
p = int(s.ToNumber(4))
|
||||
}
|
||||
if s.GetTop() >= 5 && !s.IsNil(5) {
|
||||
keyLen = int(s.ToNumber(5))
|
||||
}
|
||||
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("failed to generate salt")
|
||||
return 2
|
||||
}
|
||||
|
||||
hash, err := scrypt.Key([]byte(password), salt, N, r, p, keyLen)
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("scrypt hash failed")
|
||||
return 2
|
||||
}
|
||||
|
||||
encodedSalt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
encodedHash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
result := fmt.Sprintf("$scrypt$N=%d,r=%d,p=%d$%s$%s", N, r, p, encodedSalt, encodedHash)
|
||||
s.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
func scrypt_verify(s *luajit.State) int {
|
||||
password := s.ToString(1)
|
||||
hash := s.ToString(2)
|
||||
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) != 5 || parts[1] != "scrypt" {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
var N, r, p int
|
||||
if _, err := fmt.Sscanf(parts[2], "N=%d,r=%d,p=%d", &N, &r, &p); err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[3])
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
actualHash, err := scrypt.Key([]byte(password), salt, N, r, p, len(expectedHash))
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
s.PushBoolean(hmac.Equal(actualHash, expectedHash))
|
||||
return 1
|
||||
}
|
||||
|
||||
func pbkdf2_hash(s *luajit.State) int {
|
||||
password := s.ToString(1)
|
||||
iterations := 100000
|
||||
keyLen := 32
|
||||
|
||||
if s.GetTop() >= 2 && !s.IsNil(2) {
|
||||
iterations = int(s.ToNumber(2))
|
||||
}
|
||||
if s.GetTop() >= 3 && !s.IsNil(3) {
|
||||
keyLen = int(s.ToNumber(3))
|
||||
}
|
||||
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
s.PushNil()
|
||||
s.PushString("failed to generate salt")
|
||||
return 2
|
||||
}
|
||||
|
||||
hash := pbkdf2.Key([]byte(password), salt, iterations, keyLen, sha256.New)
|
||||
encodedSalt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
encodedHash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
result := fmt.Sprintf("$pbkdf2-sha256$i=%d$%s$%s", iterations, encodedSalt, encodedHash)
|
||||
s.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
func pbkdf2_verify(s *luajit.State) int {
|
||||
password := s.ToString(1)
|
||||
hash := s.ToString(2)
|
||||
|
||||
parts := strings.Split(hash, "$")
|
||||
if len(parts) != 5 || parts[1] != "pbkdf2-sha256" {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
var iterations int
|
||||
if _, err := fmt.Sscanf(parts[2], "i=%d", &iterations); err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[3])
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
actualHash := pbkdf2.Key([]byte(password), salt, iterations, len(expectedHash), sha256.New)
|
||||
s.PushBoolean(hmac.Equal(actualHash, expectedHash))
|
||||
return 1
|
||||
}
|
||||
|
||||
func password_hash(s *luajit.State) int {
|
||||
password := s.ToString(1)
|
||||
algorithm := "argon2id" // default
|
||||
|
||||
if s.GetTop() >= 2 && !s.IsNil(2) {
|
||||
algorithm = s.ToString(2)
|
||||
}
|
||||
|
||||
switch algorithm {
|
||||
case "argon2id":
|
||||
s.PushString(password)
|
||||
return argon2_hash(s)
|
||||
case "bcrypt":
|
||||
s.PushString(password)
|
||||
if s.GetTop() >= 3 {
|
||||
s.PushNumber(s.ToNumber(3))
|
||||
}
|
||||
return bcrypt_hash(s)
|
||||
case "scrypt":
|
||||
s.PushString(password)
|
||||
return scrypt_hash(s)
|
||||
case "pbkdf2":
|
||||
s.PushString(password)
|
||||
return pbkdf2_hash(s)
|
||||
default:
|
||||
s.PushNil()
|
||||
s.PushString("unsupported algorithm: " + algorithm)
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
func password_verify(s *luajit.State) int {
|
||||
hash := s.ToString(2)
|
||||
|
||||
// Auto-detect algorithm from hash format
|
||||
if strings.HasPrefix(hash, "$argon2id$") {
|
||||
return argon2_verify(s)
|
||||
} else if strings.HasPrefix(hash, "$2a$") || strings.HasPrefix(hash, "$2b$") || strings.HasPrefix(hash, "$2y$") {
|
||||
return bcrypt_verify(s)
|
||||
} else if strings.HasPrefix(hash, "$scrypt$") {
|
||||
return scrypt_verify(s)
|
||||
} else if strings.HasPrefix(hash, "$pbkdf2-sha256$") {
|
||||
return pbkdf2_verify(s)
|
||||
}
|
||||
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
@ -1,530 +0,0 @@
|
||||
local crypto = {}
|
||||
|
||||
-- ======================================================================
|
||||
-- ENCODING / DECODING
|
||||
-- ======================================================================
|
||||
|
||||
function crypto.base64_encode(data)
|
||||
local result, err = moonshark.base64_encode(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.base64_decode(data)
|
||||
local result, err = moonshark.base64_decode(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.base64_url_encode(data)
|
||||
local result, err = moonshark.base64_url_encode(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.base64_url_decode(data)
|
||||
local result, err = moonshark.base64_url_decode(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.hex_encode(data)
|
||||
local result, err = moonshark.hex_encode(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.hex_decode(data)
|
||||
local result, err = moonshark.hex_decode(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- HASHING FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function crypto.md5(data)
|
||||
local result, err = moonshark.md5_hash(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.sha1(data)
|
||||
local result, err = moonshark.sha1_hash(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.sha256(data)
|
||||
local result, err = moonshark.sha256_hash(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.sha512(data)
|
||||
local result, err = moonshark.sha512_hash(data)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- Hash file contents
|
||||
function crypto.hash_file(path, algorithm)
|
||||
algorithm = algorithm or "sha256"
|
||||
|
||||
if not moonshark.file_exists(path) then
|
||||
error("File not found: " .. path)
|
||||
end
|
||||
|
||||
local content = moonshark.file_read(path)
|
||||
if not content then
|
||||
error("Failed to read file: " .. path)
|
||||
end
|
||||
|
||||
if algorithm == "md5" then
|
||||
return crypto.md5(content)
|
||||
elseif algorithm == "sha1" then
|
||||
return crypto.sha1(content)
|
||||
elseif algorithm == "sha256" then
|
||||
return crypto.sha256(content)
|
||||
elseif algorithm == "sha512" then
|
||||
return crypto.sha512(content)
|
||||
else
|
||||
error("Unsupported hash algorithm: " .. algorithm)
|
||||
end
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- HMAC FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function crypto.hmac_sha1(message, key)
|
||||
local result, err = moonshark.hmac_sha1(message, key)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.hmac_sha256(message, key)
|
||||
local result, err = moonshark.hmac_sha256(message, key)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- UUID FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function crypto.uuid()
|
||||
return moonshark.uuid_generate()
|
||||
end
|
||||
|
||||
function crypto.uuid_v4()
|
||||
return moonshark.uuid_generate_v4()
|
||||
end
|
||||
|
||||
function crypto.is_uuid(str)
|
||||
return moonshark.uuid_validate(str)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- RANDOM GENERATORS
|
||||
-- ======================================================================
|
||||
|
||||
function crypto.random_bytes(length)
|
||||
local result, err = moonshark.random_bytes(length)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.random_hex(length)
|
||||
local result, err = moonshark.random_hex(length)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.random_string(length, charset)
|
||||
local result, err = moonshark.random_string(length, charset)
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- Generate random alphanumeric string
|
||||
function crypto.random_alphanumeric(length)
|
||||
return crypto.random_string(length, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
end
|
||||
|
||||
-- Generate random password with mixed characters
|
||||
function crypto.random_password(length, include_symbols)
|
||||
length = length or 12
|
||||
local charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
if include_symbols then
|
||||
charset = charset .. "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
end
|
||||
|
||||
return crypto.random_string(length, charset)
|
||||
end
|
||||
|
||||
-- Generate cryptographically secure token
|
||||
function crypto.token(length)
|
||||
length = length or 32
|
||||
return crypto.random_hex(length)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function crypto.secure_compare(a, b)
|
||||
return moonshark.secure_compare(a, b)
|
||||
end
|
||||
|
||||
-- Generate checksum for data integrity
|
||||
function crypto.checksum(data, algorithm)
|
||||
algorithm = algorithm or "sha256"
|
||||
return crypto.hash_file and crypto[algorithm] and crypto[algorithm](data) or error("Invalid algorithm")
|
||||
end
|
||||
|
||||
-- Verify data against checksum
|
||||
function crypto.verify_checksum(data, expected, algorithm)
|
||||
algorithm = algorithm or "sha256"
|
||||
local actual = crypto[algorithm](data)
|
||||
return crypto.secure_compare(actual, expected)
|
||||
end
|
||||
|
||||
-- Simple encryption using XOR (not cryptographically secure)
|
||||
function crypto.xor_encrypt(data, key)
|
||||
local result = {}
|
||||
local key_len = #key
|
||||
|
||||
for i = 1, #data do
|
||||
local data_byte = string.byte(data, i)
|
||||
local key_byte = string.byte(key, ((i - 1) % key_len) + 1)
|
||||
table.insert(result, string.char(bit32 and bit32.bxor(data_byte, key_byte) or bit.bxor(data_byte, key_byte)))
|
||||
end
|
||||
|
||||
return table.concat(result)
|
||||
end
|
||||
|
||||
-- XOR decryption (same as encryption)
|
||||
function crypto.xor_decrypt(data, key)
|
||||
return crypto.xor_encrypt(data, key)
|
||||
end
|
||||
|
||||
-- Generate hash chain for proof of work
|
||||
function crypto.hash_chain(data, iterations, algorithm)
|
||||
iterations = iterations or 1000
|
||||
algorithm = algorithm or "sha256"
|
||||
|
||||
local result = data
|
||||
for i = 1, iterations do
|
||||
result = crypto[algorithm](result)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- Key derivation using PBKDF2-like approach (simplified)
|
||||
function crypto.derive_key(password, salt, iterations, algorithm)
|
||||
iterations = iterations or 10000
|
||||
algorithm = algorithm or "sha256"
|
||||
salt = salt or crypto.random_hex(16)
|
||||
|
||||
local derived = password .. salt
|
||||
for i = 1, iterations do
|
||||
derived = crypto[algorithm](derived)
|
||||
end
|
||||
|
||||
return derived, salt
|
||||
end
|
||||
|
||||
-- Generate nonce (number used once)
|
||||
function crypto.nonce(length)
|
||||
length = length or 16
|
||||
return crypto.random_hex(length)
|
||||
end
|
||||
|
||||
-- Create message authentication code
|
||||
function crypto.mac(message, key, algorithm)
|
||||
algorithm = algorithm or "sha256"
|
||||
return crypto["hmac_" .. algorithm](message, key)
|
||||
end
|
||||
|
||||
-- Verify message authentication code
|
||||
function crypto.verify_mac(message, key, mac, algorithm)
|
||||
algorithm = algorithm or "sha256"
|
||||
local expected = crypto.mac(message, key, algorithm)
|
||||
return crypto.secure_compare(expected, mac)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- CONVENIENCE FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- One-shot encoding chain
|
||||
function crypto.encode_chain(data, formats)
|
||||
formats = formats or {"base64"}
|
||||
local result = data
|
||||
|
||||
for _, format in ipairs(formats) do
|
||||
if format == "base64" then
|
||||
result = crypto.base64_encode(result)
|
||||
elseif format == "base64url" then
|
||||
result = crypto.base64_url_encode(result)
|
||||
elseif format == "hex" then
|
||||
result = crypto.hex_encode(result)
|
||||
else
|
||||
error("Unknown encoding format: " .. format)
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- One-shot decoding chain (reverse order)
|
||||
function crypto.decode_chain(data, formats)
|
||||
formats = formats or {"base64"}
|
||||
local result = data
|
||||
|
||||
-- Reverse the formats for decoding
|
||||
for i = #formats, 1, -1 do
|
||||
local format = formats[i]
|
||||
if format == "base64" then
|
||||
result = crypto.base64_decode(result)
|
||||
elseif format == "base64url" then
|
||||
result = crypto.base64_url_decode(result)
|
||||
elseif format == "hex" then
|
||||
result = crypto.hex_decode(result)
|
||||
else
|
||||
error("Unknown decoding format: " .. format)
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- Hash multiple inputs
|
||||
function crypto.hash_multiple(inputs, algorithm)
|
||||
algorithm = algorithm or "sha256"
|
||||
local combined = table.concat(inputs, "")
|
||||
return crypto[algorithm](combined)
|
||||
end
|
||||
|
||||
-- Create fingerprint from table data
|
||||
function crypto.fingerprint(data, algorithm)
|
||||
algorithm = algorithm or "sha256"
|
||||
return crypto[algorithm](json.encode(data))
|
||||
end
|
||||
|
||||
-- Simple data integrity check
|
||||
function crypto.integrity_check(data)
|
||||
return {
|
||||
data = data,
|
||||
hash = crypto.sha256(data),
|
||||
timestamp = os.time(),
|
||||
uuid = crypto.uuid()
|
||||
}
|
||||
end
|
||||
|
||||
-- Verify integrity check
|
||||
function crypto.verify_integrity(check)
|
||||
if not check.data or not check.hash then
|
||||
return false
|
||||
end
|
||||
|
||||
local expected = crypto.sha256(check.data)
|
||||
return crypto.secure_compare(expected, check.hash)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- PASSWORD HASHING
|
||||
-- ======================================================================
|
||||
|
||||
-- Generic password hashing (defaults to argon2id)
|
||||
function crypto.hash_password(password, algorithm, options)
|
||||
algorithm = algorithm or "argon2id"
|
||||
options = options or {}
|
||||
|
||||
local result, err
|
||||
|
||||
if algorithm == "argon2id" then
|
||||
local time = options.time or 1
|
||||
local memory = options.memory or 65536 -- 64MB in KB
|
||||
local threads = options.threads or 4
|
||||
local keylen = options.keylen or 32
|
||||
result, err = moonshark.argon2_hash(password, time, memory, threads, keylen)
|
||||
|
||||
elseif algorithm == "bcrypt" then
|
||||
local cost = options.cost or 12
|
||||
result, err = moonshark.bcrypt_hash(password, cost)
|
||||
|
||||
elseif algorithm == "scrypt" then
|
||||
local N = options.N or 32768
|
||||
local r = options.r or 8
|
||||
local p = options.p or 1
|
||||
local keylen = options.keylen or 32
|
||||
result, err = moonshark.scrypt_hash(password, N, r, p, keylen)
|
||||
|
||||
elseif algorithm == "pbkdf2" then
|
||||
local iterations = options.iterations or 100000
|
||||
local keylen = options.keylen or 32
|
||||
result, err = moonshark.pbkdf2_hash(password, iterations, keylen)
|
||||
|
||||
else
|
||||
error("unsupported algorithm: " .. algorithm)
|
||||
end
|
||||
|
||||
if not result then
|
||||
error(err)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- Generic password verification (auto-detects algorithm)
|
||||
function crypto.verify_password(password, hash)
|
||||
return moonshark.password_verify(password, hash)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- ALGORITHM-SPECIFIC FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Argon2id hashing
|
||||
function crypto.argon2_hash(password, options)
|
||||
options = options or {}
|
||||
local time = options.time or 1
|
||||
local memory = options.memory or 65536
|
||||
local threads = options.threads or 4
|
||||
local keylen = options.keylen or 32
|
||||
|
||||
local result, err = moonshark.argon2_hash(password, time, memory, threads, keylen)
|
||||
if not result then error(err) end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.argon2_verify(password, hash)
|
||||
return moonshark.argon2_verify(password, hash)
|
||||
end
|
||||
|
||||
-- bcrypt hashing
|
||||
function crypto.bcrypt_hash(password, cost)
|
||||
cost = cost or 12
|
||||
local result, err = moonshark.bcrypt_hash(password, cost)
|
||||
if not result then error(err) end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.bcrypt_verify(password, hash)
|
||||
return moonshark.bcrypt_verify(password, hash)
|
||||
end
|
||||
|
||||
-- scrypt hashing
|
||||
function crypto.scrypt_hash(password, options)
|
||||
options = options or {}
|
||||
local N = options.N or 32768
|
||||
local r = options.r or 8
|
||||
local p = options.p or 1
|
||||
local keylen = options.keylen or 32
|
||||
|
||||
local result, err = moonshark.scrypt_hash(password, N, r, p, keylen)
|
||||
if not result then error(err) end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.scrypt_verify(password, hash)
|
||||
return moonshark.scrypt_verify(password, hash)
|
||||
end
|
||||
|
||||
-- PBKDF2 hashing
|
||||
function crypto.pbkdf2_hash(password, iterations, keylen)
|
||||
iterations = iterations or 100000
|
||||
keylen = keylen or 32
|
||||
|
||||
local result, err = moonshark.pbkdf2_hash(password, iterations, keylen)
|
||||
if not result then error(err) end
|
||||
return result
|
||||
end
|
||||
|
||||
function crypto.pbkdf2_verify(password, hash)
|
||||
return moonshark.pbkdf2_verify(password, hash)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- PASSWORD CONFIG PRESETS
|
||||
-- ======================================================================
|
||||
|
||||
function crypto.hash_password_fast(password, algorithm)
|
||||
algorithm = algorithm or "argon2id"
|
||||
|
||||
local options = {
|
||||
argon2id = { time = 1, memory = 8192, threads = 1 },
|
||||
bcrypt = { cost = 10 },
|
||||
scrypt = { N = 16384, r = 8, p = 1 },
|
||||
pbkdf2 = { iterations = 50000 }
|
||||
}
|
||||
|
||||
return crypto.hash_password(password, algorithm, options[algorithm])
|
||||
end
|
||||
|
||||
function crypto.hash_password_strong(password, algorithm)
|
||||
algorithm = algorithm or "argon2id"
|
||||
|
||||
local options = {
|
||||
argon2id = { time = 3, memory = 131072, threads = 4 },
|
||||
bcrypt = { cost = 14 },
|
||||
scrypt = { N = 65536, r = 8, p = 2 },
|
||||
pbkdf2 = { iterations = 200000 }
|
||||
}
|
||||
|
||||
return crypto.hash_password(password, algorithm, options[algorithm])
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Detect algorithm from hash
|
||||
function crypto.detect_algorithm(hash)
|
||||
if hash:match("^%$argon2id%$") then
|
||||
return "argon2id"
|
||||
elseif hash:match("^%$2[aby]%$") then
|
||||
return "bcrypt"
|
||||
elseif hash:match("^%$scrypt%$") then
|
||||
return "scrypt"
|
||||
elseif hash:match("^%$pbkdf2%-sha256%$") then
|
||||
return "pbkdf2"
|
||||
else
|
||||
return "unknown"
|
||||
end
|
||||
end
|
||||
|
||||
return crypto
|
||||
@ -1,37 +0,0 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
func GetFunctionList() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"getcwd": getcwd,
|
||||
"chdir": chdir,
|
||||
}
|
||||
}
|
||||
|
||||
func getcwd(s *luajit.State) int {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
s.PushNil()
|
||||
s.PushString(err.Error())
|
||||
return 2
|
||||
}
|
||||
s.PushString(cwd)
|
||||
return 1
|
||||
}
|
||||
|
||||
func chdir(s *luajit.State) int {
|
||||
path := s.ToString(1)
|
||||
err := os.Chdir(path)
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
s.PushString(err.Error())
|
||||
return 2
|
||||
}
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
@ -1,723 +0,0 @@
|
||||
local fs = {}
|
||||
|
||||
local is_windows = package.config:sub(1,1) == '\\'
|
||||
local path_sep = is_windows and '\\' or '/'
|
||||
|
||||
-- ======================================================================
|
||||
-- UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
local function shell_escape(str)
|
||||
if is_windows then
|
||||
-- Windows: escape quotes and wrap in quotes
|
||||
return '"' .. str:gsub('"', '""') .. '"'
|
||||
else
|
||||
-- Unix: escape shell metacharacters
|
||||
return "'" .. str:gsub("'", "'\"'\"'") .. "'"
|
||||
end
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- FILE OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
function fs.exists(path)
|
||||
local file = io.open(path, "r")
|
||||
if file then
|
||||
file:close()
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
function fs.size(path)
|
||||
local file = io.open(path, "r")
|
||||
if not file then return nil end
|
||||
|
||||
local size = file:seek("end")
|
||||
file:close()
|
||||
return size
|
||||
end
|
||||
|
||||
function fs.is_dir(path)
|
||||
if not fs.exists(path) then return false end
|
||||
|
||||
local cmd
|
||||
if is_windows then
|
||||
cmd = 'dir ' .. shell_escape(path) .. ' >nul 2>&1 && echo dir'
|
||||
else
|
||||
cmd = 'test -d ' .. shell_escape(path) .. ' && echo dir'
|
||||
end
|
||||
|
||||
local handle = io.popen(cmd)
|
||||
if handle then
|
||||
local result = handle:read("*l")
|
||||
handle:close()
|
||||
return result == "dir"
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
function fs.is_file(path)
|
||||
return fs.exists(path) and not fs.is_dir(path)
|
||||
end
|
||||
|
||||
function fs.read(path)
|
||||
local file = io.open(path, "r")
|
||||
if not file then
|
||||
error("Failed to read file '" .. path .. "': file not found or permission denied")
|
||||
end
|
||||
|
||||
local content = file:read("*all")
|
||||
file:close()
|
||||
|
||||
if not content then
|
||||
error("Failed to read file '" .. path .. "': read error")
|
||||
end
|
||||
|
||||
return content
|
||||
end
|
||||
|
||||
function fs.write(path, content)
|
||||
if path == "" or path:find("\0") then
|
||||
error("Failed to write file '" .. path .. "': invalid path")
|
||||
end
|
||||
|
||||
local file = io.open(path, "w")
|
||||
if not file then
|
||||
error("Failed to write file '" .. path .. "': permission denied or invalid path")
|
||||
end
|
||||
|
||||
local success = file:write(content)
|
||||
file:close()
|
||||
|
||||
if not success then
|
||||
error("Failed to write file '" .. path .. "': write error")
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.append(path, content)
|
||||
local file = io.open(path, "a")
|
||||
if not file then
|
||||
error("Failed to append to file '" .. path .. "': permission denied or invalid path")
|
||||
end
|
||||
|
||||
local success = file:write(content)
|
||||
file:close()
|
||||
|
||||
if not success then
|
||||
error("Failed to append to file '" .. path .. "': write error")
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.copy(src, dst)
|
||||
local src_file = io.open(src, "rb")
|
||||
if not src_file then
|
||||
error("Failed to copy '" .. src .. "' to '" .. dst .. "': source file not found")
|
||||
end
|
||||
|
||||
local dst_file = io.open(dst, "wb")
|
||||
if not dst_file then
|
||||
src_file:close()
|
||||
error("Failed to copy '" .. src .. "' to '" .. dst .. "': cannot create destination")
|
||||
end
|
||||
|
||||
local chunk_size = 8192
|
||||
while true do
|
||||
local chunk = src_file:read(chunk_size)
|
||||
if not chunk then break end
|
||||
|
||||
if not dst_file:write(chunk) then
|
||||
src_file:close()
|
||||
dst_file:close()
|
||||
error("Failed to copy '" .. src .. "' to '" .. dst .. "': write error")
|
||||
end
|
||||
end
|
||||
|
||||
src_file:close()
|
||||
dst_file:close()
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.move(src, dst)
|
||||
local success = os.rename(src, dst)
|
||||
if success then return true end
|
||||
|
||||
-- Fallback to copy + delete
|
||||
fs.copy(src, dst)
|
||||
fs.remove(src)
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.remove(path)
|
||||
local success = os.remove(path)
|
||||
if not success then
|
||||
error("Failed to remove '" .. path .. "': file not found or permission denied")
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.mtime(path)
|
||||
local cmd
|
||||
if is_windows then
|
||||
cmd = 'forfiles /p . /m ' .. shell_escape(fs.basename(path)) .. ' /c "cmd /c echo @fdate @ftime" 2>nul'
|
||||
local handle = io.popen(cmd)
|
||||
if handle then
|
||||
local result = handle:read("*line")
|
||||
handle:close()
|
||||
if result then
|
||||
-- Basic Windows timestamp parsing - fallback to file existence check
|
||||
return fs.exists(path) and os.time() or nil
|
||||
end
|
||||
end
|
||||
else
|
||||
cmd = 'stat -c %Y ' .. shell_escape(path) .. ' 2>/dev/null'
|
||||
local handle = io.popen(cmd)
|
||||
if handle then
|
||||
local result = handle:read("*line")
|
||||
handle:close()
|
||||
if result and result:match("^%d+$") then
|
||||
return tonumber(result)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
function fs.touch(path)
|
||||
if fs.exists(path) then
|
||||
-- Update timestamp
|
||||
local content = fs.read(path)
|
||||
fs.write(path, content)
|
||||
else
|
||||
-- Create empty file
|
||||
fs.write(path, "")
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.lines(path)
|
||||
local lines = {}
|
||||
local file = io.open(path, "r")
|
||||
if not file then
|
||||
error("Failed to read lines from '" .. path .. "': file not found")
|
||||
end
|
||||
|
||||
for line in file:lines() do
|
||||
lines[#lines + 1] = line
|
||||
end
|
||||
file:close()
|
||||
|
||||
return lines
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- DIRECTORY OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
function fs.mkdir(path)
|
||||
local cmd
|
||||
if is_windows then
|
||||
cmd = 'mkdir ' .. shell_escape(path) .. ' >nul 2>&1'
|
||||
else
|
||||
cmd = 'mkdir -p ' .. shell_escape(path)
|
||||
end
|
||||
|
||||
local result = os.execute(cmd)
|
||||
-- Handle different Lua version return values
|
||||
local success = (result == 0) or (result == true)
|
||||
if not success then
|
||||
error("Failed to create directory '" .. path .. "'")
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.rmdir(path)
|
||||
local cmd
|
||||
if is_windows then
|
||||
cmd = 'rmdir /s /q ' .. shell_escape(path) .. ' >nul 2>&1'
|
||||
else
|
||||
cmd = 'rm -rf ' .. shell_escape(path)
|
||||
end
|
||||
|
||||
local result = os.execute(cmd)
|
||||
local success = (result == 0) or (result == true)
|
||||
if not success then
|
||||
error("Failed to remove directory '" .. path .. "'")
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.list(path)
|
||||
if not fs.exists(path) then
|
||||
error("Failed to list directory '" .. path .. "': directory not found")
|
||||
end
|
||||
|
||||
if not fs.is_dir(path) then
|
||||
error("Failed to list directory '" .. path .. "': not a directory")
|
||||
end
|
||||
|
||||
local entries = {}
|
||||
local cmd
|
||||
|
||||
if is_windows then
|
||||
cmd = 'dir /b ' .. shell_escape(path) .. ' 2>nul'
|
||||
else
|
||||
cmd = 'ls -1 ' .. shell_escape(path) .. ' 2>/dev/null'
|
||||
end
|
||||
|
||||
local handle = io.popen(cmd)
|
||||
if not handle then
|
||||
error("Failed to list directory '" .. path .. "': command failed")
|
||||
end
|
||||
|
||||
for name in handle:lines() do
|
||||
local full_path = fs.join(path, name)
|
||||
entries[#entries + 1] = {
|
||||
name = name,
|
||||
is_dir = fs.is_dir(full_path),
|
||||
size = fs.is_file(full_path) and fs.size(full_path) or 0,
|
||||
mtime = fs.mtime(full_path)
|
||||
}
|
||||
end
|
||||
handle:close()
|
||||
|
||||
return entries
|
||||
end
|
||||
|
||||
function fs.list_files(path)
|
||||
local entries = fs.list(path)
|
||||
local files = {}
|
||||
for _, entry in ipairs(entries) do
|
||||
if not entry.is_dir then
|
||||
files[#files + 1] = entry
|
||||
end
|
||||
end
|
||||
return files
|
||||
end
|
||||
|
||||
function fs.list_dirs(path)
|
||||
local entries = fs.list(path)
|
||||
local dirs = {}
|
||||
for _, entry in ipairs(entries) do
|
||||
if entry.is_dir then
|
||||
dirs[#dirs + 1] = entry
|
||||
end
|
||||
end
|
||||
return dirs
|
||||
end
|
||||
|
||||
function fs.list_names(path)
|
||||
local entries = fs.list(path)
|
||||
local names = {}
|
||||
for _, entry in ipairs(entries) do
|
||||
names[#names + 1] = entry.name
|
||||
end
|
||||
return names
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- PATH OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
function fs.join(...)
|
||||
local parts = {...}
|
||||
if #parts == 0 then return "" end
|
||||
if #parts == 1 then return parts[1] or "" end
|
||||
|
||||
local result = parts[1] or ""
|
||||
for i = 2, #parts do
|
||||
local part = parts[i]
|
||||
if part and part ~= "" then
|
||||
if result == "" then
|
||||
result = part
|
||||
elseif result:sub(-1) == path_sep then
|
||||
result = result .. part
|
||||
else
|
||||
result = result .. path_sep .. part
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
function fs.dirname(path)
|
||||
if path == "" then return "." end
|
||||
|
||||
-- Remove trailing separators
|
||||
path = path:gsub("[/\\]+$", "")
|
||||
|
||||
local pos = path:find("[/\\][^/\\]*$")
|
||||
if pos then
|
||||
local dir = path:sub(1, pos - 1)
|
||||
return dir ~= "" and dir or path_sep
|
||||
end
|
||||
return "."
|
||||
end
|
||||
|
||||
function fs.basename(path)
|
||||
if path == "" then return "" end
|
||||
|
||||
-- Remove trailing separators
|
||||
path = path:gsub("[/\\]+$", "")
|
||||
|
||||
return path:match("[^/\\]*$") or ""
|
||||
end
|
||||
|
||||
function fs.ext(path)
|
||||
local base = fs.basename(path)
|
||||
local dot_pos = base:find("%.[^%.]*$")
|
||||
return dot_pos and base:sub(dot_pos) or ""
|
||||
end
|
||||
|
||||
function fs.abs(path)
|
||||
if path == "" then path = "." end
|
||||
|
||||
if is_windows then
|
||||
if path:match("^[A-Za-z]:") or path:match("^\\\\") then
|
||||
return fs.clean(path)
|
||||
end
|
||||
else
|
||||
if path:sub(1, 1) == "/" then
|
||||
return fs.clean(path)
|
||||
end
|
||||
end
|
||||
|
||||
local cwd = fs.getcwd()
|
||||
return fs.clean(fs.join(cwd, path))
|
||||
end
|
||||
|
||||
function fs.clean(path)
|
||||
if path == "" then return "." end
|
||||
|
||||
-- Normalize path separators
|
||||
path = path:gsub("[/\\]+", path_sep)
|
||||
|
||||
-- Track if path was absolute
|
||||
local is_absolute = false
|
||||
if is_windows then
|
||||
is_absolute = path:match("^[A-Za-z]:") or path:match("^\\\\")
|
||||
else
|
||||
is_absolute = path:sub(1, 1) == "/"
|
||||
end
|
||||
|
||||
-- Split into components
|
||||
local parts = {}
|
||||
for part in path:gmatch("[^" .. path_sep:gsub("\\", "\\\\") .. "]+") do
|
||||
if part == ".." then
|
||||
if #parts > 0 and parts[#parts] ~= ".." then
|
||||
if not is_absolute or #parts > 1 then
|
||||
parts[#parts] = nil
|
||||
end
|
||||
elseif not is_absolute then
|
||||
parts[#parts + 1] = ".."
|
||||
end
|
||||
elseif part ~= "." and part ~= "" then
|
||||
parts[#parts + 1] = part
|
||||
end
|
||||
end
|
||||
|
||||
local result = table.concat(parts, path_sep)
|
||||
|
||||
if is_absolute then
|
||||
if is_windows and path:match("^[A-Za-z]:") then
|
||||
-- Windows drive letter
|
||||
local drive = path:match("^[A-Za-z]:")
|
||||
result = drive .. path_sep .. result
|
||||
elseif is_windows and path:match("^\\\\") then
|
||||
-- UNC path
|
||||
result = "\\\\" .. result
|
||||
else
|
||||
-- Unix absolute
|
||||
result = path_sep .. result
|
||||
end
|
||||
end
|
||||
|
||||
return result ~= "" and result or (is_absolute and path_sep or ".")
|
||||
end
|
||||
|
||||
function fs.split(path)
|
||||
local pos = path:find("[/\\][^/\\]*$")
|
||||
if pos then
|
||||
return path:sub(1, pos), path:sub(pos + 1)
|
||||
end
|
||||
return "", path
|
||||
end
|
||||
|
||||
function fs.splitext(path)
|
||||
local dir = fs.dirname(path)
|
||||
local base = fs.basename(path)
|
||||
local ext = fs.ext(path)
|
||||
local name = base
|
||||
|
||||
if ext ~= "" then
|
||||
name = base:sub(1, -(#ext + 1))
|
||||
end
|
||||
|
||||
return dir, name, ext
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- WORKING DIRECTORY
|
||||
-- ======================================================================
|
||||
|
||||
function fs.getcwd()
|
||||
local cwd, err = moonshark.getcwd()
|
||||
if not cwd then
|
||||
error("Failed to get current directory: " .. (err or "unknown error"))
|
||||
end
|
||||
return cwd
|
||||
end
|
||||
|
||||
function fs.chdir(path)
|
||||
local success, err = moonshark.chdir(path)
|
||||
if not success then
|
||||
error("Failed to change directory to '" .. path .. "': " .. (err or "unknown error"))
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- TEMPORARY FILES
|
||||
-- ======================================================================
|
||||
|
||||
function fs.tempfile(prefix)
|
||||
prefix = prefix or "tmp"
|
||||
local temp_name = prefix .. "_" .. os.time() .. "_" .. math.random(10000)
|
||||
|
||||
local temp_path
|
||||
if is_windows then
|
||||
local temp_dir = os.getenv("TEMP") or os.getenv("TMP") or "C:\\temp"
|
||||
temp_path = fs.join(temp_dir, temp_name)
|
||||
else
|
||||
temp_path = fs.join("/tmp", temp_name)
|
||||
end
|
||||
|
||||
fs.write(temp_path, "")
|
||||
return temp_path
|
||||
end
|
||||
|
||||
function fs.tempdir(prefix)
|
||||
prefix = prefix or "tmp"
|
||||
local temp_name = prefix .. "_" .. os.time() .. "_" .. math.random(10000)
|
||||
|
||||
local temp_path
|
||||
if is_windows then
|
||||
local temp_dir = os.getenv("TEMP") or os.getenv("TMP") or "C:\\temp"
|
||||
temp_path = fs.join(temp_dir, temp_name)
|
||||
else
|
||||
temp_path = fs.join("/tmp", temp_name)
|
||||
end
|
||||
|
||||
fs.mkdir(temp_path)
|
||||
return temp_path
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- PATTERN MATCHING
|
||||
-- ======================================================================
|
||||
|
||||
function fs.glob(pattern)
|
||||
-- Simple validation to prevent obvious shell injection
|
||||
if pattern:find("[;&|`$(){}]") then
|
||||
return {}
|
||||
end
|
||||
|
||||
local cmd
|
||||
if is_windows then
|
||||
cmd = 'for %f in (' .. pattern .. ') do @echo %f 2>nul'
|
||||
else
|
||||
cmd = 'ls -1d ' .. pattern .. ' 2>/dev/null'
|
||||
end
|
||||
|
||||
local matches = {}
|
||||
local handle = io.popen(cmd)
|
||||
if handle then
|
||||
for match in handle:lines() do
|
||||
matches[#matches + 1] = match
|
||||
end
|
||||
handle:close()
|
||||
end
|
||||
|
||||
return matches
|
||||
end
|
||||
|
||||
function fs.walk(root)
|
||||
local files = {}
|
||||
|
||||
local function walk_recursive(path)
|
||||
if not fs.exists(path) then return end
|
||||
|
||||
files[#files + 1] = path
|
||||
|
||||
if fs.is_dir(path) then
|
||||
local success, entries = pcall(fs.list, path)
|
||||
if success then
|
||||
for _, entry in ipairs(entries) do
|
||||
walk_recursive(fs.join(path, entry.name))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
walk_recursive(root)
|
||||
return files
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function fs.extension(path)
|
||||
local ext = fs.ext(path)
|
||||
return ext:sub(2) -- Remove leading dot
|
||||
end
|
||||
|
||||
function fs.change_ext(path, new_ext)
|
||||
local dir, name, _ = fs.splitext(path)
|
||||
if not new_ext:match("^%.") then
|
||||
new_ext = "." .. new_ext
|
||||
end
|
||||
if dir == "." then
|
||||
return name .. new_ext
|
||||
end
|
||||
return fs.join(dir, name .. new_ext)
|
||||
end
|
||||
|
||||
function fs.ensure_dir(path)
|
||||
if not fs.exists(path) then
|
||||
fs.mkdir(path)
|
||||
elseif not fs.is_dir(path) then
|
||||
error("Path exists but is not a directory: " .. path)
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.size_human(path)
|
||||
local size = fs.size(path)
|
||||
if not size then return nil end
|
||||
|
||||
local units = {"B", "KB", "MB", "GB", "TB"}
|
||||
local unit_index = 1
|
||||
local size_float = size
|
||||
|
||||
while size_float >= 1024 and unit_index < #units do
|
||||
size_float = size_float / 1024
|
||||
unit_index = unit_index + 1
|
||||
end
|
||||
|
||||
if unit_index == 1 then
|
||||
return string.format("%d %s", size_float, units[unit_index])
|
||||
else
|
||||
return string.format("%.1f %s", size_float, units[unit_index])
|
||||
end
|
||||
end
|
||||
|
||||
function fs.is_safe_path(path)
|
||||
path = fs.clean(path)
|
||||
|
||||
if path:match("%.%.") then return false end
|
||||
if path:match("^[/\\]") then return false end
|
||||
if path:match("^~") then return false end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
-- Aliases for convenience
|
||||
fs.makedirs = fs.mkdir
|
||||
fs.removedirs = fs.rmdir
|
||||
|
||||
function fs.copytree(src, dst)
|
||||
if not fs.exists(src) then
|
||||
error("Source directory does not exist: " .. src)
|
||||
end
|
||||
|
||||
if not fs.is_dir(src) then
|
||||
error("Source is not a directory: " .. src)
|
||||
end
|
||||
|
||||
fs.mkdir(dst)
|
||||
|
||||
local entries = fs.list(src)
|
||||
for _, entry in ipairs(entries) do
|
||||
local src_path = fs.join(src, entry.name)
|
||||
local dst_path = fs.join(dst, entry.name)
|
||||
|
||||
if entry.is_dir then
|
||||
fs.copytree(src_path, dst_path)
|
||||
else
|
||||
fs.copy(src_path, dst_path)
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function fs.find(root, pattern, recursive)
|
||||
recursive = recursive ~= false
|
||||
local results = {}
|
||||
|
||||
local function search(dir)
|
||||
local success, entries = pcall(fs.list, dir)
|
||||
if not success then return end
|
||||
|
||||
for _, entry in ipairs(entries) do
|
||||
local full_path = fs.join(dir, entry.name)
|
||||
|
||||
if not entry.is_dir and entry.name:match(pattern) then
|
||||
results[#results + 1] = full_path
|
||||
elseif entry.is_dir and recursive then
|
||||
search(full_path)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
search(root)
|
||||
return results
|
||||
end
|
||||
|
||||
function fs.tree(root, max_depth)
|
||||
max_depth = max_depth or 10
|
||||
|
||||
local function build_tree(path, depth)
|
||||
if depth > max_depth then return nil end
|
||||
|
||||
if not fs.exists(path) then return nil end
|
||||
|
||||
local node = {
|
||||
name = fs.basename(path),
|
||||
path = path,
|
||||
is_dir = fs.is_dir(path)
|
||||
}
|
||||
|
||||
if node.is_dir then
|
||||
node.children = {}
|
||||
local success, entries = pcall(fs.list, path)
|
||||
if success then
|
||||
for _, entry in ipairs(entries) do
|
||||
local child_path = fs.join(path, entry.name)
|
||||
local child = build_tree(child_path, depth + 1)
|
||||
if child then
|
||||
node.children[#node.children + 1] = child
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
node.size = fs.size(path)
|
||||
node.mtime = fs.mtime(path)
|
||||
end
|
||||
|
||||
return node
|
||||
end
|
||||
|
||||
return build_tree(root, 1)
|
||||
end
|
||||
|
||||
return fs
|
||||
@ -1,358 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"Moonshark/metadata"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var (
|
||||
globalServer *fasthttp.Server
|
||||
globalWorkerPool *WorkerPool
|
||||
globalStateCreator StateCreator
|
||||
globalMu sync.RWMutex
|
||||
serverRunning bool
|
||||
staticHandlers = make(map[string]*fasthttp.FS)
|
||||
staticMu sync.RWMutex
|
||||
)
|
||||
|
||||
func SetStateCreator(creator StateCreator) {
|
||||
globalStateCreator = creator
|
||||
}
|
||||
|
||||
func GetFunctionList() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"http_create_server": http_create_server,
|
||||
"http_spawn_workers": http_spawn_workers,
|
||||
"http_listen": http_listen,
|
||||
"http_close_server": http_close_server,
|
||||
"http_has_servers": http_has_servers,
|
||||
"http_register_static": http_register_static,
|
||||
}
|
||||
}
|
||||
|
||||
func http_create_server(s *luajit.State) int {
|
||||
globalMu.Lock()
|
||||
defer globalMu.Unlock()
|
||||
|
||||
if globalServer != nil {
|
||||
s.PushBoolean(true) // Already created
|
||||
return 1
|
||||
}
|
||||
|
||||
globalServer = &fasthttp.Server{
|
||||
Name: "Moonshark/" + metadata.Version,
|
||||
Handler: handleRequest,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func http_spawn_workers(s *luajit.State) int {
|
||||
globalMu.Lock()
|
||||
defer globalMu.Unlock()
|
||||
|
||||
if globalWorkerPool != nil {
|
||||
s.PushBoolean(true) // Already spawned
|
||||
return 1
|
||||
}
|
||||
|
||||
if globalStateCreator == nil {
|
||||
s.PushBoolean(false)
|
||||
s.PushString("state creator not set")
|
||||
return 2
|
||||
}
|
||||
|
||||
workerCount := max(runtime.NumCPU(), 2)
|
||||
|
||||
pool, err := NewWorkerPool(workerCount, s, globalStateCreator)
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
s.PushString(fmt.Sprintf("failed to create worker pool: %v", err))
|
||||
return 2
|
||||
}
|
||||
globalWorkerPool = pool
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func http_listen(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
return s.PushError("http_listen: %v", err)
|
||||
}
|
||||
|
||||
addr := s.ToString(1)
|
||||
|
||||
globalMu.RLock()
|
||||
server := globalServer
|
||||
globalMu.RUnlock()
|
||||
|
||||
if server == nil {
|
||||
s.PushBoolean(false)
|
||||
s.PushString("no server created")
|
||||
return 2
|
||||
}
|
||||
|
||||
globalMu.Lock()
|
||||
if serverRunning {
|
||||
globalMu.Unlock()
|
||||
s.PushBoolean(true) // Already running
|
||||
return 1
|
||||
}
|
||||
serverRunning = true
|
||||
globalMu.Unlock()
|
||||
|
||||
go func() {
|
||||
if err := server.ListenAndServe(addr); err != nil {
|
||||
fmt.Printf("HTTP server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
globalMu.Lock()
|
||||
serverRunning = false
|
||||
globalMu.Unlock()
|
||||
s.PushBoolean(false)
|
||||
s.PushString(fmt.Sprintf("failed to start server: %v", err))
|
||||
return 2
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func http_close_server(s *luajit.State) int {
|
||||
StopAllServers()
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func http_has_servers(s *luajit.State) int {
|
||||
globalMu.RLock()
|
||||
running := serverRunning
|
||||
globalMu.RUnlock()
|
||||
|
||||
s.PushBoolean(running)
|
||||
return 1
|
||||
}
|
||||
|
||||
func http_register_static(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(2); err != nil {
|
||||
return s.PushError("http_register_static: %v", err)
|
||||
}
|
||||
|
||||
urlPrefix := s.ToString(1)
|
||||
rootPath := s.ToString(2)
|
||||
noCache := s.ToBoolean(3)
|
||||
|
||||
// Ensure prefix starts with /
|
||||
if !strings.HasPrefix(urlPrefix, "/") {
|
||||
urlPrefix = "/" + urlPrefix
|
||||
}
|
||||
|
||||
// Convert to absolute path
|
||||
absPath, err := filepath.Abs(rootPath)
|
||||
if err != nil {
|
||||
s.PushBoolean(false)
|
||||
s.PushString(fmt.Sprintf("invalid path: %v", err))
|
||||
return 2
|
||||
}
|
||||
|
||||
RegisterStaticHandler(urlPrefix, absPath, noCache)
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func HasActiveServers() bool {
|
||||
globalMu.RLock()
|
||||
defer globalMu.RUnlock()
|
||||
return serverRunning
|
||||
}
|
||||
|
||||
func WaitForServers() {
|
||||
for HasActiveServers() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func handleRequest(ctx *fasthttp.RequestCtx) {
|
||||
path := string(ctx.Path())
|
||||
|
||||
// Fast path for likely static files (has extension)
|
||||
if isLikelyStaticFile(path) && tryStaticHandler(ctx, path) {
|
||||
return
|
||||
}
|
||||
|
||||
// Try Lua routing
|
||||
globalMu.RLock()
|
||||
pool := globalWorkerPool
|
||||
globalMu.RUnlock()
|
||||
|
||||
if pool != nil {
|
||||
worker := pool.Get()
|
||||
if worker != nil {
|
||||
defer pool.Put(worker)
|
||||
|
||||
req := GetRequest()
|
||||
defer PutRequest(req)
|
||||
|
||||
resp := GetResponse()
|
||||
defer PutResponse(resp)
|
||||
|
||||
// Populate request
|
||||
req.Method = string(ctx.Method())
|
||||
req.Path = path
|
||||
req.Body = string(ctx.Request.Body())
|
||||
|
||||
for key, value := range ctx.QueryArgs().All() {
|
||||
req.Query[string(key)] = string(value)
|
||||
}
|
||||
|
||||
for key, value := range ctx.Request.Header.All() {
|
||||
req.Headers[string(key)] = string(value)
|
||||
}
|
||||
|
||||
err := worker.HandleRequest(req, resp)
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetBodyString(fmt.Sprintf("Internal Server Error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// If Lua found a route, use it
|
||||
if resp.StatusCode != 404 {
|
||||
ctx.SetStatusCode(resp.StatusCode)
|
||||
for key, value := range resp.Headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
if resp.Body != "" {
|
||||
ctx.SetBodyString(resp.Body)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isLikelyStaticFile(path) && tryStaticHandler(ctx, path) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("Not Found")
|
||||
}
|
||||
|
||||
func tryStaticHandler(ctx *fasthttp.RequestCtx, path string) bool {
|
||||
staticMu.RLock()
|
||||
defer staticMu.RUnlock()
|
||||
|
||||
for prefix, fs := range staticHandlers {
|
||||
if after, ok := strings.CutPrefix(path, prefix); ok {
|
||||
ctx.Request.URI().SetPath(after)
|
||||
fs.NewRequestHandler()(ctx)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isLikelyStaticFile(path string) bool {
|
||||
// Check for file extension
|
||||
lastSlash := strings.LastIndex(path, "/")
|
||||
lastDot := strings.LastIndex(path, ".")
|
||||
return lastDot > lastSlash && lastDot != -1
|
||||
}
|
||||
|
||||
// RegisterStaticHandler adds a static file handler
|
||||
func RegisterStaticHandler(urlPrefix, rootPath string, noCache bool) {
|
||||
staticMu.Lock()
|
||||
defer staticMu.Unlock()
|
||||
|
||||
var cacheDuration time.Duration
|
||||
var compress bool
|
||||
if noCache {
|
||||
cacheDuration = 0
|
||||
compress = false
|
||||
} else {
|
||||
cacheDuration = 3600 * time.Second
|
||||
compress = true
|
||||
}
|
||||
|
||||
fs := &fasthttp.FS{
|
||||
Root: rootPath,
|
||||
CompressRoot: rootPath + "/.cache",
|
||||
IndexNames: []string{"index.html"},
|
||||
GenerateIndexPages: false,
|
||||
Compress: compress,
|
||||
CompressBrotli: compress,
|
||||
CompressZstd: compress,
|
||||
CacheDuration: cacheDuration,
|
||||
AcceptByteRange: true,
|
||||
PathNotFound: func(ctx *fasthttp.RequestCtx) {
|
||||
path := ctx.Path()
|
||||
fmt.Printf("404 not found: %s\n", path)
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("404 not found")
|
||||
},
|
||||
}
|
||||
|
||||
staticHandlers[urlPrefix] = fs
|
||||
}
|
||||
|
||||
func StopAllServers() {
|
||||
globalMu.Lock()
|
||||
defer globalMu.Unlock()
|
||||
|
||||
// Start shutting down both in parallel
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Close worker pool in parallel
|
||||
if globalWorkerPool != nil {
|
||||
wg.Add(1)
|
||||
pool := globalWorkerPool
|
||||
globalWorkerPool = nil
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pool.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
// Shutdown server with 100ms timeout
|
||||
if globalServer != nil {
|
||||
wg.Add(1)
|
||||
server := globalServer
|
||||
globalServer = nil
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
if err := server.ShutdownWithContext(ctx); err != nil {
|
||||
// Force close if graceful shutdown times out
|
||||
server.CloseOnShutdown = true
|
||||
_ = server.Shutdown()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
serverRunning = false
|
||||
|
||||
// Wait for both to complete
|
||||
wg.Wait()
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,346 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
type StateCreator func() (*luajit.State, error)
|
||||
|
||||
type Request struct {
|
||||
Method string
|
||||
Path string
|
||||
Query map[string]string
|
||||
Headers map[string]string
|
||||
Body string
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
StatusCode int
|
||||
Headers map[string]string
|
||||
Body string
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
state *luajit.State
|
||||
id int
|
||||
}
|
||||
|
||||
type WorkerPool struct {
|
||||
workers chan *Worker
|
||||
masterState *luajit.State
|
||||
stateCreator StateCreator
|
||||
workerCount int
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
requestPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Request{
|
||||
Query: make(map[string]string),
|
||||
Headers: make(map[string]string),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
responsePool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Response{
|
||||
Headers: make(map[string]string),
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func GetRequest() *Request {
|
||||
req := requestPool.Get().(*Request)
|
||||
for k := range req.Query {
|
||||
delete(req.Query, k)
|
||||
}
|
||||
for k := range req.Headers {
|
||||
delete(req.Headers, k)
|
||||
}
|
||||
req.Method = ""
|
||||
req.Path = ""
|
||||
req.Body = ""
|
||||
return req
|
||||
}
|
||||
|
||||
func PutRequest(req *Request) {
|
||||
requestPool.Put(req)
|
||||
}
|
||||
|
||||
func GetResponse() *Response {
|
||||
resp := responsePool.Get().(*Response)
|
||||
for k := range resp.Headers {
|
||||
delete(resp.Headers, k)
|
||||
}
|
||||
resp.StatusCode = 200
|
||||
resp.Body = ""
|
||||
return resp
|
||||
}
|
||||
|
||||
func PutResponse(resp *Response) {
|
||||
responsePool.Put(resp)
|
||||
}
|
||||
|
||||
func NewWorkerPool(size int, masterState *luajit.State, stateCreator StateCreator) (*WorkerPool, error) {
|
||||
pool := &WorkerPool{
|
||||
workers: make(chan *Worker, size),
|
||||
masterState: masterState,
|
||||
stateCreator: stateCreator,
|
||||
workerCount: size,
|
||||
}
|
||||
|
||||
for i := range size {
|
||||
worker, err := pool.createWorker(i)
|
||||
if err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("failed to create worker %d: %w", i, err)
|
||||
}
|
||||
pool.workers <- worker
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func (p *WorkerPool) createWorker(id int) (*Worker, error) {
|
||||
workerState, err := p.stateCreator()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create worker state: %w", err)
|
||||
}
|
||||
|
||||
return &Worker{
|
||||
state: workerState,
|
||||
id: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *WorkerPool) Get() *Worker {
|
||||
p.mu.RLock()
|
||||
if p.closed {
|
||||
p.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
select {
|
||||
case worker := <-p.workers:
|
||||
return worker
|
||||
default:
|
||||
worker, err := p.createWorker(-1)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return worker
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WorkerPool) Put(worker *Worker) {
|
||||
if worker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.closed {
|
||||
worker.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if worker.id == -1 {
|
||||
worker.Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case p.workers <- worker:
|
||||
default:
|
||||
worker.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WorkerPool) SyncRoutes(routesData any) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
|
||||
// Sync routes to all workers
|
||||
workers := make([]*Worker, 0, p.workerCount)
|
||||
|
||||
// Collect all workers
|
||||
for {
|
||||
select {
|
||||
case worker := <-p.workers:
|
||||
workers = append(workers, worker)
|
||||
default:
|
||||
goto syncWorkers
|
||||
}
|
||||
}
|
||||
|
||||
syncWorkers:
|
||||
// Sync and return workers
|
||||
for _, worker := range workers {
|
||||
if worker.state != nil {
|
||||
worker.state.PushValue(routesData)
|
||||
worker.state.SetGlobal("_http_routes_data")
|
||||
|
||||
worker.state.GetGlobal("_http_sync_worker_routes")
|
||||
if worker.state.IsFunction(-1) {
|
||||
worker.state.Call(0, 0)
|
||||
} else {
|
||||
worker.state.Pop(1)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case p.workers <- worker:
|
||||
default:
|
||||
worker.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WorkerPool) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
|
||||
// Collect all workers first
|
||||
workers := make([]*Worker, 0, len(p.workers))
|
||||
close(p.workers)
|
||||
for worker := range p.workers {
|
||||
workers = append(workers, worker)
|
||||
}
|
||||
|
||||
// Close all workers in parallel
|
||||
var wg sync.WaitGroup
|
||||
for _, worker := range workers {
|
||||
wg.Add(1)
|
||||
go func(w *Worker) {
|
||||
defer wg.Done()
|
||||
w.Close()
|
||||
}(worker)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (w *Worker) Close() {
|
||||
if w.state != nil {
|
||||
w.state.Close()
|
||||
w.state = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) HandleRequest(req *Request, resp *Response) error {
|
||||
if w.state == nil {
|
||||
return fmt.Errorf("worker state is nil")
|
||||
}
|
||||
|
||||
// Create request table
|
||||
w.state.NewTable()
|
||||
w.state.PushString("method")
|
||||
w.state.PushString(req.Method)
|
||||
w.state.SetTable(-3)
|
||||
|
||||
w.state.PushString("path")
|
||||
w.state.PushString(req.Path)
|
||||
w.state.SetTable(-3)
|
||||
|
||||
w.state.PushString("body")
|
||||
w.state.PushString(req.Body)
|
||||
w.state.SetTable(-3)
|
||||
|
||||
// Query params
|
||||
w.state.PushString("query")
|
||||
w.state.NewTable()
|
||||
for k, v := range req.Query {
|
||||
w.state.PushString(k)
|
||||
w.state.PushString(v)
|
||||
w.state.SetTable(-3)
|
||||
}
|
||||
w.state.SetTable(-3)
|
||||
|
||||
// Headers
|
||||
w.state.PushString("headers")
|
||||
w.state.NewTable()
|
||||
for k, v := range req.Headers {
|
||||
w.state.PushString(k)
|
||||
w.state.PushString(v)
|
||||
w.state.SetTable(-3)
|
||||
}
|
||||
w.state.SetTable(-3)
|
||||
|
||||
// Create response table
|
||||
w.state.NewTable()
|
||||
w.state.PushString("status")
|
||||
w.state.PushNumber(200)
|
||||
w.state.SetTable(-3)
|
||||
|
||||
w.state.PushString("body")
|
||||
w.state.PushString("")
|
||||
w.state.SetTable(-3)
|
||||
|
||||
w.state.PushString("headers")
|
||||
w.state.NewTable()
|
||||
w.state.SetTable(-3)
|
||||
|
||||
// Call _http_handle_request(req, res) - pure Lua routing
|
||||
w.state.GetGlobal("_http_handle_request")
|
||||
if !w.state.IsFunction(-1) {
|
||||
w.state.Pop(3)
|
||||
resp.StatusCode = 500
|
||||
resp.Body = "HTTP handler not initialized"
|
||||
return nil
|
||||
}
|
||||
|
||||
w.state.PushCopy(-3) // request
|
||||
w.state.PushCopy(-3) // response
|
||||
|
||||
if err := w.state.Call(2, 0); err != nil {
|
||||
w.state.Pop(2)
|
||||
resp.StatusCode = 500
|
||||
resp.Body = fmt.Sprintf("Handler error: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract response
|
||||
w.state.GetField(-1, "status")
|
||||
if w.state.IsNumber(-1) {
|
||||
resp.StatusCode = int(w.state.ToNumber(-1))
|
||||
}
|
||||
w.state.Pop(1)
|
||||
|
||||
w.state.GetField(-1, "body")
|
||||
if w.state.IsString(-1) {
|
||||
resp.Body = w.state.ToString(-1)
|
||||
}
|
||||
w.state.Pop(1)
|
||||
|
||||
w.state.GetField(-1, "headers")
|
||||
if w.state.IsTable(-1) {
|
||||
w.state.PushNil()
|
||||
for w.state.Next(-2) {
|
||||
if w.state.IsString(-2) && w.state.IsString(-1) {
|
||||
resp.Headers[w.state.ToString(-2)] = w.state.ToString(-1)
|
||||
}
|
||||
w.state.Pop(1)
|
||||
}
|
||||
}
|
||||
w.state.Pop(1)
|
||||
|
||||
w.state.Pop(2) // Clean up request and response tables
|
||||
return nil
|
||||
}
|
||||
@ -1,598 +0,0 @@
|
||||
json = {}
|
||||
|
||||
function json.encode(value)
|
||||
local buffer = {}
|
||||
local pos = 1
|
||||
|
||||
local function encode_string(s)
|
||||
buffer[pos] = '"'
|
||||
pos = pos + 1
|
||||
|
||||
local start = 1
|
||||
for i = 1, #s do
|
||||
local c = s:byte(i)
|
||||
if c == 34 then -- "
|
||||
if i > start then
|
||||
buffer[pos] = s:sub(start, i - 1)
|
||||
pos = pos + 1
|
||||
end
|
||||
buffer[pos] = '\\"'
|
||||
pos = pos + 1
|
||||
start = i + 1
|
||||
elseif c == 92 then -- \
|
||||
if i > start then
|
||||
buffer[pos] = s:sub(start, i - 1)
|
||||
pos = pos + 1
|
||||
end
|
||||
buffer[pos] = '\\\\'
|
||||
pos = pos + 1
|
||||
start = i + 1
|
||||
elseif c < 32 then
|
||||
if i > start then
|
||||
buffer[pos] = s:sub(start, i - 1)
|
||||
pos = pos + 1
|
||||
end
|
||||
if c == 8 then
|
||||
buffer[pos] = '\\b'
|
||||
elseif c == 9 then
|
||||
buffer[pos] = '\\t'
|
||||
elseif c == 10 then
|
||||
buffer[pos] = '\\n'
|
||||
elseif c == 12 then
|
||||
buffer[pos] = '\\f'
|
||||
elseif c == 13 then
|
||||
buffer[pos] = '\\r'
|
||||
else
|
||||
buffer[pos] = ('\\u%04x'):format(c)
|
||||
end
|
||||
pos = pos + 1
|
||||
start = i + 1
|
||||
end
|
||||
end
|
||||
|
||||
if start <= #s then
|
||||
buffer[pos] = s:sub(start)
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
buffer[pos] = '"'
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
local function encode_value(v, depth)
|
||||
local t = type(v)
|
||||
|
||||
if t == 'string' then
|
||||
encode_string(v)
|
||||
elseif t == 'number' then
|
||||
if v ~= v then -- NaN
|
||||
buffer[pos] = 'null'
|
||||
elseif v == 1/0 or v == -1/0 then -- Infinity
|
||||
buffer[pos] = 'null'
|
||||
else
|
||||
buffer[pos] = tostring(v)
|
||||
end
|
||||
pos = pos + 1
|
||||
elseif t == 'boolean' then
|
||||
buffer[pos] = v and 'true' or 'false'
|
||||
pos = pos + 1
|
||||
elseif t == 'table' then
|
||||
if depth > 100 then error('circular reference') end
|
||||
|
||||
local is_array = true
|
||||
local max_index = 0
|
||||
local count = 0
|
||||
|
||||
for k, _ in pairs(v) do
|
||||
count = count + 1
|
||||
if type(k) ~= 'number' or k <= 0 or k % 1 ~= 0 then
|
||||
is_array = false
|
||||
break
|
||||
end
|
||||
if k > max_index then max_index = k end
|
||||
end
|
||||
|
||||
if is_array and count == max_index then
|
||||
buffer[pos] = '['
|
||||
pos = pos + 1
|
||||
|
||||
for i = 1, max_index do
|
||||
if i > 1 then
|
||||
buffer[pos] = ','
|
||||
pos = pos + 1
|
||||
end
|
||||
encode_value(v[i], depth + 1)
|
||||
end
|
||||
|
||||
buffer[pos] = ']'
|
||||
pos = pos + 1
|
||||
else
|
||||
buffer[pos] = '{'
|
||||
pos = pos + 1
|
||||
|
||||
local first = true
|
||||
for k, val in pairs(v) do
|
||||
if not first then
|
||||
buffer[pos] = ','
|
||||
pos = pos + 1
|
||||
end
|
||||
first = false
|
||||
|
||||
encode_string(tostring(k))
|
||||
buffer[pos] = ':'
|
||||
pos = pos + 1
|
||||
encode_value(val, depth + 1)
|
||||
end
|
||||
|
||||
buffer[pos] = '}'
|
||||
pos = pos + 1
|
||||
end
|
||||
else
|
||||
buffer[pos] = 'null'
|
||||
pos = pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
encode_value(value, 0)
|
||||
return table.concat(buffer)
|
||||
end
|
||||
|
||||
function json.decode(str)
|
||||
local pos = 1
|
||||
local len = #str
|
||||
|
||||
local function skip_whitespace()
|
||||
while pos <= len do
|
||||
local c = str:byte(pos)
|
||||
if c ~= 32 and c ~= 9 and c ~= 10 and c ~= 13 then break end
|
||||
pos = pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
local function decode_string()
|
||||
local start = pos + 1
|
||||
pos = pos + 1
|
||||
|
||||
while pos <= len do
|
||||
local c = str:byte(pos)
|
||||
if c == 34 then -- "
|
||||
local result = str:sub(start, pos - 1)
|
||||
pos = pos + 1
|
||||
|
||||
if result:find('\\') then
|
||||
result = result:gsub('\\(.)', {
|
||||
['"'] = '"',
|
||||
['\\'] = '\\',
|
||||
['/'] = '/',
|
||||
['b'] = '\b',
|
||||
['f'] = '\f',
|
||||
['n'] = '\n',
|
||||
['r'] = '\r',
|
||||
['t'] = '\t'
|
||||
})
|
||||
result = result:gsub('\\u(%x%x%x%x)', function(hex)
|
||||
return string.char(tonumber(hex, 16))
|
||||
end)
|
||||
end
|
||||
|
||||
return result
|
||||
elseif c == 92 then -- \
|
||||
pos = pos + 2
|
||||
else
|
||||
pos = pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
error('unterminated string')
|
||||
end
|
||||
|
||||
local function decode_number()
|
||||
local start = pos
|
||||
local c = str:byte(pos)
|
||||
|
||||
if c == 45 then pos = pos + 1 end -- -
|
||||
|
||||
c = str:byte(pos)
|
||||
if not c or c < 48 or c > 57 then error('invalid number') end
|
||||
|
||||
if c == 48 then
|
||||
pos = pos + 1
|
||||
else
|
||||
while pos <= len do
|
||||
c = str:byte(pos)
|
||||
if c < 48 or c > 57 then break end
|
||||
pos = pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
if pos <= len and str:byte(pos) == 46 then -- .
|
||||
pos = pos + 1
|
||||
local found_digit = false
|
||||
while pos <= len do
|
||||
c = str:byte(pos)
|
||||
if c < 48 or c > 57 then break end
|
||||
found_digit = true
|
||||
pos = pos + 1
|
||||
end
|
||||
if not found_digit then error('invalid number') end
|
||||
end
|
||||
|
||||
if pos <= len then
|
||||
c = str:byte(pos)
|
||||
if c == 101 or c == 69 then -- e or E
|
||||
pos = pos + 1
|
||||
if pos <= len then
|
||||
c = str:byte(pos)
|
||||
if c == 43 or c == 45 then pos = pos + 1 end -- + or -
|
||||
end
|
||||
local found_digit = false
|
||||
while pos <= len do
|
||||
c = str:byte(pos)
|
||||
if c < 48 or c > 57 then break end
|
||||
found_digit = true
|
||||
pos = pos + 1
|
||||
end
|
||||
if not found_digit then error('invalid number') end
|
||||
end
|
||||
end
|
||||
|
||||
return tonumber(str:sub(start, pos - 1))
|
||||
end
|
||||
|
||||
local function decode_value()
|
||||
skip_whitespace()
|
||||
if pos > len then error('unexpected end') end
|
||||
|
||||
local c = str:byte(pos)
|
||||
|
||||
if c == 34 then -- "
|
||||
return decode_string()
|
||||
elseif c == 123 then -- {
|
||||
local result = {}
|
||||
pos = pos + 1
|
||||
skip_whitespace()
|
||||
|
||||
if pos <= len and str:byte(pos) == 125 then -- }
|
||||
pos = pos + 1
|
||||
return result
|
||||
end
|
||||
|
||||
while true do
|
||||
skip_whitespace()
|
||||
if pos > len or str:byte(pos) ~= 34 then error('expected string key') end
|
||||
|
||||
local key = decode_string()
|
||||
skip_whitespace()
|
||||
|
||||
if pos > len or str:byte(pos) ~= 58 then error('expected :') end
|
||||
pos = pos + 1
|
||||
|
||||
result[key] = decode_value()
|
||||
skip_whitespace()
|
||||
|
||||
if pos > len then error('unexpected end') end
|
||||
c = str:byte(pos)
|
||||
|
||||
if c == 125 then -- }
|
||||
pos = pos + 1
|
||||
return result
|
||||
elseif c == 44 then -- ,
|
||||
pos = pos + 1
|
||||
else
|
||||
error('expected , or }')
|
||||
end
|
||||
end
|
||||
|
||||
elseif c == 91 then -- [
|
||||
local result = {}
|
||||
local index = 1
|
||||
pos = pos + 1
|
||||
skip_whitespace()
|
||||
|
||||
if pos <= len and str:byte(pos) == 93 then -- ]
|
||||
pos = pos + 1
|
||||
return result
|
||||
end
|
||||
|
||||
while true do
|
||||
result[index] = decode_value()
|
||||
index = index + 1
|
||||
skip_whitespace()
|
||||
|
||||
if pos > len then error('unexpected end') end
|
||||
c = str:byte(pos)
|
||||
|
||||
if c == 93 then -- ]
|
||||
pos = pos + 1
|
||||
return result
|
||||
elseif c == 44 then -- ,
|
||||
pos = pos + 1
|
||||
else
|
||||
error('expected , or ]')
|
||||
end
|
||||
end
|
||||
|
||||
elseif c == 116 then -- true
|
||||
if str:sub(pos, pos + 3) == 'true' then
|
||||
pos = pos + 4
|
||||
return true
|
||||
end
|
||||
error('invalid literal')
|
||||
|
||||
elseif c == 102 then -- false
|
||||
if str:sub(pos, pos + 4) == 'false' then
|
||||
pos = pos + 5
|
||||
return false
|
||||
end
|
||||
error('invalid literal')
|
||||
|
||||
elseif c == 110 then -- null
|
||||
if str:sub(pos, pos + 3) == 'null' then
|
||||
pos = pos + 4
|
||||
return nil
|
||||
end
|
||||
error('invalid literal')
|
||||
|
||||
elseif (c >= 48 and c <= 57) or c == 45 then -- 0-9 or -
|
||||
return decode_number()
|
||||
|
||||
else
|
||||
error('unexpected character')
|
||||
end
|
||||
end
|
||||
|
||||
local result = decode_value()
|
||||
skip_whitespace()
|
||||
if pos <= len then error('unexpected content after JSON') end
|
||||
return result
|
||||
end
|
||||
|
||||
function json.load_file(filename)
|
||||
local file = io.open(filename, "r")
|
||||
if not file then
|
||||
error("Cannot open file: " .. filename)
|
||||
end
|
||||
|
||||
local content = file:read("*all")
|
||||
file:close()
|
||||
|
||||
return json.decode(content)
|
||||
end
|
||||
|
||||
function json.save_file(filename, data)
|
||||
local file = io.open(filename, "w")
|
||||
if not file then
|
||||
error("Cannot write to file: " .. filename)
|
||||
end
|
||||
|
||||
file:write(json.encode(data))
|
||||
file:close()
|
||||
end
|
||||
|
||||
function json.merge(...)
|
||||
local result = {}
|
||||
local n = select("#", ...)
|
||||
for i = 1, n do
|
||||
local obj = select(i, ...)
|
||||
if type(obj) == "table" then
|
||||
for k, v in pairs(obj) do
|
||||
result[k] = v
|
||||
end
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
function json.extract(data, path)
|
||||
local current = data
|
||||
local start = 1
|
||||
local len = #path
|
||||
|
||||
while start <= len do
|
||||
local dot_pos = path:find(".", start, true)
|
||||
local part = dot_pos and path:sub(start, dot_pos - 1) or path:sub(start)
|
||||
|
||||
if type(current) ~= "table" then
|
||||
return nil
|
||||
end
|
||||
|
||||
local bracket_start, bracket_end = part:find("^%[(%d+)%]$")
|
||||
if bracket_start then
|
||||
local index = tonumber(part:sub(2, -2)) + 1
|
||||
current = current[index]
|
||||
else
|
||||
current = current[part]
|
||||
end
|
||||
|
||||
if current == nil then
|
||||
return nil
|
||||
end
|
||||
|
||||
start = dot_pos and dot_pos + 1 or len + 1
|
||||
end
|
||||
|
||||
return current
|
||||
end
|
||||
|
||||
function json.pretty(value, indent)
|
||||
local buffer = {}
|
||||
local pos = 1
|
||||
indent = indent or " "
|
||||
|
||||
local function encode_string(s)
|
||||
buffer[pos] = '"'
|
||||
pos = pos + 1
|
||||
|
||||
local start = 1
|
||||
for i = 1, #s do
|
||||
local c = s:byte(i)
|
||||
if c == 34 then -- "
|
||||
if i > start then
|
||||
buffer[pos] = s:sub(start, i - 1)
|
||||
pos = pos + 1
|
||||
end
|
||||
buffer[pos] = '\\"'
|
||||
pos = pos + 1
|
||||
start = i + 1
|
||||
elseif c == 92 then -- \
|
||||
if i > start then
|
||||
buffer[pos] = s:sub(start, i - 1)
|
||||
pos = pos + 1
|
||||
end
|
||||
buffer[pos] = '\\\\'
|
||||
pos = pos + 1
|
||||
start = i + 1
|
||||
elseif c < 32 then
|
||||
if i > start then
|
||||
buffer[pos] = s:sub(start, i - 1)
|
||||
pos = pos + 1
|
||||
end
|
||||
if c == 8 then
|
||||
buffer[pos] = '\\b'
|
||||
elseif c == 9 then
|
||||
buffer[pos] = '\\t'
|
||||
elseif c == 10 then
|
||||
buffer[pos] = '\\n'
|
||||
elseif c == 12 then
|
||||
buffer[pos] = '\\f'
|
||||
elseif c == 13 then
|
||||
buffer[pos] = '\\r'
|
||||
else
|
||||
buffer[pos] = ('\\u%04x'):format(c)
|
||||
end
|
||||
pos = pos + 1
|
||||
start = i + 1
|
||||
end
|
||||
end
|
||||
|
||||
if start <= #s then
|
||||
buffer[pos] = s:sub(start)
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
buffer[pos] = '"'
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
local function encode_value(v, depth)
|
||||
local t = type(v)
|
||||
local current_indent = string.rep(indent, depth)
|
||||
local next_indent = string.rep(indent, depth + 1)
|
||||
|
||||
if t == 'string' then
|
||||
encode_string(v)
|
||||
elseif t == 'number' then
|
||||
if v ~= v then -- NaN
|
||||
buffer[pos] = 'null'
|
||||
elseif v == 1/0 or v == -1/0 then -- Infinity
|
||||
buffer[pos] = 'null'
|
||||
else
|
||||
buffer[pos] = tostring(v)
|
||||
end
|
||||
pos = pos + 1
|
||||
elseif t == 'boolean' then
|
||||
buffer[pos] = v and 'true' or 'false'
|
||||
pos = pos + 1
|
||||
elseif t == 'table' then
|
||||
if depth > 100 then error('circular reference') end
|
||||
|
||||
local is_array = true
|
||||
local max_index = 0
|
||||
local count = 0
|
||||
|
||||
for k, _ in pairs(v) do
|
||||
count = count + 1
|
||||
if type(k) ~= 'number' or k <= 0 or k % 1 ~= 0 then
|
||||
is_array = false
|
||||
break
|
||||
end
|
||||
if k > max_index then max_index = k end
|
||||
end
|
||||
|
||||
if is_array and count == max_index then
|
||||
buffer[pos] = '[\n'
|
||||
pos = pos + 1
|
||||
|
||||
for i = 1, max_index do
|
||||
buffer[pos] = next_indent
|
||||
pos = pos + 1
|
||||
encode_value(v[i], depth + 1)
|
||||
if i < max_index then
|
||||
buffer[pos] = ','
|
||||
pos = pos + 1
|
||||
end
|
||||
buffer[pos] = '\n'
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
buffer[pos] = current_indent .. ']'
|
||||
pos = pos + 1
|
||||
else
|
||||
buffer[pos] = '{\n'
|
||||
pos = pos + 1
|
||||
|
||||
local keys = {}
|
||||
for k in pairs(v) do
|
||||
keys[#keys + 1] = k
|
||||
end
|
||||
|
||||
for i, k in ipairs(keys) do
|
||||
buffer[pos] = next_indent
|
||||
pos = pos + 1
|
||||
encode_string(tostring(k))
|
||||
buffer[pos] = ': '
|
||||
pos = pos + 1
|
||||
encode_value(v[k], depth + 1)
|
||||
if i < #keys then
|
||||
buffer[pos] = ','
|
||||
pos = pos + 1
|
||||
end
|
||||
buffer[pos] = '\n'
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
buffer[pos] = current_indent .. '}'
|
||||
pos = pos + 1
|
||||
end
|
||||
else
|
||||
buffer[pos] = 'null'
|
||||
pos = pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
encode_value(value, 0)
|
||||
return table.concat(buffer)
|
||||
end
|
||||
|
||||
function json.validate(data, schema)
|
||||
local function validate_value(value, schema_value)
|
||||
local value_type = type(value)
|
||||
local schema_type = schema_value.type
|
||||
|
||||
if schema_type and value_type ~= schema_type then
|
||||
return false, "Expected " .. schema_type .. ", got " .. value_type
|
||||
end
|
||||
|
||||
if schema_type == "table" and schema_value.properties then
|
||||
local required = schema_value.required
|
||||
for prop, prop_schema in pairs(schema_value.properties) do
|
||||
local prop_value = value[prop]
|
||||
|
||||
if required and required[prop] and prop_value == nil then
|
||||
return false, "Missing required property: " .. prop
|
||||
end
|
||||
|
||||
if prop_value ~= nil then
|
||||
local valid, err = validate_value(prop_value, prop_schema)
|
||||
if not valid then
|
||||
return false, "Property " .. prop .. ": " .. err
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
return validate_value(data, schema)
|
||||
end
|
||||
505
modules/kv/kv.go
505
modules/kv/kv.go
@ -1,505 +0,0 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
var (
|
||||
stores = make(map[string]*Store)
|
||||
mutex sync.RWMutex
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
data map[string]string
|
||||
expires map[string]int64
|
||||
filename string
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func GetFunctionList() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"kv_open": kv_open,
|
||||
"kv_get": kv_get,
|
||||
"kv_set": kv_set,
|
||||
"kv_delete": kv_delete,
|
||||
"kv_clear": kv_clear,
|
||||
"kv_has": kv_has,
|
||||
"kv_size": kv_size,
|
||||
"kv_keys": kv_keys,
|
||||
"kv_values": kv_values,
|
||||
"kv_save": kv_save,
|
||||
"kv_close": kv_close,
|
||||
"kv_increment": kv_increment,
|
||||
"kv_append": kv_append,
|
||||
"kv_expire": kv_expire,
|
||||
"kv_cleanup_expired": kv_cleanup_expired,
|
||||
}
|
||||
}
|
||||
|
||||
// kv_open(name, filename) -> boolean
|
||||
func kv_open(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
filename := s.ToString(2)
|
||||
|
||||
if name == "" {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
if store, exists := stores[name]; exists {
|
||||
if filename != "" && store.filename != filename {
|
||||
store.filename = filename
|
||||
}
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
store := &Store{
|
||||
data: make(map[string]string),
|
||||
expires: make(map[string]int64),
|
||||
filename: filename,
|
||||
}
|
||||
|
||||
if filename != "" {
|
||||
store.load()
|
||||
}
|
||||
|
||||
stores[name] = store
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_get(name, key, default) -> value or default
|
||||
func kv_get(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
hasDefault := s.GetTop() >= 3
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
if hasDefault {
|
||||
s.PushCopy(3)
|
||||
} else {
|
||||
s.PushNil()
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.RLock()
|
||||
value, found := store.data[key]
|
||||
store.mutex.RUnlock()
|
||||
|
||||
if found {
|
||||
s.PushString(value)
|
||||
} else if hasDefault {
|
||||
s.PushCopy(3)
|
||||
} else {
|
||||
s.PushNil()
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_set(name, key, value) -> boolean
|
||||
func kv_set(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
value := s.ToString(3)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.Lock()
|
||||
store.data[key] = value
|
||||
store.mutex.Unlock()
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_delete(name, key) -> boolean
|
||||
func kv_delete(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.Lock()
|
||||
_, existed := store.data[key]
|
||||
delete(store.data, key)
|
||||
delete(store.expires, key)
|
||||
store.mutex.Unlock()
|
||||
|
||||
s.PushBoolean(existed)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_clear(name) -> boolean
|
||||
func kv_clear(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.Lock()
|
||||
store.data = make(map[string]string)
|
||||
store.expires = make(map[string]int64)
|
||||
store.mutex.Unlock()
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_has(name, key) -> boolean
|
||||
func kv_has(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.RLock()
|
||||
_, found := store.data[key]
|
||||
store.mutex.RUnlock()
|
||||
|
||||
s.PushBoolean(found)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_size(name) -> number
|
||||
func kv_size(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.PushNumber(0)
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.RLock()
|
||||
size := len(store.data)
|
||||
store.mutex.RUnlock()
|
||||
|
||||
s.PushNumber(float64(size))
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_keys(name) -> table
|
||||
func kv_keys(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.NewTable()
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.RLock()
|
||||
keys := make([]string, 0, len(store.data))
|
||||
for k := range store.data {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
store.mutex.RUnlock()
|
||||
|
||||
s.PushValue(keys)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_values(name) -> table
|
||||
func kv_values(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.NewTable()
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.RLock()
|
||||
values := make([]string, 0, len(store.data))
|
||||
for _, v := range store.data {
|
||||
values = append(values, v)
|
||||
}
|
||||
store.mutex.RUnlock()
|
||||
|
||||
s.PushValue(values)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_save(name) -> boolean
|
||||
func kv_save(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists || store.filename == "" {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
err := store.save()
|
||||
s.PushBoolean(err == nil)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_close(name) -> boolean
|
||||
func kv_close(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
store, exists := stores[name]
|
||||
if !exists {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
if store.filename != "" {
|
||||
store.save()
|
||||
}
|
||||
|
||||
delete(stores, name)
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_increment(name, key, delta) -> number
|
||||
func kv_increment(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
delta := 1.0
|
||||
if s.GetTop() >= 3 {
|
||||
delta = s.ToNumber(3)
|
||||
}
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.PushNumber(0)
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.Lock()
|
||||
current, _ := strconv.ParseFloat(store.data[key], 64)
|
||||
newValue := current + delta
|
||||
store.data[key] = strconv.FormatFloat(newValue, 'g', -1, 64)
|
||||
store.mutex.Unlock()
|
||||
|
||||
s.PushNumber(newValue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_append(name, key, value, separator) -> boolean
|
||||
func kv_append(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
value := s.ToString(3)
|
||||
separator := ""
|
||||
if s.GetTop() >= 4 {
|
||||
separator = s.ToString(4)
|
||||
}
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.Lock()
|
||||
current := store.data[key]
|
||||
if current == "" {
|
||||
store.data[key] = value
|
||||
} else {
|
||||
store.data[key] = current + separator + value
|
||||
}
|
||||
store.mutex.Unlock()
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_expire(name, key, ttl) -> boolean
|
||||
func kv_expire(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
key := s.ToString(2)
|
||||
ttl := s.ToNumber(3)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
store.mutex.Lock()
|
||||
store.expires[key] = time.Now().Unix() + int64(ttl)
|
||||
store.mutex.Unlock()
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// kv_cleanup_expired(name) -> number
|
||||
func kv_cleanup_expired(s *luajit.State) int {
|
||||
name := s.ToString(1)
|
||||
|
||||
mutex.RLock()
|
||||
store, exists := stores[name]
|
||||
mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
s.PushNumber(0)
|
||||
return 1
|
||||
}
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
deleted := 0
|
||||
|
||||
store.mutex.Lock()
|
||||
for key, expireTime := range store.expires {
|
||||
if currentTime >= expireTime {
|
||||
delete(store.data, key)
|
||||
delete(store.expires, key)
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
store.mutex.Unlock()
|
||||
|
||||
s.PushNumber(float64(deleted))
|
||||
return 1
|
||||
}
|
||||
|
||||
func (store *Store) load() error {
|
||||
if store.filename == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := os.Open(store.filename)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if strings.HasSuffix(store.filename, ".json") {
|
||||
decoder := json.NewDecoder(file)
|
||||
return decoder.Decode(&store.data)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
store.data[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (store *Store) save() error {
|
||||
if store.filename == "" {
|
||||
return fmt.Errorf("no filename specified")
|
||||
}
|
||||
|
||||
file, err := os.Create(store.filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
store.mutex.RLock()
|
||||
defer store.mutex.RUnlock()
|
||||
|
||||
if strings.HasSuffix(store.filename, ".json") {
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", "\t")
|
||||
return encoder.Encode(store.data)
|
||||
}
|
||||
|
||||
for key, value := range store.data {
|
||||
if _, err := fmt.Fprintf(file, "%s=%s\n", key, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseAllStores saves and closes all open stores
|
||||
func CloseAllStores() {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
for name, store := range stores {
|
||||
if store.filename != "" {
|
||||
store.save()
|
||||
}
|
||||
delete(stores, name)
|
||||
}
|
||||
}
|
||||
@ -1,196 +0,0 @@
|
||||
local kv = {}
|
||||
|
||||
-- ======================================================================
|
||||
-- BASIC KEY-VALUE OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
function kv.open(name, filename)
|
||||
if type(name) ~= "string" then error("kv.open: store name must be a string", 2) end
|
||||
if filename ~= nil and type(filename) ~= "string" then error("kv.open: filename must be a string", 2) end
|
||||
|
||||
filename = filename or ""
|
||||
return moonshark.kv_open(name, filename)
|
||||
end
|
||||
|
||||
function kv.get(name, key, default)
|
||||
if type(name) ~= "string" then error("kv.get: store name must be a string", 2) end
|
||||
if type(key) ~= "string" then error("kv.get: key must be a string", 2) end
|
||||
|
||||
if default ~= nil then
|
||||
return moonshark.kv_get(name, key, default)
|
||||
else
|
||||
return moonshark.kv_get(name, key)
|
||||
end
|
||||
end
|
||||
|
||||
function kv.set(name, key, value)
|
||||
if type(name) ~= "string" then error("kv.set: store name must be a string", 2) end
|
||||
if type(key) ~= "string" then error("kv.set: key must be a string", 2) end
|
||||
if type(value) ~= "string" then error("kv.set: value must be a string", 2) end
|
||||
|
||||
return moonshark.kv_set(name, key, value)
|
||||
end
|
||||
|
||||
function kv.delete(name, key)
|
||||
if type(name) ~= "string" then error("kv.delete: store name must be a string", 2) end
|
||||
if type(key) ~= "string" then error("kv.delete: key must be a string", 2) end
|
||||
|
||||
return moonshark.kv_delete(name, key)
|
||||
end
|
||||
|
||||
function kv.clear(name)
|
||||
if type(name) ~= "string" then error("kv.clear: store name must be a string", 2) end
|
||||
|
||||
return moonshark.kv_clear(name)
|
||||
end
|
||||
|
||||
function kv.has(name, key)
|
||||
if type(name) ~= "string" then error("kv.has: store name must be a string", 2) end
|
||||
if type(key) ~= "string" then error("kv.has: key must be a string", 2) end
|
||||
|
||||
return moonshark.kv_has(name, key)
|
||||
end
|
||||
|
||||
function kv.size(name)
|
||||
if type(name) ~= "string" then error("kv.size: store name must be a string", 2) end
|
||||
|
||||
return moonshark.kv_size(name)
|
||||
end
|
||||
|
||||
function kv.keys(name)
|
||||
if type(name) ~= "string" then error("kv.keys: store name must be a string", 2) end
|
||||
|
||||
return moonshark.kv_keys(name)
|
||||
end
|
||||
|
||||
function kv.values(name)
|
||||
if type(name) ~= "string" then error("kv.values: store name must be a string", 2) end
|
||||
|
||||
return moonshark.kv_values(name)
|
||||
end
|
||||
|
||||
function kv.save(name)
|
||||
if type(name) ~= "string" then error("kv.save: store name must be a string", 2) end
|
||||
|
||||
return moonshark.kv_save(name)
|
||||
end
|
||||
|
||||
function kv.close(name)
|
||||
if type(name) ~= "string" then error("kv.close: store name must be a string", 2) end
|
||||
|
||||
return moonshark.kv_close(name)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function kv.increment(name, key, delta)
|
||||
if type(name) ~= "string" then error("kv.increment: store name must be a string", 2) end
|
||||
if type(key) ~= "string" then error("kv.increment: key must be a string", 2) end
|
||||
delta = delta or 1
|
||||
if type(delta) ~= "number" then error("kv.increment: delta must be a number", 2) end
|
||||
|
||||
return moonshark.kv_increment(name, key, delta)
|
||||
end
|
||||
|
||||
function kv.append(name, key, value, separator)
|
||||
if type(name) ~= "string" then error("kv.append: store name must be a string", 2) end
|
||||
if type(key) ~= "string" then error("kv.append: key must be a string", 2) end
|
||||
if type(value) ~= "string" then error("kv.append: value must be a string", 2) end
|
||||
separator = separator or ""
|
||||
if type(separator) ~= "string" then error("kv.append: separator must be a string", 2) end
|
||||
|
||||
return moonshark.kv_append(name, key, value, separator)
|
||||
end
|
||||
|
||||
function kv.expire(name, key, ttl)
|
||||
if type(name) ~= "string" then error("kv.expire: store name must be a string", 2) end
|
||||
if type(key) ~= "string" then error("kv.expire: key must be a string", 2) end
|
||||
if type(ttl) ~= "number" or ttl <= 0 then error("kv.expire: TTL must be a positive number", 2) end
|
||||
|
||||
return moonshark.kv_expire(name, key, ttl)
|
||||
end
|
||||
|
||||
function kv.cleanup_expired(name)
|
||||
if type(name) ~= "string" then error("kv.cleanup_expired: store name must be a string", 2) end
|
||||
|
||||
return moonshark.kv_cleanup_expired(name)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- OBJECT-ORIENTED INTERFACE
|
||||
-- ======================================================================
|
||||
|
||||
local Store = {}
|
||||
Store.__index = Store
|
||||
|
||||
function kv.create(name, filename)
|
||||
if type(name) ~= "string" then error("kv.create: store name must be a string", 2) end
|
||||
if filename ~= nil and type(filename) ~= "string" then error("kv.create: filename must be a string", 2) end
|
||||
|
||||
local success = kv.open(name, filename)
|
||||
if not success then
|
||||
error("kv.create: failed to open store '" .. name .. "'", 2)
|
||||
end
|
||||
|
||||
return setmetatable({name = name}, Store)
|
||||
end
|
||||
|
||||
function Store:get(key, default)
|
||||
return kv.get(self.name, key, default)
|
||||
end
|
||||
|
||||
function Store:set(key, value)
|
||||
return kv.set(self.name, key, value)
|
||||
end
|
||||
|
||||
function Store:delete(key)
|
||||
return kv.delete(self.name, key)
|
||||
end
|
||||
|
||||
function Store:clear()
|
||||
return kv.clear(self.name)
|
||||
end
|
||||
|
||||
function Store:has(key)
|
||||
return kv.has(self.name, key)
|
||||
end
|
||||
|
||||
function Store:size()
|
||||
return kv.size(self.name)
|
||||
end
|
||||
|
||||
function Store:keys()
|
||||
return kv.keys(self.name)
|
||||
end
|
||||
|
||||
function Store:values()
|
||||
return kv.values(self.name)
|
||||
end
|
||||
|
||||
function Store:save()
|
||||
return kv.save(self.name)
|
||||
end
|
||||
|
||||
function Store:close()
|
||||
return kv.close(self.name)
|
||||
end
|
||||
|
||||
function Store:increment(key, delta)
|
||||
return kv.increment(self.name, key, delta)
|
||||
end
|
||||
|
||||
function Store:append(key, value, separator)
|
||||
return kv.append(self.name, key, value, separator)
|
||||
end
|
||||
|
||||
function Store:expire(key, ttl)
|
||||
return kv.expire(self.name, key, ttl)
|
||||
end
|
||||
|
||||
function Store:cleanup_expired()
|
||||
return kv.cleanup_expired(self.name)
|
||||
end
|
||||
|
||||
return kv
|
||||
@ -1,732 +0,0 @@
|
||||
-- ======================================================================
|
||||
-- ENHANCED CONSTANTS (higher precision)
|
||||
-- ======================================================================
|
||||
|
||||
math.pi = 3.14159265358979323846 -- Replace with higher precision
|
||||
math.tau = 6.28318530717958647693 -- 2*pi, useful for full rotations
|
||||
math.e = 2.71828182845904523536
|
||||
math.phi = 1.61803398874989484820 -- Golden ratio (1 + sqrt(5)) / 2
|
||||
math.sqrt2 = 1.41421356237309504880
|
||||
math.sqrt3 = 1.73205080756887729353
|
||||
math.ln2 = 0.69314718055994530942 -- Natural log of 2
|
||||
math.ln10 = 2.30258509299404568402 -- Natural log of 10
|
||||
math.infinity = 1/0
|
||||
math.nan = 0/0
|
||||
|
||||
-- ======================================================================
|
||||
-- EXTENDED FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Cube root that handles negative numbers correctly
|
||||
function math.cbrt(x)
|
||||
return x < 0 and -(-x)^(1/3) or x^(1/3)
|
||||
end
|
||||
|
||||
-- Euclidean distance - more accurate than sqrt(x*x + y*y)
|
||||
function math.hypot(x, y)
|
||||
return math.sqrt(x * x + y * y)
|
||||
end
|
||||
|
||||
-- IEEE 754 NaN check
|
||||
function math.isnan(x)
|
||||
return x ~= x
|
||||
end
|
||||
|
||||
-- Check if number is finite
|
||||
function math.isfinite(x)
|
||||
return x > -math.infinity and x < math.infinity
|
||||
end
|
||||
|
||||
-- Mathematical sign function
|
||||
function math.sign(x)
|
||||
return x > 0 and 1 or (x < 0 and -1 or 0)
|
||||
end
|
||||
|
||||
-- Constrain value to range
|
||||
function math.clamp(x, min, max)
|
||||
return x < min and min or (x > max and max or x)
|
||||
end
|
||||
|
||||
-- Linear interpolation
|
||||
function math.lerp(a, b, t)
|
||||
return a + (b - a) * t
|
||||
end
|
||||
|
||||
-- Smooth interpolation using Hermite polynomial
|
||||
function math.smoothstep(a, b, t)
|
||||
t = math.clamp((t - a) / (b - a), 0, 1)
|
||||
return t * t * (3 - 2 * t)
|
||||
end
|
||||
|
||||
-- Map value from input range to output range
|
||||
function math.map(x, in_min, in_max, out_min, out_max)
|
||||
return (x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min
|
||||
end
|
||||
|
||||
-- Round to nearest integer
|
||||
function math.round(x)
|
||||
return x >= 0 and math.floor(x + 0.5) or math.ceil(x - 0.5)
|
||||
end
|
||||
|
||||
-- Round to specified decimal places
|
||||
function math.roundto(x, decimals)
|
||||
local mult = 10 ^ (decimals or 0)
|
||||
return math.floor(x * mult + 0.5) / mult
|
||||
end
|
||||
|
||||
-- Normalize angle to [-π, π] range
|
||||
function math.normalize_angle(angle)
|
||||
return angle - 2 * math.pi * math.floor((angle + math.pi) / (2 * math.pi))
|
||||
end
|
||||
|
||||
-- 2D Euclidean distance
|
||||
function math.distance(x1, y1, x2, y2)
|
||||
local dx, dy = x2 - x1, y2 - y1
|
||||
return math.sqrt(dx * dx + dy * dy)
|
||||
end
|
||||
|
||||
-- Factorial with bounds checking
|
||||
function math.factorial(n)
|
||||
if n < 0 or n ~= math.floor(n) or n > 170 then
|
||||
return nil
|
||||
end
|
||||
local result = 1
|
||||
for i = 2, n do
|
||||
result = result * i
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- Greatest common divisor using Euclidean algorithm
|
||||
function math.gcd(a, b)
|
||||
a, b = math.floor(math.abs(a)), math.floor(math.abs(b))
|
||||
while b ~= 0 do
|
||||
a, b = b, a % b
|
||||
end
|
||||
return a
|
||||
end
|
||||
|
||||
-- Least common multiple
|
||||
function math.lcm(a, b)
|
||||
if a == 0 or b == 0 then return 0 end
|
||||
return math.abs(a * b) / math.gcd(a, b)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- ENHANCED RANDOM FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Random float in range
|
||||
function math.randomf(min, max)
|
||||
if not min and not max then
|
||||
return math.random()
|
||||
elseif not max then
|
||||
max = min
|
||||
min = 0
|
||||
end
|
||||
return min + math.random() * (max - min)
|
||||
end
|
||||
|
||||
-- Random integer in range
|
||||
function math.randint(min, max)
|
||||
if not max then
|
||||
max = min
|
||||
min = 1
|
||||
end
|
||||
return math.floor(math.random() * (max - min + 1) + min)
|
||||
end
|
||||
|
||||
-- Random boolean with probability
|
||||
function math.randboolean(p)
|
||||
p = p or 0.5
|
||||
return math.random() < p
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- STATISTICS FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function math.sum(t)
|
||||
if type(t) ~= "table" then return 0 end
|
||||
local sum = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
sum = sum + t[i]
|
||||
end
|
||||
end
|
||||
return sum
|
||||
end
|
||||
|
||||
function math.mean(t)
|
||||
if type(t) ~= "table" or #t == 0 then return 0 end
|
||||
local sum = 0
|
||||
local count = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
sum = sum + t[i]
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
return count > 0 and sum / count or 0
|
||||
end
|
||||
|
||||
function math.median(t)
|
||||
if type(t) ~= "table" or #t == 0 then return 0 end
|
||||
local nums = {}
|
||||
local count = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
count = count + 1
|
||||
nums[count] = t[i]
|
||||
end
|
||||
end
|
||||
if count == 0 then return 0 end
|
||||
table.sort(nums)
|
||||
if count % 2 == 0 then
|
||||
return (nums[count/2] + nums[count/2 + 1]) / 2
|
||||
else
|
||||
return nums[math.ceil(count/2)]
|
||||
end
|
||||
end
|
||||
|
||||
function math.variance(t)
|
||||
if type(t) ~= "table" then return 0 end
|
||||
local count = 0
|
||||
local m = math.mean(t)
|
||||
local sum = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
local dev = t[i] - m
|
||||
sum = sum + dev * dev
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
return count > 1 and sum / count or 0
|
||||
end
|
||||
|
||||
function math.stdev(t)
|
||||
return math.sqrt(math.variance(t))
|
||||
end
|
||||
|
||||
function math.pvariance(t)
|
||||
if type(t) ~= "table" then return 0 end
|
||||
local count = 0
|
||||
local m = math.mean(t)
|
||||
local sum = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
local dev = t[i] - m
|
||||
sum = sum + dev * dev
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
return count > 0 and sum / count or 0
|
||||
end
|
||||
|
||||
function math.pstdev(t)
|
||||
return math.sqrt(math.pvariance(t))
|
||||
end
|
||||
|
||||
function math.mode(t)
|
||||
if type(t) ~= "table" or #t == 0 then return nil end
|
||||
local counts = {}
|
||||
local most_frequent = nil
|
||||
local max_count = 0
|
||||
for i=1, #t do
|
||||
local v = t[i]
|
||||
counts[v] = (counts[v] or 0) + 1
|
||||
if counts[v] > max_count then
|
||||
max_count = counts[v]
|
||||
most_frequent = v
|
||||
end
|
||||
end
|
||||
return most_frequent
|
||||
end
|
||||
|
||||
function math.minmax(t)
|
||||
if type(t) ~= "table" or #t == 0 then return nil, nil end
|
||||
local min, max
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
min = t[i]
|
||||
max = t[i]
|
||||
break
|
||||
end
|
||||
end
|
||||
if min == nil then return nil, nil end
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
if t[i] < min then min = t[i] end
|
||||
if t[i] > max then max = t[i] end
|
||||
end
|
||||
end
|
||||
return min, max
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- 2D VECTOR OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
math.vec2 = {
|
||||
new = function(x, y)
|
||||
return {x = x or 0, y = y or 0}
|
||||
end,
|
||||
|
||||
copy = function(v)
|
||||
return {x = v.x, y = v.y}
|
||||
end,
|
||||
|
||||
add = function(a, b)
|
||||
return {x = a.x + b.x, y = a.y + b.y}
|
||||
end,
|
||||
|
||||
sub = function(a, b)
|
||||
return {x = a.x - b.x, y = a.y - b.y}
|
||||
end,
|
||||
|
||||
mul = function(a, b)
|
||||
if type(b) == "number" then
|
||||
return {x = a.x * b, y = a.y * b}
|
||||
end
|
||||
return {x = a.x * b.x, y = a.y * b.y}
|
||||
end,
|
||||
|
||||
div = function(a, b)
|
||||
if type(b) == "number" then
|
||||
local inv = 1 / b
|
||||
return {x = a.x * inv, y = a.y * inv}
|
||||
end
|
||||
return {x = a.x / b.x, y = a.y / b.y}
|
||||
end,
|
||||
|
||||
dot = function(a, b)
|
||||
return a.x * b.x + a.y * b.y
|
||||
end,
|
||||
|
||||
length = function(v)
|
||||
return math.sqrt(v.x * v.x + v.y * v.y)
|
||||
end,
|
||||
|
||||
length_squared = function(v)
|
||||
return v.x * v.x + v.y * v.y
|
||||
end,
|
||||
|
||||
distance = function(a, b)
|
||||
local dx, dy = b.x - a.x, b.y - a.y
|
||||
return math.sqrt(dx * dx + dy * dy)
|
||||
end,
|
||||
|
||||
distance_squared = function(a, b)
|
||||
local dx, dy = b.x - a.x, b.y - a.y
|
||||
return dx * dx + dy * dy
|
||||
end,
|
||||
|
||||
normalize = function(v)
|
||||
local len = math.sqrt(v.x * v.x + v.y * v.y)
|
||||
if len > 1e-10 then
|
||||
local inv_len = 1 / len
|
||||
return {x = v.x * inv_len, y = v.y * inv_len}
|
||||
end
|
||||
return {x = 0, y = 0}
|
||||
end,
|
||||
|
||||
rotate = function(v, angle)
|
||||
local c, s = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
x = v.x * c - v.y * s,
|
||||
y = v.x * s + v.y * c
|
||||
}
|
||||
end,
|
||||
|
||||
angle = function(v)
|
||||
return math.atan2(v.y, v.x)
|
||||
end,
|
||||
|
||||
lerp = function(a, b, t)
|
||||
t = math.clamp(t, 0, 1)
|
||||
return {
|
||||
x = a.x + (b.x - a.x) * t,
|
||||
y = a.y + (b.y - a.y) * t
|
||||
}
|
||||
end,
|
||||
|
||||
reflect = function(v, normal)
|
||||
local dot = v.x * normal.x + v.y * normal.y
|
||||
return {
|
||||
x = v.x - 2 * dot * normal.x,
|
||||
y = v.y - 2 * dot * normal.y
|
||||
}
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- 3D VECTOR OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
math.vec3 = {
|
||||
new = function(x, y, z)
|
||||
return {x = x or 0, y = y or 0, z = z or 0}
|
||||
end,
|
||||
|
||||
copy = function(v)
|
||||
return {x = v.x, y = v.y, z = v.z}
|
||||
end,
|
||||
|
||||
add = function(a, b)
|
||||
return {x = a.x + b.x, y = a.y + b.y, z = a.z + b.z}
|
||||
end,
|
||||
|
||||
sub = function(a, b)
|
||||
return {x = a.x - b.x, y = a.y - b.y, z = a.z - b.z}
|
||||
end,
|
||||
|
||||
mul = function(a, b)
|
||||
if type(b) == "number" then
|
||||
return {x = a.x * b, y = a.y * b, z = a.z * b}
|
||||
end
|
||||
return {x = a.x * b.x, y = a.y * b.y, z = a.z * b.z}
|
||||
end,
|
||||
|
||||
div = function(a, b)
|
||||
if type(b) == "number" then
|
||||
local inv = 1 / b
|
||||
return {x = a.x * inv, y = a.y * inv, z = a.z * inv}
|
||||
end
|
||||
return {x = a.x / b.x, y = a.y / b.y, z = a.z / b.z}
|
||||
end,
|
||||
|
||||
dot = function(a, b)
|
||||
return a.x * b.x + a.y * b.y + a.z * b.z
|
||||
end,
|
||||
|
||||
cross = function(a, b)
|
||||
return {
|
||||
x = a.y * b.z - a.z * b.y,
|
||||
y = a.z * b.x - a.x * b.z,
|
||||
z = a.x * b.y - a.y * b.x
|
||||
}
|
||||
end,
|
||||
|
||||
length = function(v)
|
||||
return math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
|
||||
end,
|
||||
|
||||
length_squared = function(v)
|
||||
return v.x * v.x + v.y * v.y + v.z * v.z
|
||||
end,
|
||||
|
||||
distance = function(a, b)
|
||||
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
|
||||
return math.sqrt(dx * dx + dy * dy + dz * dz)
|
||||
end,
|
||||
|
||||
distance_squared = function(a, b)
|
||||
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
|
||||
return dx * dx + dy * dy + dz * dz
|
||||
end,
|
||||
|
||||
normalize = function(v)
|
||||
local len = math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
|
||||
if len > 1e-10 then
|
||||
local inv_len = 1 / len
|
||||
return {x = v.x * inv_len, y = v.y * inv_len, z = v.z * inv_len}
|
||||
end
|
||||
return {x = 0, y = 0, z = 0}
|
||||
end,
|
||||
|
||||
lerp = function(a, b, t)
|
||||
t = math.clamp(t, 0, 1)
|
||||
return {
|
||||
x = a.x + (b.x - a.x) * t,
|
||||
y = a.y + (b.y - a.y) * t,
|
||||
z = a.z + (b.z - a.z) * t
|
||||
}
|
||||
end,
|
||||
|
||||
reflect = function(v, normal)
|
||||
local dot = v.x * normal.x + v.y * normal.y + v.z * normal.z
|
||||
return {
|
||||
x = v.x - 2 * dot * normal.x,
|
||||
y = v.y - 2 * dot * normal.y,
|
||||
z = v.z - 2 * dot * normal.z
|
||||
}
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- MATRIX OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
math.mat2 = {
|
||||
new = function(a, b, c, d)
|
||||
return {
|
||||
{a or 1, b or 0},
|
||||
{c or 0, d or 1}
|
||||
}
|
||||
end,
|
||||
|
||||
identity = function()
|
||||
return {{1, 0}, {0, 1}}
|
||||
end,
|
||||
|
||||
mul = function(a, b)
|
||||
return {
|
||||
{
|
||||
a[1][1] * b[1][1] + a[1][2] * b[2][1],
|
||||
a[1][1] * b[1][2] + a[1][2] * b[2][2]
|
||||
},
|
||||
{
|
||||
a[2][1] * b[1][1] + a[2][2] * b[2][1],
|
||||
a[2][1] * b[1][2] + a[2][2] * b[2][2]
|
||||
}
|
||||
}
|
||||
end,
|
||||
|
||||
det = function(m)
|
||||
return m[1][1] * m[2][2] - m[1][2] * m[2][1]
|
||||
end,
|
||||
|
||||
inverse = function(m)
|
||||
local det = m[1][1] * m[2][2] - m[1][2] * m[2][1]
|
||||
if math.abs(det) < 1e-10 then
|
||||
return nil
|
||||
end
|
||||
local inv_det = 1 / det
|
||||
return {
|
||||
{m[2][2] * inv_det, -m[1][2] * inv_det},
|
||||
{-m[2][1] * inv_det, m[1][1] * inv_det}
|
||||
}
|
||||
end,
|
||||
|
||||
rotation = function(angle)
|
||||
local cos, sin = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
{cos, -sin},
|
||||
{sin, cos}
|
||||
}
|
||||
end,
|
||||
|
||||
transform = function(m, v)
|
||||
return {
|
||||
x = m[1][1] * v.x + m[1][2] * v.y,
|
||||
y = m[2][1] * v.x + m[2][2] * v.y
|
||||
}
|
||||
end,
|
||||
|
||||
scale = function(sx, sy)
|
||||
sy = sy or sx
|
||||
return {
|
||||
{sx, 0},
|
||||
{0, sy}
|
||||
}
|
||||
end
|
||||
}
|
||||
|
||||
math.mat3 = {
|
||||
identity = function()
|
||||
return {
|
||||
{1, 0, 0},
|
||||
{0, 1, 0},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
transform = function(x, y, angle, sx, sy)
|
||||
sx = sx or 1
|
||||
sy = sy or sx
|
||||
local cos, sin = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
{cos * sx, -sin * sy, x},
|
||||
{sin * sx, cos * sy, y},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
mul = function(a, b)
|
||||
local result = {
|
||||
{0, 0, 0},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0}
|
||||
}
|
||||
for i = 1, 3 do
|
||||
for j = 1, 3 do
|
||||
for k = 1, 3 do
|
||||
result[i][j] = result[i][j] + a[i][k] * b[k][j]
|
||||
end
|
||||
end
|
||||
end
|
||||
return result
|
||||
end,
|
||||
|
||||
transform_point = function(m, v)
|
||||
local x = m[1][1] * v.x + m[1][2] * v.y + m[1][3]
|
||||
local y = m[2][1] * v.x + m[2][2] * v.y + m[2][3]
|
||||
local w = m[3][1] * v.x + m[3][2] * v.y + m[3][3]
|
||||
if math.abs(w) < 1e-10 then
|
||||
return {x = 0, y = 0}
|
||||
end
|
||||
return {x = x / w, y = y / w}
|
||||
end,
|
||||
|
||||
translation = function(x, y)
|
||||
return {
|
||||
{1, 0, x},
|
||||
{0, 1, y},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
rotation = function(angle)
|
||||
local cos, sin = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
{cos, -sin, 0},
|
||||
{sin, cos, 0},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
scale = function(sx, sy)
|
||||
sy = sy or sx
|
||||
return {
|
||||
{sx, 0, 0},
|
||||
{0, sy, 0},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
det = function(m)
|
||||
return m[1][1] * (m[2][2] * m[3][3] - m[2][3] * m[3][2]) -
|
||||
m[1][2] * (m[2][1] * m[3][3] - m[2][3] * m[3][1]) +
|
||||
m[1][3] * (m[2][1] * m[3][2] - m[2][2] * m[3][1])
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- GEOMETRY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
math.geometry = {
|
||||
point_line_distance = function(px, py, x1, y1, x2, y2)
|
||||
local dx, dy = x2 - x1, y2 - y1
|
||||
local len_sq = dx * dx + dy * dy
|
||||
if len_sq < 1e-10 then
|
||||
return math.distance(px, py, x1, y1)
|
||||
end
|
||||
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
|
||||
t = math.clamp(t, 0, 1)
|
||||
local nearestX = x1 + t * dx
|
||||
local nearestY = y1 + t * dy
|
||||
return math.distance(px, py, nearestX, nearestY)
|
||||
end,
|
||||
|
||||
point_in_polygon = function(px, py, vertices)
|
||||
local inside = false
|
||||
local n = #vertices / 2
|
||||
for i = 1, n do
|
||||
local x1, y1 = vertices[i*2-1], vertices[i*2]
|
||||
local x2, y2
|
||||
if i == n then
|
||||
x2, y2 = vertices[1], vertices[2]
|
||||
else
|
||||
x2, y2 = vertices[i*2+1], vertices[i*2+2]
|
||||
end
|
||||
if ((y1 > py) ~= (y2 > py)) and
|
||||
(px < (x2 - x1) * (py - y1) / (y2 - y1) + x1) then
|
||||
inside = not inside
|
||||
end
|
||||
end
|
||||
return inside
|
||||
end,
|
||||
|
||||
triangle_area = function(x1, y1, x2, y2, x3, y3)
|
||||
return math.abs((x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2)) / 2)
|
||||
end,
|
||||
|
||||
point_in_triangle = function(px, py, x1, y1, x2, y2, x3, y3)
|
||||
local area = math.geometry.triangle_area(x1, y1, x2, y2, x3, y3)
|
||||
local area1 = math.geometry.triangle_area(px, py, x2, y2, x3, y3)
|
||||
local area2 = math.geometry.triangle_area(x1, y1, px, py, x3, y3)
|
||||
local area3 = math.geometry.triangle_area(x1, y1, x2, y2, px, py)
|
||||
return math.abs(area - (area1 + area2 + area3)) < 1e-10
|
||||
end,
|
||||
|
||||
line_intersect = function(x1, y1, x2, y2, x3, y3, x4, y4)
|
||||
local d = (y4 - y3) * (x2 - x1) - (x4 - x3) * (y2 - y1)
|
||||
if math.abs(d) < 1e-10 then
|
||||
return false, nil, nil
|
||||
end
|
||||
local ua = ((x4 - x3) * (y1 - y3) - (y4 - y3) * (x1 - x3)) / d
|
||||
local ub = ((x2 - x1) * (y1 - y3) - (y2 - y1) * (x1 - x3)) / d
|
||||
if ua >= 0 and ua <= 1 and ub >= 0 and ub <= 1 then
|
||||
local x = x1 + ua * (x2 - x1)
|
||||
local y = y1 + ua * (y2 - y1)
|
||||
return true, x, y
|
||||
end
|
||||
return false, nil, nil
|
||||
end,
|
||||
|
||||
closest_point_on_segment = function(px, py, x1, y1, x2, y2)
|
||||
local dx, dy = x2 - x1, y2 - y1
|
||||
local len_sq = dx * dx + dy * dy
|
||||
if len_sq < 1e-10 then
|
||||
return x1, y1
|
||||
end
|
||||
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
|
||||
t = math.clamp(t, 0, 1)
|
||||
return x1 + t * dx, y1 + t * dy
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- INTERPOLATION FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
math.interpolation = {
|
||||
bezier = function(t, p0, p1, p2, p3)
|
||||
t = math.clamp(t, 0, 1)
|
||||
local t2 = t * t
|
||||
local t3 = t2 * t
|
||||
local mt = 1 - t
|
||||
local mt2 = mt * mt
|
||||
local mt3 = mt2 * mt
|
||||
return p0 * mt3 + 3 * p1 * mt2 * t + 3 * p2 * mt * t2 + p3 * t3
|
||||
end,
|
||||
|
||||
catmull_rom = function(t, p0, p1, p2, p3)
|
||||
t = math.clamp(t, 0, 1)
|
||||
local t2 = t * t
|
||||
local t3 = t2 * t
|
||||
return 0.5 * (
|
||||
(2 * p1) +
|
||||
(-p0 + p2) * t +
|
||||
(2 * p0 - 5 * p1 + 4 * p2 - p3) * t2 +
|
||||
(-p0 + 3 * p1 - 3 * p2 + p3) * t3
|
||||
)
|
||||
end,
|
||||
|
||||
hermite = function(t, p0, p1, m0, m1)
|
||||
t = math.clamp(t, 0, 1)
|
||||
local t2 = t * t
|
||||
local t3 = t2 * t
|
||||
local h00 = 2 * t3 - 3 * t2 + 1
|
||||
local h10 = t3 - 2 * t2 + t
|
||||
local h01 = -2 * t3 + 3 * t2
|
||||
local h11 = t3 - t2
|
||||
return h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
|
||||
end,
|
||||
|
||||
quadratic_bezier = function(t, p0, p1, p2)
|
||||
t = math.clamp(t, 0, 1)
|
||||
local mt = 1 - t
|
||||
return mt * mt * p0 + 2 * mt * t * p1 + t * t * p2
|
||||
end,
|
||||
|
||||
step = function(t, edge, x)
|
||||
return t < edge and 0 or x
|
||||
end,
|
||||
|
||||
smootherstep = function(edge0, edge1, x)
|
||||
local t = math.clamp((x - edge0) / (edge1 - edge0), 0, 1)
|
||||
return t * t * t * (t * (t * 6 - 15) + 10)
|
||||
end
|
||||
}
|
||||
@ -1,814 +0,0 @@
|
||||
local mysql = {}
|
||||
|
||||
local Connection = {}
|
||||
Connection.__index = Connection
|
||||
|
||||
function Connection:close()
|
||||
if self._id then
|
||||
local ok = moonshark.sql_close(self._id)
|
||||
self._id = nil
|
||||
return ok
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
function Connection:ping()
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_ping(self._id)
|
||||
end
|
||||
|
||||
function Connection:query(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_query(self._id, query_str:normalize_whitespace(), ...)
|
||||
end
|
||||
|
||||
function Connection:exec(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_exec(self._id, query_str:normalize_whitespace(), ...)
|
||||
end
|
||||
|
||||
function Connection:query_row(query_str, ...)
|
||||
local results = self:query(query_str, ...)
|
||||
return results and #results > 0 and results[1] or nil
|
||||
end
|
||||
|
||||
function Connection:query_value(query_str, ...)
|
||||
local row = self:query_row(query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function Connection:begin()
|
||||
local result = self:exec("BEGIN")
|
||||
if result then
|
||||
return {
|
||||
conn = self,
|
||||
active = true,
|
||||
commit = function(tx)
|
||||
if tx.active then
|
||||
tx.active = false
|
||||
return tx.conn:exec("COMMIT")
|
||||
end
|
||||
return false
|
||||
end,
|
||||
rollback = function(tx)
|
||||
if tx.active then
|
||||
tx.active = false
|
||||
return tx.conn:exec("ROLLBACK")
|
||||
end
|
||||
return false
|
||||
end,
|
||||
savepoint = function(tx, name)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
if name:is_blank() then error("Savepoint name cannot be empty") end
|
||||
return tx.conn:exec("SAVEPOINT {{name}}":parse({name = name}))
|
||||
end,
|
||||
rollback_to = function(tx, name)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
if name:is_blank() then error("Savepoint name cannot be empty") end
|
||||
return tx.conn:exec("ROLLBACK TO SAVEPOINT {{name}}":parse({name = name}))
|
||||
end,
|
||||
query = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:query(query_str, ...)
|
||||
end,
|
||||
exec = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:exec(query_str, ...)
|
||||
end,
|
||||
query_row = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:query_row(query_str, ...)
|
||||
end,
|
||||
query_value = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:query_value(query_str, ...)
|
||||
end
|
||||
}
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function Connection:insert(table_name, data)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys = table.keys(data)
|
||||
local values = table.values(data)
|
||||
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
|
||||
|
||||
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}})":parse({
|
||||
table = table_name,
|
||||
columns = keys:join(", "),
|
||||
placeholders = placeholders
|
||||
})
|
||||
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:upsert(table_name, data, update_data)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys = table.keys(data)
|
||||
local values = table.values(data)
|
||||
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
|
||||
|
||||
-- Use update_data if provided, otherwise update with same data
|
||||
local update_source = update_data or data
|
||||
local updates = table.map(table.keys(update_source), function(key)
|
||||
return key .. " = VALUES(" .. key .. ")"
|
||||
end)
|
||||
|
||||
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}}) ON DUPLICATE KEY UPDATE {{updates}}":parse({
|
||||
table = table_name,
|
||||
columns = keys:join(", "),
|
||||
placeholders = placeholders,
|
||||
updates = updates:join(", ")
|
||||
})
|
||||
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:replace(table_name, data)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys = table.keys(data)
|
||||
local values = table.values(data)
|
||||
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
|
||||
|
||||
local query = "REPLACE INTO {{table}} ({{columns}}) VALUES ({{placeholders}})":parse({
|
||||
table = table_name,
|
||||
columns = keys:join(", "),
|
||||
placeholders = placeholders
|
||||
})
|
||||
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:update(table_name, data, where_clause, ...)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if where_clause:is_blank() then
|
||||
error("WHERE clause cannot be empty for UPDATE")
|
||||
end
|
||||
|
||||
local keys = table.keys(data)
|
||||
local values = table.values(data)
|
||||
local sets = table.map(keys, function(key) return key .. " = ?" end)
|
||||
|
||||
local query = "UPDATE {{table}} SET {{sets}} WHERE {{where}}":parse({
|
||||
table = table_name,
|
||||
sets = sets:join(", "),
|
||||
where = where_clause
|
||||
})
|
||||
|
||||
table.extend(values, {...})
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:delete(table_name, where_clause, ...)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if where_clause:is_blank() then
|
||||
error("WHERE clause cannot be empty for DELETE")
|
||||
end
|
||||
|
||||
local query = "DELETE FROM {{table}} WHERE {{where}}":parse({
|
||||
table = table_name,
|
||||
where = where_clause
|
||||
})
|
||||
return self:exec(query, ...)
|
||||
end
|
||||
|
||||
function Connection:select(table_name, columns, where_clause, ...)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
columns = columns or "*"
|
||||
if type(columns) == "table" then
|
||||
columns = table.concat(columns, ", ")
|
||||
end
|
||||
|
||||
if where_clause and not where_clause:is_blank() then
|
||||
local query = "SELECT {{columns}} FROM {{table}} WHERE {{where}}":parse({
|
||||
columns = columns,
|
||||
table = table_name,
|
||||
where = where_clause
|
||||
})
|
||||
return self:query(query, ...)
|
||||
else
|
||||
local query = "SELECT {{columns}} FROM {{table}}":parse({
|
||||
columns = columns,
|
||||
table = table_name
|
||||
})
|
||||
return self:query(query)
|
||||
end
|
||||
end
|
||||
|
||||
-- MySQL schema helpers
|
||||
function Connection:database_exists(database_name)
|
||||
if database_name:is_blank() then return false end
|
||||
return self:query_value("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?",
|
||||
database_name:trim()) ~= nil
|
||||
end
|
||||
|
||||
function Connection:table_exists(table_name, database_name)
|
||||
if table_name:is_blank() then return false end
|
||||
database_name = database_name or self:current_database()
|
||||
if not database_name then return false end
|
||||
|
||||
return self:query_value("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?",
|
||||
database_name:trim(), table_name:trim()) ~= nil
|
||||
end
|
||||
|
||||
function Connection:column_exists(table_name, column_name, database_name)
|
||||
if table_name:is_blank() or column_name:is_blank() then return false end
|
||||
database_name = database_name or self:current_database()
|
||||
if not database_name then return false end
|
||||
|
||||
return self:query_value([[
|
||||
SELECT COLUMN_NAME FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?
|
||||
]], database_name:trim(), table_name:trim(), column_name:trim()) ~= nil
|
||||
end
|
||||
|
||||
function Connection:create_database(database_name, charset, collation)
|
||||
if database_name:is_blank() then
|
||||
error("Database name cannot be empty")
|
||||
end
|
||||
|
||||
local charset_clause = charset and " CHARACTER SET " .. charset or ""
|
||||
local collation_clause = collation and " COLLATE " .. collation or ""
|
||||
|
||||
return self:exec("CREATE DATABASE IF NOT EXISTS {{database}}{{charset}}{{collation}}":parse({
|
||||
database = database_name,
|
||||
charset = charset_clause,
|
||||
collation = collation_clause
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:drop_database(database_name)
|
||||
if database_name:is_blank() then
|
||||
error("Database name cannot be empty")
|
||||
end
|
||||
return self:exec("DROP DATABASE IF EXISTS {{database}}":parse({database = database_name}))
|
||||
end
|
||||
|
||||
function Connection:create_table(table_name, schema, engine, charset)
|
||||
if table_name:is_blank() or schema:is_blank() then
|
||||
error("Table name and schema cannot be empty")
|
||||
end
|
||||
|
||||
local engine_clause = engine and " ENGINE=" .. engine:upper() or ""
|
||||
local charset_clause = charset and " CHARACTER SET " .. charset or ""
|
||||
|
||||
return self:exec("CREATE TABLE IF NOT EXISTS {{table}} ({{schema}}){{engine}}{{charset}}":parse({
|
||||
table = table_name,
|
||||
schema = schema:trim(),
|
||||
engine = engine_clause,
|
||||
charset = charset_clause
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:drop_table(table_name)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
return self:exec("DROP TABLE IF EXISTS {{table}}":parse({table = table_name}))
|
||||
end
|
||||
|
||||
function Connection:add_column(table_name, column_def, position)
|
||||
if table_name:is_blank() or column_def:is_blank() then
|
||||
error("Table name and column definition cannot be empty")
|
||||
end
|
||||
|
||||
local position_clause = position and " " .. position or ""
|
||||
return self:exec("ALTER TABLE {{table}} ADD COLUMN {{column}}{{position}}":parse({
|
||||
table = table_name,
|
||||
column = column_def:trim(),
|
||||
position = position_clause
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:drop_column(table_name, column_name)
|
||||
if table_name:is_blank() or column_name:is_blank() then
|
||||
error("Table name and column name cannot be empty")
|
||||
end
|
||||
return self:exec("ALTER TABLE {{table}} DROP COLUMN {{column}}":parse({
|
||||
table = table_name,
|
||||
column = column_name
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:modify_column(table_name, column_def)
|
||||
if table_name:is_blank() or column_def:is_blank() then
|
||||
error("Table name and column definition cannot be empty")
|
||||
end
|
||||
return self:exec("ALTER TABLE {{table}} MODIFY COLUMN {{column}}":parse({
|
||||
table = table_name,
|
||||
column = column_def:trim()
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:rename_table(old_name, new_name)
|
||||
if old_name:is_blank() or new_name:is_blank() then
|
||||
error("Old and new table names cannot be empty")
|
||||
end
|
||||
return self:exec("RENAME TABLE {{old}} TO {{new}}":parse({old = old_name, new = new_name}))
|
||||
end
|
||||
|
||||
function Connection:create_index(index_name, table_name, columns, unique, type)
|
||||
if index_name:is_blank() or table_name:is_blank() then
|
||||
error("Index name and table name cannot be empty")
|
||||
end
|
||||
|
||||
local unique_clause = unique and "UNIQUE " or ""
|
||||
local type_clause = type and " USING " .. type:upper() or ""
|
||||
local columns_str = type(columns) == "table" and table.concat(columns, ", ") or tostring(columns)
|
||||
|
||||
return self:exec("CREATE {{unique}}INDEX {{index}} ON {{table}} ({{columns}}){{type}}":parse({
|
||||
unique = unique_clause,
|
||||
index = index_name,
|
||||
table = table_name,
|
||||
columns = columns_str,
|
||||
type = type_clause
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:drop_index(index_name, table_name)
|
||||
if index_name:is_blank() or table_name:is_blank() then
|
||||
error("Index name and table name cannot be empty")
|
||||
end
|
||||
return self:exec("DROP INDEX {{index}} ON {{table}}":parse({
|
||||
index = index_name,
|
||||
table = table_name
|
||||
}))
|
||||
end
|
||||
|
||||
-- MySQL maintenance functions
|
||||
function Connection:optimize(table_name)
|
||||
local table_clause = table_name and " " .. table_name or ""
|
||||
return self:query("OPTIMIZE TABLE{{table}}":parse({table = table_clause}))
|
||||
end
|
||||
|
||||
function Connection:repair(table_name)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty for REPAIR")
|
||||
end
|
||||
return self:query("REPAIR TABLE {{table}}":parse({table = table_name}))
|
||||
end
|
||||
|
||||
function Connection:check_table(table_name, options)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty for CHECK")
|
||||
end
|
||||
|
||||
local options_clause = ""
|
||||
if options then
|
||||
local valid_options = {"QUICK", "FAST", "MEDIUM", "EXTENDED", "CHANGED"}
|
||||
local options_upper = options:upper()
|
||||
|
||||
if table.contains(valid_options, options_upper) then
|
||||
options_clause = " " .. options_upper
|
||||
end
|
||||
end
|
||||
|
||||
return self:query("CHECK TABLE {{table}}{{options}}":parse({
|
||||
table = table_name,
|
||||
options = options_clause
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:analyze_table(table_name)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty for ANALYZE")
|
||||
end
|
||||
return self:query("ANALYZE TABLE {{table}}":parse({table = table_name}))
|
||||
end
|
||||
|
||||
-- MySQL settings and introspection
|
||||
function Connection:show(what)
|
||||
if what:is_blank() then
|
||||
error("SHOW parameter cannot be empty")
|
||||
end
|
||||
return self:query("SHOW {{what}}":parse({what = what:upper()}))
|
||||
end
|
||||
|
||||
function Connection:current_database()
|
||||
return self:query_value("SELECT DATABASE() AS db")
|
||||
end
|
||||
|
||||
function Connection:version()
|
||||
return self:query_value("SELECT VERSION() AS version")
|
||||
end
|
||||
|
||||
function Connection:connection_id()
|
||||
return self:query_value("SELECT CONNECTION_ID()")
|
||||
end
|
||||
|
||||
function Connection:list_databases()
|
||||
return self:query("SHOW DATABASES")
|
||||
end
|
||||
|
||||
function Connection:list_tables(database_name)
|
||||
if database_name and not database_name:is_blank() then
|
||||
return self:query("SHOW TABLES FROM {{database}}":parse({database = database_name}))
|
||||
else
|
||||
return self:query("SHOW TABLES")
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:describe_table(table_name)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
return self:query("DESCRIBE {{table}}":parse({table = table_name}))
|
||||
end
|
||||
|
||||
function Connection:show_create_table(table_name)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
return self:query("SHOW CREATE TABLE {{table}}":parse({table = table_name}))
|
||||
end
|
||||
|
||||
function Connection:show_indexes(table_name)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
return self:query("SHOW INDEXES FROM {{table}}":parse({table = table_name}))
|
||||
end
|
||||
|
||||
function Connection:show_table_status(table_name)
|
||||
if table_name and not table_name:is_blank() then
|
||||
return self:query("SHOW TABLE STATUS LIKE ?", table_name)
|
||||
else
|
||||
return self:query("SHOW TABLE STATUS")
|
||||
end
|
||||
end
|
||||
|
||||
-- MySQL user and privilege management
|
||||
function Connection:create_user(username, password, host)
|
||||
if username:is_blank() or password:is_blank() then
|
||||
error("Username and password cannot be empty")
|
||||
end
|
||||
|
||||
host = host or "%"
|
||||
return self:exec("CREATE USER '{{username}}'@'{{host}}' IDENTIFIED BY ?":parse({
|
||||
username = username,
|
||||
host = host
|
||||
}), password)
|
||||
end
|
||||
|
||||
function Connection:drop_user(username, host)
|
||||
if username:is_blank() then
|
||||
error("Username cannot be empty")
|
||||
end
|
||||
|
||||
host = host or "%"
|
||||
return self:exec("DROP USER IF EXISTS '{{username}}'@'{{host}}'":parse({
|
||||
username = username,
|
||||
host = host
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:grant(privileges, database, table_name, username, host)
|
||||
if privileges:is_blank() or database:is_blank() or username:is_blank() then
|
||||
error("Privileges, database, and username cannot be empty")
|
||||
end
|
||||
|
||||
host = host or "%"
|
||||
table_name = table_name or "*"
|
||||
local object = database .. "." .. table_name
|
||||
|
||||
return self:exec("GRANT {{privileges}} ON {{object}} TO '{{username}}'@'{{host}}'":parse({
|
||||
privileges = privileges:upper(),
|
||||
object = object,
|
||||
username = username,
|
||||
host = host
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:revoke(privileges, database, table_name, username, host)
|
||||
if privileges:is_blank() or database:is_blank() or username:is_blank() then
|
||||
error("Privileges, database, and username cannot be empty")
|
||||
end
|
||||
|
||||
host = host or "%"
|
||||
table_name = table_name or "*"
|
||||
local object = database .. "." .. table_name
|
||||
|
||||
return self:exec("REVOKE {{privileges}} ON {{object}} FROM '{{username}}'@'{{host}}'":parse({
|
||||
privileges = privileges:upper(),
|
||||
object = object,
|
||||
username = username,
|
||||
host = host
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:flush_privileges()
|
||||
return self:exec("FLUSH PRIVILEGES")
|
||||
end
|
||||
|
||||
-- MySQL variables and configuration
|
||||
function Connection:set_variable(name, value, global)
|
||||
if name:is_blank() then
|
||||
error("Variable name cannot be empty")
|
||||
end
|
||||
|
||||
local scope = global and "GLOBAL " or "SESSION "
|
||||
return self:exec("SET {{scope}}{{name}} = ?":parse({scope = scope, name = name}), value)
|
||||
end
|
||||
|
||||
function Connection:get_variable(name, global)
|
||||
if name:is_blank() then
|
||||
error("Variable name cannot be empty")
|
||||
end
|
||||
|
||||
local scope = global and "global." or "session."
|
||||
return self:query_value("SELECT @@{{scope}}{{name}}":parse({scope = scope, name = name}))
|
||||
end
|
||||
|
||||
function Connection:show_variables(pattern)
|
||||
if pattern and not pattern:is_blank() then
|
||||
return self:query("SHOW VARIABLES LIKE ?", pattern)
|
||||
else
|
||||
return self:query("SHOW VARIABLES")
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:show_status(pattern)
|
||||
if pattern and not pattern:is_blank() then
|
||||
return self:query("SHOW STATUS LIKE ?", pattern)
|
||||
else
|
||||
return self:query("SHOW STATUS")
|
||||
end
|
||||
end
|
||||
|
||||
-- Connection management
|
||||
function mysql.connect(dsn)
|
||||
if dsn:is_blank() then
|
||||
error("DSN cannot be empty")
|
||||
end
|
||||
|
||||
local conn_id = moonshark.sql_connect("mysql", dsn:trim())
|
||||
if conn_id then
|
||||
return setmetatable({_id = conn_id}, Connection)
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
mysql.open = mysql.connect
|
||||
|
||||
-- Quick execution functions
|
||||
function mysql.query(dsn, query_str, ...)
|
||||
local conn = mysql.connect(dsn)
|
||||
if not conn then
|
||||
error("Failed to connect to MySQL database")
|
||||
end
|
||||
|
||||
local results = conn:query(query_str, ...)
|
||||
conn:close()
|
||||
return results
|
||||
end
|
||||
|
||||
function mysql.exec(dsn, query_str, ...)
|
||||
local conn = mysql.connect(dsn)
|
||||
if not conn then
|
||||
error("Failed to connect to MySQL database")
|
||||
end
|
||||
|
||||
local result = conn:exec(query_str, ...)
|
||||
conn:close()
|
||||
return result
|
||||
end
|
||||
|
||||
function mysql.query_row(dsn, query_str, ...)
|
||||
local results = mysql.query(dsn, query_str, ...)
|
||||
return results and #results > 0 and results[1] or nil
|
||||
end
|
||||
|
||||
function mysql.query_value(dsn, query_str, ...)
|
||||
local row = mysql.query_row(dsn, query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Migration helpers
|
||||
function mysql.migrate(dsn, migrations, database_name)
|
||||
local conn = mysql.connect(dsn)
|
||||
if not conn then
|
||||
error("Failed to connect to MySQL database for migration")
|
||||
end
|
||||
|
||||
-- Use specified database if provided
|
||||
if database_name and not database_name:is_blank() then
|
||||
conn:exec("USE {{database}}":parse({database = database_name}))
|
||||
end
|
||||
|
||||
-- Create migrations table
|
||||
conn:create_table("_migrations", "id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) UNIQUE NOT NULL, applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP")
|
||||
|
||||
local tx = conn:begin()
|
||||
if not tx then
|
||||
conn:close()
|
||||
error("Failed to begin migration transaction")
|
||||
end
|
||||
|
||||
for _, migration in ipairs(migrations) do
|
||||
if not migration.name or migration.name:is_blank() then
|
||||
tx:rollback()
|
||||
conn:close()
|
||||
error("Migration must have a non-empty name")
|
||||
end
|
||||
|
||||
-- Check if migration already applied
|
||||
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?", migration.name:trim())
|
||||
if not existing then
|
||||
local ok, err = pcall(function()
|
||||
if type(migration.up) == "string" then
|
||||
conn:exec(migration.up)
|
||||
elseif type(migration.up) == "function" then
|
||||
migration.up(conn)
|
||||
else
|
||||
error("Migration 'up' must be string or function")
|
||||
end
|
||||
end)
|
||||
|
||||
if ok then
|
||||
conn:exec("INSERT INTO _migrations (name) VALUES (?)", migration.name:trim())
|
||||
print("Applied migration: {{name}}":parse({name = migration.name}))
|
||||
else
|
||||
tx:rollback()
|
||||
conn:close()
|
||||
error("Migration '{{name}}' failed: {{error}}":parse({
|
||||
name = migration.name,
|
||||
error = err or "unknown error"
|
||||
}))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
tx:commit()
|
||||
conn:close()
|
||||
return true
|
||||
end
|
||||
|
||||
-- Result processing utilities
|
||||
function mysql.to_array(results, column_name)
|
||||
if not results or table.is_empty(results) then return {} end
|
||||
if column_name:is_blank() then error("Column name cannot be empty") end
|
||||
return table.map(results, function(row) return row[column_name] end)
|
||||
end
|
||||
|
||||
function mysql.to_map(results, key_column, value_column)
|
||||
if not results or table.is_empty(results) then return {} end
|
||||
if key_column:is_blank() then error("Key column name cannot be empty") end
|
||||
|
||||
local map = {}
|
||||
for _, row in ipairs(results) do
|
||||
local key = row[key_column]
|
||||
map[key] = value_column and row[value_column] or row
|
||||
end
|
||||
return map
|
||||
end
|
||||
|
||||
function mysql.group_by(results, column_name)
|
||||
if not results or table.is_empty(results) then return {} end
|
||||
if column_name:is_blank() then error("Column name cannot be empty") end
|
||||
return table.group_by(results, function(row) return row[column_name] end)
|
||||
end
|
||||
|
||||
function mysql.print_results(results)
|
||||
if not results or table.is_empty(results) then
|
||||
print("No results")
|
||||
return
|
||||
end
|
||||
|
||||
local columns = table.keys(results[1])
|
||||
table.sort(columns)
|
||||
|
||||
-- Calculate column widths
|
||||
local widths = {}
|
||||
for _, col in ipairs(columns) do
|
||||
widths[col] = col:length()
|
||||
for _, row in ipairs(results) do
|
||||
local value = tostring(row[col] or "")
|
||||
widths[col] = math.max(widths[col], value:length())
|
||||
end
|
||||
end
|
||||
|
||||
-- Print header and separator
|
||||
local header_parts = table.map(columns, function(col) return col:pad_right(widths[col]) end)
|
||||
local separator_parts = table.map(columns, function(col) return string.repeat_("-", widths[col]) end)
|
||||
|
||||
print(table.concat(header_parts, " | "))
|
||||
print(table.concat(separator_parts, "-+-"))
|
||||
|
||||
-- Print rows
|
||||
for _, row in ipairs(results) do
|
||||
local value_parts = table.map(columns, function(col)
|
||||
local value = tostring(row[col] or "")
|
||||
return value:pad_right(widths[col])
|
||||
end)
|
||||
print(table.concat(value_parts, " | "))
|
||||
end
|
||||
end
|
||||
|
||||
-- MySQL-specific utilities
|
||||
function mysql.escape_string(str_val)
|
||||
if type(str_val) ~= "string" then
|
||||
return tostring(str_val)
|
||||
end
|
||||
return str_val:replace("'", "\\'")
|
||||
end
|
||||
|
||||
function mysql.escape_identifier(name)
|
||||
if name:is_blank() then
|
||||
error("Identifier name cannot be empty")
|
||||
end
|
||||
return "`{{name}}`":parse({name = name:replace("`", "``")})
|
||||
end
|
||||
|
||||
-- DSN builder helper
|
||||
function mysql.build_dsn(options)
|
||||
if type(options) ~= "table" then
|
||||
error("Options must be a table")
|
||||
end
|
||||
|
||||
local parts = {}
|
||||
|
||||
if options.username and not options.username:is_blank() then
|
||||
table.insert(parts, options.username)
|
||||
if options.password and not options.password:is_blank() then
|
||||
parts[#parts] = parts[#parts] .. ":" .. options.password
|
||||
end
|
||||
parts[#parts] = parts[#parts] .. "@"
|
||||
end
|
||||
|
||||
if options.protocol and not options.protocol:is_blank() then
|
||||
local host_part = options.protocol .. "("
|
||||
if options.host and not options.host:is_blank() then
|
||||
host_part = host_part .. options.host
|
||||
if options.port then
|
||||
host_part = host_part .. ":" .. tostring(options.port)
|
||||
end
|
||||
end
|
||||
table.insert(parts, host_part .. ")")
|
||||
elseif options.host and not options.host:is_blank() then
|
||||
local host_part = "tcp(" .. options.host
|
||||
if options.port then
|
||||
host_part = host_part .. ":" .. tostring(options.port)
|
||||
end
|
||||
table.insert(parts, host_part .. ")")
|
||||
end
|
||||
|
||||
if options.database and not options.database:is_blank() then
|
||||
table.insert(parts, "/" .. options.database)
|
||||
end
|
||||
|
||||
-- Add parameters
|
||||
local params = {}
|
||||
if options.charset and not options.charset:is_blank() then
|
||||
table.insert(params, "charset=" .. options.charset)
|
||||
end
|
||||
if options.parseTime ~= nil then
|
||||
table.insert(params, "parseTime=" .. tostring(options.parseTime))
|
||||
end
|
||||
if options.timeout and not options.timeout:is_blank() then
|
||||
table.insert(params, "timeout=" .. options.timeout)
|
||||
end
|
||||
if options.tls and not options.tls:is_blank() then
|
||||
table.insert(params, "tls=" .. options.tls)
|
||||
end
|
||||
|
||||
if #params > 0 then
|
||||
table.insert(parts, "?" .. table.concat(params, "&"))
|
||||
end
|
||||
|
||||
return table.concat(parts, "")
|
||||
end
|
||||
|
||||
return mysql
|
||||
@ -1,688 +0,0 @@
|
||||
local postgres = {}
|
||||
|
||||
local Connection = {}
|
||||
Connection.__index = Connection
|
||||
|
||||
function Connection:close()
|
||||
if self._id then
|
||||
local ok = moonshark.sql_close(self._id)
|
||||
self._id = nil
|
||||
return ok
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
function Connection:ping()
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_ping(self._id)
|
||||
end
|
||||
|
||||
function Connection:query(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_query(self._id, query_str:normalize_whitespace(), ...)
|
||||
end
|
||||
|
||||
function Connection:exec(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_exec(self._id, query_str:normalize_whitespace(), ...)
|
||||
end
|
||||
|
||||
function Connection:query_row(query_str, ...)
|
||||
local results = self:query(query_str, ...)
|
||||
return results and #results > 0 and results[1] or nil
|
||||
end
|
||||
|
||||
function Connection:query_value(query_str, ...)
|
||||
local row = self:query_row(query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function Connection:begin()
|
||||
local result = self:exec("BEGIN")
|
||||
if result then
|
||||
return {
|
||||
conn = self,
|
||||
active = true,
|
||||
commit = function(tx)
|
||||
if tx.active then
|
||||
tx.active = false
|
||||
return tx.conn:exec("COMMIT")
|
||||
end
|
||||
return false
|
||||
end,
|
||||
rollback = function(tx)
|
||||
if tx.active then
|
||||
tx.active = false
|
||||
return tx.conn:exec("ROLLBACK")
|
||||
end
|
||||
return false
|
||||
end,
|
||||
savepoint = function(tx, name)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
if name:is_blank() then error("Savepoint name cannot be empty") end
|
||||
return tx.conn:exec("SAVEPOINT {{name}}":parse({name = name}))
|
||||
end,
|
||||
rollback_to = function(tx, name)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
if name:is_blank() then error("Savepoint name cannot be empty") end
|
||||
return tx.conn:exec("ROLLBACK TO SAVEPOINT {{name}}":parse({name = name}))
|
||||
end,
|
||||
query = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:query(query_str, ...)
|
||||
end,
|
||||
exec = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:exec(query_str, ...)
|
||||
end,
|
||||
query_row = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:query_row(query_str, ...)
|
||||
end,
|
||||
query_value = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:query_value(query_str, ...)
|
||||
end
|
||||
}
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Build PostgreSQL parameters ($1, $2, etc.)
|
||||
local function build_postgres_params(data)
|
||||
local keys = table.keys(data)
|
||||
local values = table.values(data)
|
||||
local placeholders = {}
|
||||
|
||||
for i = 1, #keys do
|
||||
placeholders[i] = "$" .. i
|
||||
end
|
||||
|
||||
return keys, values, placeholders
|
||||
end
|
||||
|
||||
function Connection:insert(table_name, data, returning)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys, values, placeholders = build_postgres_params(data)
|
||||
|
||||
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}})":parse({
|
||||
table = table_name,
|
||||
columns = keys:join(", "),
|
||||
placeholders = table.concat(placeholders, ", ")
|
||||
})
|
||||
|
||||
if returning and not returning:is_blank() then
|
||||
query = query .. " RETURNING " .. returning
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:upsert(table_name, data, conflict_columns, returning)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys, values, placeholders = build_postgres_params(data)
|
||||
local updates = table.map(keys, function(key) return key .. " = EXCLUDED." .. key end)
|
||||
|
||||
local conflict_clause = ""
|
||||
if conflict_columns then
|
||||
if type(conflict_columns) == "string" then
|
||||
conflict_clause = "(" .. conflict_columns .. ")"
|
||||
else
|
||||
conflict_clause = "(" .. table.concat(conflict_columns, ", ") .. ")"
|
||||
end
|
||||
end
|
||||
|
||||
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}}) ON CONFLICT {{conflict}} DO UPDATE SET {{updates}}":parse({
|
||||
table = table_name,
|
||||
columns = keys:join(", "),
|
||||
placeholders = table.concat(placeholders, ", "),
|
||||
conflict = conflict_clause,
|
||||
updates = updates:join(", ")
|
||||
})
|
||||
|
||||
if returning and not returning:is_blank() then
|
||||
query = query .. " RETURNING " .. returning
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:update(table_name, data, where_clause, returning, ...)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if where_clause:is_blank() then
|
||||
error("WHERE clause cannot be empty for UPDATE")
|
||||
end
|
||||
|
||||
local keys = table.keys(data)
|
||||
local values = table.values(data)
|
||||
local param_count = #keys
|
||||
|
||||
-- Build SET clause with numbered parameters
|
||||
local sets = {}
|
||||
for i, key in ipairs(keys) do
|
||||
sets[i] = key .. " = $" .. i
|
||||
end
|
||||
|
||||
-- Handle WHERE parameters
|
||||
local where_args = {...}
|
||||
local where_clause_final = where_clause
|
||||
for i = 1, #where_args do
|
||||
param_count = param_count + 1
|
||||
values[#values + 1] = where_args[i]
|
||||
where_clause_final = where_clause_final:replace("?", "$" .. param_count, 1)
|
||||
end
|
||||
|
||||
local query = "UPDATE {{table}} SET {{sets}} WHERE {{where}}":parse({
|
||||
table = table_name,
|
||||
sets = table.concat(sets, ", "),
|
||||
where = where_clause_final
|
||||
})
|
||||
|
||||
if returning and not returning:is_blank() then
|
||||
query = query .. " RETURNING " .. returning
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:delete(table_name, where_clause, returning, ...)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if where_clause:is_blank() then
|
||||
error("WHERE clause cannot be empty for DELETE")
|
||||
end
|
||||
|
||||
local where_args = {...}
|
||||
local values = {}
|
||||
local where_clause_final = where_clause
|
||||
|
||||
for i = 1, #where_args do
|
||||
values[i] = where_args[i]
|
||||
where_clause_final = where_clause_final:replace("?", "$" .. i, 1)
|
||||
end
|
||||
|
||||
local query = "DELETE FROM {{table}} WHERE {{where}}":parse({
|
||||
table = table_name,
|
||||
where = where_clause_final
|
||||
})
|
||||
|
||||
if returning and not returning:is_blank() then
|
||||
query = query .. " RETURNING " .. returning
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:select(table_name, columns, where_clause, ...)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
columns = columns or "*"
|
||||
if type(columns) == "table" then
|
||||
columns = table.concat(columns, ", ")
|
||||
end
|
||||
|
||||
if where_clause and not where_clause:is_blank() then
|
||||
local where_args = {...}
|
||||
local values = {}
|
||||
local where_clause_final = where_clause
|
||||
|
||||
for i = 1, #where_args do
|
||||
values[i] = where_args[i]
|
||||
where_clause_final = where_clause_final:replace("?", "$" .. i, 1)
|
||||
end
|
||||
|
||||
local query = "SELECT {{columns}} FROM {{table}} WHERE {{where}}":parse({
|
||||
columns = columns,
|
||||
table = table_name,
|
||||
where = where_clause_final
|
||||
})
|
||||
return self:query(query, unpack(values))
|
||||
else
|
||||
local query = "SELECT {{columns}} FROM {{table}}":parse({
|
||||
columns = columns,
|
||||
table = table_name
|
||||
})
|
||||
return self:query(query)
|
||||
end
|
||||
end
|
||||
|
||||
-- Schema helpers
|
||||
function Connection:table_exists(table_name, schema_name)
|
||||
if table_name:is_blank() then return false end
|
||||
schema_name = schema_name or "public"
|
||||
return self:query_value("SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2",
|
||||
schema_name:trim(), table_name:trim()) ~= nil
|
||||
end
|
||||
|
||||
function Connection:column_exists(table_name, column_name, schema_name)
|
||||
if table_name:is_blank() or column_name:is_blank() then return false end
|
||||
schema_name = schema_name or "public"
|
||||
return self:query_value([[
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_schema = $1 AND table_name = $2 AND column_name = $3
|
||||
]], schema_name:trim(), table_name:trim(), column_name:trim()) ~= nil
|
||||
end
|
||||
|
||||
function Connection:create_table(table_name, schema)
|
||||
if table_name:is_blank() or schema:is_blank() then
|
||||
error("Table name and schema cannot be empty")
|
||||
end
|
||||
return self:exec("CREATE TABLE IF NOT EXISTS {{table}} ({{schema}})":parse({
|
||||
table = table_name,
|
||||
schema = schema:trim()
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:drop_table(table_name, cascade)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
local cascade_clause = cascade and " CASCADE" or ""
|
||||
return self:exec("DROP TABLE IF EXISTS {{table}}{{cascade}}":parse({
|
||||
table = table_name,
|
||||
cascade = cascade_clause
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:add_column(table_name, column_def)
|
||||
if table_name:is_blank() or column_def:is_blank() then
|
||||
error("Table name and column definition cannot be empty")
|
||||
end
|
||||
return self:exec("ALTER TABLE {{table}} ADD COLUMN IF NOT EXISTS {{column}}":parse({
|
||||
table = table_name,
|
||||
column = column_def:trim()
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:drop_column(table_name, column_name, cascade)
|
||||
if table_name:is_blank() or column_name:is_blank() then
|
||||
error("Table name and column name cannot be empty")
|
||||
end
|
||||
local cascade_clause = cascade and " CASCADE" or ""
|
||||
return self:exec("ALTER TABLE {{table}} DROP COLUMN IF EXISTS {{column}}{{cascade}}":parse({
|
||||
table = table_name,
|
||||
column = column_name,
|
||||
cascade = cascade_clause
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:create_index(index_name, table_name, columns, unique, method)
|
||||
if index_name:is_blank() or table_name:is_blank() then
|
||||
error("Index name and table name cannot be empty")
|
||||
end
|
||||
|
||||
local unique_clause = unique and "UNIQUE " or ""
|
||||
local method_clause = method and " USING " .. method:upper() or ""
|
||||
local columns_str = type(columns) == "table" and table.concat(columns, ", ") or tostring(columns)
|
||||
|
||||
return self:exec("CREATE {{unique}}INDEX IF NOT EXISTS {{index}} ON {{table}}{{method}} ({{columns}})":parse({
|
||||
unique = unique_clause,
|
||||
index = index_name,
|
||||
table = table_name,
|
||||
method = method_clause,
|
||||
columns = columns_str
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:drop_index(index_name, cascade)
|
||||
if index_name:is_blank() then
|
||||
error("Index name cannot be empty")
|
||||
end
|
||||
local cascade_clause = cascade and " CASCADE" or ""
|
||||
return self:exec("DROP INDEX IF EXISTS {{index}}{{cascade}}":parse({
|
||||
index = index_name,
|
||||
cascade = cascade_clause
|
||||
}))
|
||||
end
|
||||
|
||||
-- PostgreSQL-specific functions
|
||||
function Connection:vacuum(table_name, analyze)
|
||||
local analyze_clause = analyze and " ANALYZE" or ""
|
||||
local table_clause = table_name and " " .. table_name or ""
|
||||
return self:exec("VACUUM{{analyze}}{{table}}":parse({
|
||||
analyze = analyze_clause,
|
||||
table = table_clause
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:analyze(table_name)
|
||||
local table_clause = table_name and " " .. table_name or ""
|
||||
return self:exec("ANALYZE{{table}}":parse({table = table_clause}))
|
||||
end
|
||||
|
||||
function Connection:reindex(name, type)
|
||||
if name:is_blank() then
|
||||
error("Name cannot be empty for REINDEX")
|
||||
end
|
||||
|
||||
type = (type or "INDEX"):upper()
|
||||
local valid_types = {"INDEX", "TABLE", "SCHEMA", "DATABASE", "SYSTEM"}
|
||||
|
||||
if not table.contains(valid_types, type) then
|
||||
error("Invalid REINDEX type: " .. type)
|
||||
end
|
||||
|
||||
return self:exec("REINDEX {{type}} {{name}}":parse({type = type, name = name}))
|
||||
end
|
||||
|
||||
function Connection:show(setting)
|
||||
if setting:is_blank() then
|
||||
error("Setting name cannot be empty")
|
||||
end
|
||||
return self:query_value("SHOW {{setting}}":parse({setting = setting}))
|
||||
end
|
||||
|
||||
function Connection:set(setting, value)
|
||||
if setting:is_blank() then
|
||||
error("Setting name cannot be empty")
|
||||
end
|
||||
return self:exec("SET {{setting}} = {{value}}":parse({
|
||||
setting = setting,
|
||||
value = tostring(value)
|
||||
}))
|
||||
end
|
||||
|
||||
function Connection:current_database()
|
||||
return self:query_value("SELECT current_database()")
|
||||
end
|
||||
|
||||
function Connection:current_schema()
|
||||
return self:query_value("SELECT current_schema()")
|
||||
end
|
||||
|
||||
function Connection:version()
|
||||
return self:query_value("SELECT version()")
|
||||
end
|
||||
|
||||
function Connection:list_schemas()
|
||||
return self:query("SELECT schema_name FROM information_schema.schemata ORDER BY schema_name")
|
||||
end
|
||||
|
||||
function Connection:list_tables(schema_name)
|
||||
schema_name = schema_name or "public"
|
||||
return self:query("SELECT tablename FROM pg_tables WHERE schemaname = $1 ORDER BY tablename", schema_name:trim())
|
||||
end
|
||||
|
||||
function Connection:describe_table(table_name, schema_name)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
schema_name = schema_name or "public"
|
||||
return self:query([[
|
||||
SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = $1 AND table_name = $2
|
||||
ORDER BY ordinal_position
|
||||
]], schema_name:trim(), table_name:trim())
|
||||
end
|
||||
|
||||
-- JSON/JSONB helpers
|
||||
function Connection:json_extract(column, path)
|
||||
if column:is_blank() or path:is_blank() then
|
||||
error("Column and path cannot be empty")
|
||||
end
|
||||
return "{{column}}->'{{path}}'":parse({column = column, path = path})
|
||||
end
|
||||
|
||||
function Connection:json_extract_text(column, path)
|
||||
if column:is_blank() or path:is_blank() then
|
||||
error("Column and path cannot be empty")
|
||||
end
|
||||
return "{{column}}->>'{{path}}'":parse({column = column, path = path})
|
||||
end
|
||||
|
||||
function Connection:jsonb_contains(column, value)
|
||||
if column:is_blank() or value:is_blank() then
|
||||
error("Column and value cannot be empty")
|
||||
end
|
||||
return "{{column}} @> '{{value}}'":parse({column = column, value = value})
|
||||
end
|
||||
|
||||
function Connection:jsonb_contained_by(column, value)
|
||||
if column:is_blank() or value:is_blank() then
|
||||
error("Column and value cannot be empty")
|
||||
end
|
||||
return "{{column}} <@ '{{value}}'":parse({column = column, value = value})
|
||||
end
|
||||
|
||||
-- Array helpers
|
||||
function Connection:array_contains(column, value)
|
||||
if column:is_blank() then
|
||||
error("Column cannot be empty")
|
||||
end
|
||||
return "$1 = ANY({{column}})":parse({column = column})
|
||||
end
|
||||
|
||||
function Connection:array_length(column)
|
||||
if column:is_blank() then
|
||||
error("Column cannot be empty")
|
||||
end
|
||||
return "array_length({{column}}, 1)":parse({column = column})
|
||||
end
|
||||
|
||||
-- Connection management
|
||||
function postgres.parse_dsn(dsn)
|
||||
if dsn:is_blank() then
|
||||
return nil, "DSN cannot be empty"
|
||||
end
|
||||
|
||||
local parts = {}
|
||||
for pair in dsn:trim():gmatch("[^%s]+") do
|
||||
local key, value = pair:match("([^=]+)=(.+)")
|
||||
if key and value then
|
||||
parts[key:trim()] = value:trim()
|
||||
end
|
||||
end
|
||||
|
||||
return parts
|
||||
end
|
||||
|
||||
function postgres.connect(dsn)
|
||||
if dsn:is_blank() then
|
||||
error("DSN cannot be empty")
|
||||
end
|
||||
|
||||
local conn_id = moonshark.sql_connect("postgres", dsn:trim())
|
||||
if conn_id then
|
||||
return setmetatable({_id = conn_id}, Connection)
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
postgres.open = postgres.connect
|
||||
|
||||
-- Quick execution functions
|
||||
function postgres.query(dsn, query_str, ...)
|
||||
local conn = postgres.connect(dsn)
|
||||
if not conn then
|
||||
error("Failed to connect to PostgreSQL database")
|
||||
end
|
||||
|
||||
local results = conn:query(query_str, ...)
|
||||
conn:close()
|
||||
return results
|
||||
end
|
||||
|
||||
function postgres.exec(dsn, query_str, ...)
|
||||
local conn = postgres.connect(dsn)
|
||||
if not conn then
|
||||
error("Failed to connect to PostgreSQL database")
|
||||
end
|
||||
|
||||
local result = conn:exec(query_str, ...)
|
||||
conn:close()
|
||||
return result
|
||||
end
|
||||
|
||||
function postgres.query_row(dsn, query_str, ...)
|
||||
local results = postgres.query(dsn, query_str, ...)
|
||||
return results and #results > 0 and results[1] or nil
|
||||
end
|
||||
|
||||
function postgres.query_value(dsn, query_str, ...)
|
||||
local row = postgres.query_row(dsn, query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Migration helpers
|
||||
function postgres.migrate(dsn, migrations, schema)
|
||||
schema = schema or "public"
|
||||
local conn = postgres.connect(dsn)
|
||||
if not conn then
|
||||
error("Failed to connect to PostgreSQL database for migration")
|
||||
end
|
||||
|
||||
conn:create_table("_migrations", "id SERIAL PRIMARY KEY, name TEXT UNIQUE NOT NULL, applied_at TIMESTAMPTZ DEFAULT NOW()")
|
||||
|
||||
local tx = conn:begin()
|
||||
if not tx then
|
||||
conn:close()
|
||||
error("Failed to begin migration transaction")
|
||||
end
|
||||
|
||||
for _, migration in ipairs(migrations) do
|
||||
if not migration.name or migration.name:is_blank() then
|
||||
tx:rollback()
|
||||
conn:close()
|
||||
error("Migration must have a non-empty name")
|
||||
end
|
||||
|
||||
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = $1", migration.name:trim())
|
||||
if not existing then
|
||||
local ok, err = pcall(function()
|
||||
if type(migration.up) == "string" then
|
||||
conn:exec(migration.up)
|
||||
elseif type(migration.up) == "function" then
|
||||
migration.up(conn)
|
||||
else
|
||||
error("Migration 'up' must be string or function")
|
||||
end
|
||||
end)
|
||||
|
||||
if ok then
|
||||
conn:exec("INSERT INTO _migrations (name) VALUES ($1)", migration.name:trim())
|
||||
print("Applied migration: {{name}}":parse({name = migration.name}))
|
||||
else
|
||||
tx:rollback()
|
||||
conn:close()
|
||||
error("Migration '{{name}}' failed: {{error}}":parse({
|
||||
name = migration.name,
|
||||
error = err or "unknown error"
|
||||
}))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
tx:commit()
|
||||
conn:close()
|
||||
return true
|
||||
end
|
||||
|
||||
-- Result processing utilities
|
||||
function postgres.to_array(results, column_name)
|
||||
if not results or table.is_empty(results) then return {} end
|
||||
if column_name:is_blank() then error("Column name cannot be empty") end
|
||||
return table.map(results, function(row) return row[column_name] end)
|
||||
end
|
||||
|
||||
function postgres.to_map(results, key_column, value_column)
|
||||
if not results or table.is_empty(results) then return {} end
|
||||
if key_column:is_blank() then error("Key column name cannot be empty") end
|
||||
|
||||
local map = {}
|
||||
for _, row in ipairs(results) do
|
||||
local key = row[key_column]
|
||||
map[key] = value_column and row[value_column] or row
|
||||
end
|
||||
return map
|
||||
end
|
||||
|
||||
function postgres.group_by(results, column_name)
|
||||
if not results or table.is_empty(results) then return {} end
|
||||
if column_name:is_blank() then error("Column name cannot be empty") end
|
||||
return table.group_by(results, function(row) return row[column_name] end)
|
||||
end
|
||||
|
||||
function postgres.print_results(results)
|
||||
if not results or table.is_empty(results) then
|
||||
print("No results")
|
||||
return
|
||||
end
|
||||
|
||||
local columns = table.keys(results[1])
|
||||
table.sort(columns)
|
||||
|
||||
-- Calculate column widths
|
||||
local widths = {}
|
||||
for _, col in ipairs(columns) do
|
||||
widths[col] = col:length()
|
||||
for _, row in ipairs(results) do
|
||||
local value = tostring(row[col] or "")
|
||||
widths[col] = math.max(widths[col], value:length())
|
||||
end
|
||||
end
|
||||
|
||||
-- Print header and separator
|
||||
local header_parts = table.map(columns, function(col) return col:pad_right(widths[col]) end)
|
||||
local separator_parts = table.map(columns, function(col) return string.repeat_("-", widths[col]) end)
|
||||
|
||||
print(table.concat(header_parts, " | "))
|
||||
print(table.concat(separator_parts, "-+-"))
|
||||
|
||||
-- Print rows
|
||||
for _, row in ipairs(results) do
|
||||
local value_parts = table.map(columns, function(col)
|
||||
local value = tostring(row[col] or "")
|
||||
return value:pad_right(widths[col])
|
||||
end)
|
||||
print(table.concat(value_parts, " | "))
|
||||
end
|
||||
end
|
||||
|
||||
function postgres.escape_identifier(name)
|
||||
if name:is_blank() then
|
||||
error("Identifier name cannot be empty")
|
||||
end
|
||||
return '"{{name}}"':parse({name = name:replace('"', '""')})
|
||||
end
|
||||
|
||||
function postgres.escape_literal(value)
|
||||
if type(value) == "string" then
|
||||
return "'{{value}}'":parse({value = value:replace("'", "''")})
|
||||
end
|
||||
return tostring(value)
|
||||
end
|
||||
|
||||
return postgres
|
||||
@ -1,155 +0,0 @@
|
||||
package modules
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"maps"
|
||||
"strings"
|
||||
|
||||
"Moonshark/modules/crypto"
|
||||
"Moonshark/modules/fs"
|
||||
"Moonshark/modules/http"
|
||||
"Moonshark/modules/kv"
|
||||
"Moonshark/modules/sql"
|
||||
lua_string "Moonshark/modules/string+"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
var Global *Registry
|
||||
|
||||
//go:embed **/*.lua
|
||||
var embeddedModules embed.FS
|
||||
|
||||
type Registry struct {
|
||||
modules map[string]string
|
||||
globalModules map[string]string // globalName -> moduleSource
|
||||
goFuncs map[string]luajit.GoFunction
|
||||
}
|
||||
|
||||
func New() *Registry {
|
||||
r := &Registry{
|
||||
modules: make(map[string]string),
|
||||
globalModules: make(map[string]string),
|
||||
goFuncs: make(map[string]luajit.GoFunction),
|
||||
}
|
||||
|
||||
maps.Copy(r.goFuncs, lua_string.GetFunctionList())
|
||||
maps.Copy(r.goFuncs, crypto.GetFunctionList())
|
||||
maps.Copy(r.goFuncs, fs.GetFunctionList())
|
||||
maps.Copy(r.goFuncs, http.GetFunctionList())
|
||||
maps.Copy(r.goFuncs, sql.GetFunctionList())
|
||||
maps.Copy(r.goFuncs, kv.GetFunctionList())
|
||||
|
||||
r.loadEmbeddedModules()
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Registry) loadEmbeddedModules() {
|
||||
dirs, _ := embeddedModules.ReadDir(".")
|
||||
|
||||
for _, dir := range dirs {
|
||||
if !dir.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
dirName := dir.Name()
|
||||
isGlobal := strings.HasSuffix(dirName, "+")
|
||||
|
||||
var moduleName, globalName string
|
||||
if isGlobal {
|
||||
moduleName = strings.TrimSuffix(dirName, "+")
|
||||
globalName = moduleName
|
||||
} else {
|
||||
moduleName = dirName
|
||||
}
|
||||
|
||||
modulePath := fmt.Sprintf("%s/%s.lua", dirName, moduleName)
|
||||
if source, err := embeddedModules.ReadFile(modulePath); err == nil {
|
||||
r.modules[moduleName] = string(source)
|
||||
if isGlobal {
|
||||
r.globalModules[globalName] = string(source)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) InstallInState(state *luajit.State) error {
|
||||
// Create moonshark global table with Go functions
|
||||
state.NewTable()
|
||||
for name, fn := range r.goFuncs {
|
||||
if err := state.PushGoFunction(fn); err != nil {
|
||||
return fmt.Errorf("failed to register Go function '%s': %w", name, err)
|
||||
}
|
||||
state.SetField(-2, name)
|
||||
}
|
||||
state.SetGlobal("moonshark")
|
||||
|
||||
// Auto-enhance all global modules
|
||||
for globalName, source := range r.globalModules {
|
||||
if err := r.enhanceGlobal(state, globalName, source); err != nil {
|
||||
return fmt.Errorf("failed to enhance %s global: %w", globalName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Backup original require and install custom one
|
||||
state.GetGlobal("require")
|
||||
state.SetGlobal("_require_original")
|
||||
|
||||
return state.RegisterGoFunction("require", func(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(1); err != nil {
|
||||
return s.PushError("require: %v", err)
|
||||
}
|
||||
|
||||
moduleName, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("require: module name must be a string")
|
||||
}
|
||||
|
||||
// Return global if this module enhances a global
|
||||
if _, isGlobal := r.globalModules[moduleName]; isGlobal {
|
||||
s.GetGlobal(moduleName)
|
||||
return 1
|
||||
}
|
||||
|
||||
// Check built-in modules
|
||||
if source, exists := r.modules[moduleName]; exists {
|
||||
if err := s.LoadString(source); err != nil {
|
||||
return s.PushError("require: failed to load module '%s': %v", moduleName, err)
|
||||
}
|
||||
if err := s.Call(0, 1); err != nil {
|
||||
return s.PushError("require: failed to execute module '%s': %v", moduleName, err)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// Fall back to original require
|
||||
s.GetGlobal("_require_original")
|
||||
if s.IsFunction(-1) {
|
||||
s.PushString(moduleName)
|
||||
if err := s.Call(1, 1); err != nil {
|
||||
return s.PushError("require: %v", err)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
return s.PushError("require: module '%s' not found", moduleName)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Registry) enhanceGlobal(state *luajit.State, globalName, source string) error {
|
||||
// Execute the module - it directly modifies the global
|
||||
if err := state.LoadString(source); err != nil {
|
||||
return fmt.Errorf("failed to load %s module: %w", globalName, err)
|
||||
}
|
||||
if err := state.Call(0, 0); err != nil { // 0 results expected
|
||||
return fmt.Errorf("failed to execute %s module: %w", globalName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Initialize() error {
|
||||
Global = New()
|
||||
return nil
|
||||
}
|
||||
@ -1,205 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
// MySQLDriver implements the Driver interface for MySQL
|
||||
type MySQLDriver struct{}
|
||||
|
||||
func (d *MySQLDriver) Name() string {
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
func (d *MySQLDriver) Open(dsn string) (Connection, error) {
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Test the connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("mysql: failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &MySQLConnection{db: db}, nil
|
||||
}
|
||||
|
||||
// MySQLConnection implements the Connection interface
|
||||
type MySQLConnection struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Ping(ctx context.Context) error {
|
||||
return c.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Begin(ctx context.Context) (Transaction, error) {
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: failed to begin transaction: %w", err)
|
||||
}
|
||||
return &MySQLTransaction{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
rows, err := c.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: query failed: %w", err)
|
||||
}
|
||||
return &MySQLRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
row := c.db.QueryRowContext(ctx, query, args...)
|
||||
return &MySQLRow{row: row}
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
result, err := c.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: exec failed: %w", err)
|
||||
}
|
||||
return &MySQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
func (c *MySQLConnection) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
stmt, err := c.db.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: failed to prepare statement: %w", err)
|
||||
}
|
||||
return &MySQLStatement{stmt: stmt}, nil
|
||||
}
|
||||
|
||||
// MySQLTransaction implements the Transaction interface
|
||||
type MySQLTransaction struct {
|
||||
tx *sql.Tx
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Commit() error {
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Rollback() error {
|
||||
return t.tx.Rollback()
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
rows, err := t.tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: transaction query failed: %w", err)
|
||||
}
|
||||
return &MySQLRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
row := t.tx.QueryRowContext(ctx, query, args...)
|
||||
return &MySQLRow{row: row}
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
result, err := t.tx.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: transaction exec failed: %w", err)
|
||||
}
|
||||
return &MySQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
func (t *MySQLTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
stmt, err := t.tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: failed to prepare transaction statement: %w", err)
|
||||
}
|
||||
return &MySQLStatement{stmt: stmt}, nil
|
||||
}
|
||||
|
||||
// MySQLRows implements the Rows interface
|
||||
type MySQLRows struct {
|
||||
rows *sql.Rows
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Next() bool {
|
||||
return r.rows.Next()
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Scan(dest ...any) error {
|
||||
return r.rows.Scan(dest...)
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Columns() ([]string, error) {
|
||||
return r.rows.Columns()
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Close() error {
|
||||
return r.rows.Close()
|
||||
}
|
||||
|
||||
func (r *MySQLRows) Err() error {
|
||||
return r.rows.Err()
|
||||
}
|
||||
|
||||
// MySQLRow implements the Row interface
|
||||
type MySQLRow struct {
|
||||
row *sql.Row
|
||||
}
|
||||
|
||||
func (r *MySQLRow) Scan(dest ...any) error {
|
||||
return r.row.Scan(dest...)
|
||||
}
|
||||
|
||||
// MySQLResult implements the Result interface
|
||||
type MySQLResult struct {
|
||||
result sql.Result
|
||||
}
|
||||
|
||||
func (r *MySQLResult) LastInsertId() (int64, error) {
|
||||
return r.result.LastInsertId()
|
||||
}
|
||||
|
||||
func (r *MySQLResult) RowsAffected() (int64, error) {
|
||||
return r.result.RowsAffected()
|
||||
}
|
||||
|
||||
// MySQLStatement implements the Statement interface
|
||||
type MySQLStatement struct {
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *MySQLStatement) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s *MySQLStatement) Query(ctx context.Context, args ...any) (Rows, error) {
|
||||
rows, err := s.stmt.QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: statement query failed: %w", err)
|
||||
}
|
||||
return &MySQLRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (s *MySQLStatement) QueryRow(ctx context.Context, args ...any) Row {
|
||||
row := s.stmt.QueryRowContext(ctx, args...)
|
||||
return &MySQLRow{row: row}
|
||||
}
|
||||
|
||||
func (s *MySQLStatement) Exec(ctx context.Context, args ...any) (Result, error) {
|
||||
result, err := s.stmt.ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mysql: statement exec failed: %w", err)
|
||||
}
|
||||
return &MySQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register MySQL driver on import
|
||||
RegisterDriver("mysql", &MySQLDriver{})
|
||||
}
|
||||
@ -1,234 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// PostgresDriver implements the Driver interface for PostgreSQL
|
||||
type PostgresDriver struct{}
|
||||
|
||||
func (d *PostgresDriver) Name() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func (d *PostgresDriver) Open(dsn string) (Connection, error) {
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to parse config: %w", err)
|
||||
}
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to create pool: %w", err)
|
||||
}
|
||||
|
||||
return &PostgresConnection{pool: pool}, nil
|
||||
}
|
||||
|
||||
// PostgresConnection implements the Connection interface
|
||||
type PostgresConnection struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Close() error {
|
||||
c.pool.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Ping(ctx context.Context) error {
|
||||
return c.pool.Ping(ctx)
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Begin(ctx context.Context) (Transaction, error) {
|
||||
tx, err := c.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to begin transaction: %w", err)
|
||||
}
|
||||
return &PostgresTransaction{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
rows, err := c.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: query failed: %w", err)
|
||||
}
|
||||
return &PostgresRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
row := c.pool.QueryRow(ctx, query, args...)
|
||||
return &PostgresRow{row: row}
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
tag, err := c.pool.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: exec failed: %w", err)
|
||||
}
|
||||
return &PostgresResult{tag: tag}, nil
|
||||
}
|
||||
|
||||
func (c *PostgresConnection) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
// pgx doesn't have explicit prepared statements like database/sql
|
||||
// We'll store the query and use it with the pool
|
||||
return &PostgresStatement{pool: c.pool, query: query}, nil
|
||||
}
|
||||
|
||||
// PostgresTransaction implements the Transaction interface
|
||||
type PostgresTransaction struct {
|
||||
tx pgx.Tx
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Commit() error {
|
||||
return t.tx.Commit(context.Background())
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Rollback() error {
|
||||
return t.tx.Rollback(context.Background())
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
rows, err := t.tx.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: transaction query failed: %w", err)
|
||||
}
|
||||
return &PostgresRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
row := t.tx.QueryRow(ctx, query, args...)
|
||||
return &PostgresRow{row: row}
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
tag, err := t.tx.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: transaction exec failed: %w", err)
|
||||
}
|
||||
return &PostgresResult{tag: tag}, nil
|
||||
}
|
||||
|
||||
func (t *PostgresTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
return &PostgresStatement{tx: t.tx, query: query}, nil
|
||||
}
|
||||
|
||||
// PostgresRows implements the Rows interface
|
||||
type PostgresRows struct {
|
||||
rows pgx.Rows
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Next() bool {
|
||||
return r.rows.Next()
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Scan(dest ...any) error {
|
||||
return r.rows.Scan(dest...)
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Columns() ([]string, error) {
|
||||
fields := r.rows.FieldDescriptions()
|
||||
columns := make([]string, len(fields))
|
||||
for i, field := range fields {
|
||||
columns[i] = field.Name
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Close() error {
|
||||
r.rows.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PostgresRows) Err() error {
|
||||
return r.rows.Err()
|
||||
}
|
||||
|
||||
// PostgresRow implements the Row interface
|
||||
type PostgresRow struct {
|
||||
row pgx.Row
|
||||
}
|
||||
|
||||
func (r *PostgresRow) Scan(dest ...any) error {
|
||||
return r.row.Scan(dest...)
|
||||
}
|
||||
|
||||
// PostgresResult implements the Result interface
|
||||
type PostgresResult struct {
|
||||
tag pgconn.CommandTag
|
||||
}
|
||||
|
||||
func (r *PostgresResult) LastInsertId() (int64, error) {
|
||||
// PostgreSQL doesn't have AUTO_INCREMENT like MySQL
|
||||
// Users should use RETURNING clause or sequences
|
||||
return 0, fmt.Errorf("postgres: LastInsertId not supported, use RETURNING clause")
|
||||
}
|
||||
|
||||
func (r *PostgresResult) RowsAffected() (int64, error) {
|
||||
return r.tag.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// PostgresStatement implements the Statement interface
|
||||
type PostgresStatement struct {
|
||||
pool *pgxpool.Pool
|
||||
tx pgx.Tx
|
||||
query string
|
||||
}
|
||||
|
||||
func (s *PostgresStatement) Close() error {
|
||||
// pgx doesn't require explicit statement cleanup
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostgresStatement) Query(ctx context.Context, args ...any) (Rows, error) {
|
||||
var rows pgx.Rows
|
||||
var err error
|
||||
|
||||
if s.tx != nil {
|
||||
rows, err = s.tx.Query(ctx, s.query, args...)
|
||||
} else {
|
||||
rows, err = s.pool.Query(ctx, s.query, args...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: statement query failed: %w", err)
|
||||
}
|
||||
return &PostgresRows{rows: rows}, nil
|
||||
}
|
||||
|
||||
func (s *PostgresStatement) QueryRow(ctx context.Context, args ...any) Row {
|
||||
var row pgx.Row
|
||||
|
||||
if s.tx != nil {
|
||||
row = s.tx.QueryRow(ctx, s.query, args...)
|
||||
} else {
|
||||
row = s.pool.QueryRow(ctx, s.query, args...)
|
||||
}
|
||||
|
||||
return &PostgresRow{row: row}
|
||||
}
|
||||
|
||||
func (s *PostgresStatement) Exec(ctx context.Context, args ...any) (Result, error) {
|
||||
var tag pgconn.CommandTag
|
||||
var err error
|
||||
|
||||
if s.tx != nil {
|
||||
tag, err = s.tx.Exec(ctx, s.query, args...)
|
||||
} else {
|
||||
tag, err = s.pool.Exec(ctx, s.query, args...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: statement exec failed: %w", err)
|
||||
}
|
||||
return &PostgresResult{tag: tag}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register PostgreSQL driver on import
|
||||
RegisterDriver("postgres", &PostgresDriver{})
|
||||
}
|
||||
@ -1,377 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Driver interface for SQL database implementations
|
||||
type Driver interface {
|
||||
Open(dsn string) (Connection, error)
|
||||
Name() string
|
||||
}
|
||||
|
||||
// Connection represents a database connection
|
||||
type Connection interface {
|
||||
Close() error
|
||||
Ping(ctx context.Context) error
|
||||
Begin(ctx context.Context) (Transaction, error)
|
||||
Query(ctx context.Context, query string, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, query string, args ...any) Row
|
||||
Exec(ctx context.Context, query string, args ...any) (Result, error)
|
||||
Prepare(ctx context.Context, query string) (Statement, error)
|
||||
}
|
||||
|
||||
// Transaction represents a database transaction
|
||||
type Transaction interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
Query(ctx context.Context, query string, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, query string, args ...any) Row
|
||||
Exec(ctx context.Context, query string, args ...any) (Result, error)
|
||||
Prepare(ctx context.Context, query string) (Statement, error)
|
||||
}
|
||||
|
||||
// Rows represents query result rows
|
||||
type Rows interface {
|
||||
Next() bool
|
||||
Scan(dest ...any) error
|
||||
Columns() ([]string, error)
|
||||
Close() error
|
||||
Err() error
|
||||
}
|
||||
|
||||
// Row represents a single query result row
|
||||
type Row interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
// Result represents the result of an executed statement
|
||||
type Result interface {
|
||||
LastInsertId() (int64, error)
|
||||
RowsAffected() (int64, error)
|
||||
}
|
||||
|
||||
// Statement represents a prepared statement
|
||||
type Statement interface {
|
||||
Close() error
|
||||
Query(ctx context.Context, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, args ...any) Row
|
||||
Exec(ctx context.Context, args ...any) (Result, error)
|
||||
}
|
||||
|
||||
// Registry manages database drivers and connections
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
drivers map[string]Driver
|
||||
conns map[string]Connection
|
||||
nextID int
|
||||
}
|
||||
|
||||
var global = &Registry{
|
||||
drivers: make(map[string]Driver),
|
||||
conns: make(map[string]Connection),
|
||||
}
|
||||
|
||||
// RegisterDriver registers a database driver
|
||||
func RegisterDriver(name string, driver Driver) {
|
||||
global.mu.Lock()
|
||||
defer global.mu.Unlock()
|
||||
global.drivers[name] = driver
|
||||
}
|
||||
|
||||
// GetDriver returns a registered driver
|
||||
func GetDriver(name string) (Driver, bool) {
|
||||
global.mu.RLock()
|
||||
defer global.mu.RUnlock()
|
||||
driver, exists := global.drivers[name]
|
||||
return driver, exists
|
||||
}
|
||||
|
||||
// Connect opens a new database connection
|
||||
func Connect(driverName, dsn string) (string, error) {
|
||||
driver, exists := GetDriver(driverName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("unknown driver: %s", driverName)
|
||||
}
|
||||
|
||||
conn, err := driver.Open(dsn)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
global.mu.Lock()
|
||||
defer global.mu.Unlock()
|
||||
|
||||
id := fmt.Sprintf("%s_%d", driverName, global.nextID)
|
||||
global.nextID++
|
||||
global.conns[id] = conn
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetConnection retrieves a connection by ID
|
||||
func GetConnection(id string) (Connection, bool) {
|
||||
global.mu.RLock()
|
||||
defer global.mu.RUnlock()
|
||||
conn, exists := global.conns[id]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// CloseConnection closes and removes a connection
|
||||
func CloseConnection(id string) error {
|
||||
global.mu.Lock()
|
||||
defer global.mu.Unlock()
|
||||
|
||||
conn, exists := global.conns[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("connection not found: %s", id)
|
||||
}
|
||||
|
||||
err := conn.Close()
|
||||
delete(global.conns, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func CloseAllConnections() {
|
||||
global.mu.Lock()
|
||||
defer global.mu.Unlock()
|
||||
|
||||
for id, conn := range global.conns {
|
||||
conn.Close()
|
||||
delete(global.conns, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Lua function implementations
|
||||
|
||||
func luaConnect(s *luajit.State) int {
|
||||
if err := s.CheckExactArgs(2); err != nil {
|
||||
return s.PushError("connect: %v", err)
|
||||
}
|
||||
|
||||
driver, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("connect: driver must be a string")
|
||||
}
|
||||
|
||||
dsn, err := s.SafeToString(2)
|
||||
if err != nil {
|
||||
return s.PushError("connect: dsn must be a string")
|
||||
}
|
||||
|
||||
connID, err := Connect(driver, dsn)
|
||||
if err != nil {
|
||||
return s.PushError("connect: %v", err)
|
||||
}
|
||||
|
||||
s.PushString(connID)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaClose(s *luajit.State) int {
|
||||
if err := s.CheckExactArgs(1); err != nil {
|
||||
return s.PushError("close: %v", err)
|
||||
}
|
||||
|
||||
connID, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("close: connection id must be a string")
|
||||
}
|
||||
|
||||
if err := CloseConnection(connID); err != nil {
|
||||
return s.PushError("close: %v", err)
|
||||
}
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaPing(s *luajit.State) int {
|
||||
if err := s.CheckExactArgs(1); err != nil {
|
||||
return s.PushError("ping: %v", err)
|
||||
}
|
||||
|
||||
connID, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("ping: connection id must be a string")
|
||||
}
|
||||
|
||||
conn, exists := GetConnection(connID)
|
||||
if !exists {
|
||||
return s.PushError("ping: connection not found")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := conn.Ping(ctx); err != nil {
|
||||
return s.PushError("ping: %v", err)
|
||||
}
|
||||
|
||||
s.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaQuery(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(2); err != nil {
|
||||
return s.PushError("query: %v", err)
|
||||
}
|
||||
|
||||
connID, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("query: connection id must be a string")
|
||||
}
|
||||
|
||||
query, err := s.SafeToString(2)
|
||||
if err != nil {
|
||||
return s.PushError("query: query must be a string")
|
||||
}
|
||||
|
||||
conn, exists := GetConnection(connID)
|
||||
if !exists {
|
||||
return s.PushError("query: connection not found")
|
||||
}
|
||||
|
||||
// Collect arguments
|
||||
args := make([]any, s.GetTop()-2)
|
||||
for i := 3; i <= s.GetTop(); i++ {
|
||||
val, err := s.ToValue(i)
|
||||
if err != nil {
|
||||
args[i-3] = nil
|
||||
} else {
|
||||
args[i-3] = val
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
rows, err := conn.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return s.PushError("query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Get column names
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return s.PushError("query: failed to get columns: %v", err)
|
||||
}
|
||||
|
||||
// Build result array
|
||||
s.CreateTable(0, 0)
|
||||
rowIndex := 1
|
||||
|
||||
for rows.Next() {
|
||||
// Create values slice for scanning
|
||||
values := make([]any, len(columns))
|
||||
valuePtrs := make([]any, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return s.PushError("query: scan error: %v", err)
|
||||
}
|
||||
|
||||
// Create row table
|
||||
s.CreateTable(0, len(columns))
|
||||
for i, col := range columns {
|
||||
s.PushString(col)
|
||||
if err := s.PushValue(values[i]); err != nil {
|
||||
s.PushNil()
|
||||
}
|
||||
s.SetTable(-3)
|
||||
}
|
||||
|
||||
// Add to result array
|
||||
s.PushNumber(float64(rowIndex))
|
||||
s.PushCopy(-2)
|
||||
s.SetTable(-4)
|
||||
s.Pop(1) // Remove row table copy
|
||||
|
||||
rowIndex++
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return s.PushError("query: %v", err)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaExec(s *luajit.State) int {
|
||||
if err := s.CheckMinArgs(2); err != nil {
|
||||
return s.PushError("exec: %v", err)
|
||||
}
|
||||
|
||||
connID, err := s.SafeToString(1)
|
||||
if err != nil {
|
||||
return s.PushError("exec: connection id must be a string")
|
||||
}
|
||||
|
||||
query, err := s.SafeToString(2)
|
||||
if err != nil {
|
||||
return s.PushError("exec: query must be a string")
|
||||
}
|
||||
|
||||
conn, exists := GetConnection(connID)
|
||||
if !exists {
|
||||
return s.PushError("exec: connection not found")
|
||||
}
|
||||
|
||||
// Collect arguments
|
||||
args := make([]any, s.GetTop()-2)
|
||||
for i := 3; i <= s.GetTop(); i++ {
|
||||
val, err := s.ToValue(i)
|
||||
if err != nil {
|
||||
args[i-3] = nil
|
||||
} else {
|
||||
args[i-3] = val
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := conn.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return s.PushError("exec: %v", err)
|
||||
}
|
||||
|
||||
// Return result info
|
||||
s.CreateTable(0, 2)
|
||||
|
||||
lastID, _ := result.LastInsertId()
|
||||
s.PushString("last_insert_id")
|
||||
s.PushNumber(float64(lastID))
|
||||
s.SetTable(-3)
|
||||
|
||||
affected, _ := result.RowsAffected()
|
||||
s.PushString("rows_affected")
|
||||
s.PushNumber(float64(affected))
|
||||
s.SetTable(-3)
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// GetFunctionList returns all Lua-callable functions
|
||||
func GetFunctionList() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"sql_connect": luaConnect,
|
||||
"sql_close": luaClose,
|
||||
"sql_ping": luaPing,
|
||||
"sql_query": luaQuery,
|
||||
"sql_exec": luaExec,
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register SQLite driver on import
|
||||
RegisterDriver("sqlite", &SQLiteDriver{})
|
||||
}
|
||||
@ -1,384 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
)
|
||||
|
||||
// SQLiteDriver implements the Driver interface for SQLite
|
||||
type SQLiteDriver struct{}
|
||||
|
||||
func (d *SQLiteDriver) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (d *SQLiteDriver) Open(dsn string) (Connection, error) {
|
||||
conn, err := sqlite.OpenConn(dsn, sqlite.OpenReadWrite|sqlite.OpenCreate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to open database: %w", err)
|
||||
}
|
||||
|
||||
return &SQLiteConnection{conn: conn}, nil
|
||||
}
|
||||
|
||||
// SQLiteConnection implements the Connection interface
|
||||
type SQLiteConnection struct {
|
||||
conn *sqlite.Conn
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Ping(ctx context.Context) error {
|
||||
return sqlitex.ExecuteTransient(c.conn, "SELECT 1", nil)
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Begin(ctx context.Context) (Transaction, error) {
|
||||
if err := sqlitex.ExecuteTransient(c.conn, "BEGIN", nil); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to begin transaction: %w", err)
|
||||
}
|
||||
return &SQLiteTransaction{conn: c.conn}, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
stmt, err := c.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to prepare query: %w", err)
|
||||
}
|
||||
|
||||
if err := c.bindArgs(stmt, args...); err != nil {
|
||||
stmt.Finalize()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SQLiteRows{stmt: stmt, hasNext: true}, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
rows, err := c.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return &SQLiteRow{err: err}
|
||||
}
|
||||
return &SQLiteRow{rows: rows.(*SQLiteRows)}
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
stmt, err := c.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
|
||||
}
|
||||
defer stmt.Finalize()
|
||||
|
||||
if err := c.bindArgs(stmt, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hasRow, err := stmt.Step()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
|
||||
}
|
||||
|
||||
// Consume all rows if any
|
||||
for hasRow {
|
||||
hasRow, err = stmt.Step()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &SQLiteResult{
|
||||
lastInsertID: c.conn.LastInsertRowID(),
|
||||
rowsAffected: c.conn.Changes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
stmt, err := c.conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to prepare statement: %w", err)
|
||||
}
|
||||
return &SQLiteStatement{stmt: stmt, conn: c.conn}, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConnection) bindArgs(stmt *sqlite.Stmt, args ...any) error {
|
||||
for i, arg := range args {
|
||||
paramIndex := i + 1
|
||||
|
||||
if arg == nil {
|
||||
stmt.BindNull(paramIndex)
|
||||
continue
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case int:
|
||||
stmt.BindInt64(paramIndex, int64(v))
|
||||
case int64:
|
||||
stmt.BindInt64(paramIndex, v)
|
||||
case float64:
|
||||
stmt.BindFloat(paramIndex, v)
|
||||
case string:
|
||||
stmt.BindText(paramIndex, v)
|
||||
case bool:
|
||||
if v {
|
||||
stmt.BindInt64(paramIndex, 1)
|
||||
} else {
|
||||
stmt.BindInt64(paramIndex, 0)
|
||||
}
|
||||
case []byte:
|
||||
stmt.BindBytes(paramIndex, v)
|
||||
default:
|
||||
return fmt.Errorf("sqlite: unsupported parameter type: %T", arg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SQLiteTransaction implements the Transaction interface
|
||||
type SQLiteTransaction struct {
|
||||
conn *sqlite.Conn
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Commit() error {
|
||||
return sqlitex.ExecuteTransient(t.conn, "COMMIT", nil)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Rollback() error {
|
||||
return sqlitex.ExecuteTransient(t.conn, "ROLLBACK", nil)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
conn := &SQLiteConnection{conn: t.conn}
|
||||
return conn.Query(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) QueryRow(ctx context.Context, query string, args ...any) Row {
|
||||
conn := &SQLiteConnection{conn: t.conn}
|
||||
return conn.QueryRow(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Exec(ctx context.Context, query string, args ...any) (Result, error) {
|
||||
conn := &SQLiteConnection{conn: t.conn}
|
||||
return conn.Exec(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (t *SQLiteTransaction) Prepare(ctx context.Context, query string) (Statement, error) {
|
||||
conn := &SQLiteConnection{conn: t.conn}
|
||||
return conn.Prepare(ctx, query)
|
||||
}
|
||||
|
||||
// SQLiteRows implements the Rows interface
|
||||
type SQLiteRows struct {
|
||||
stmt *sqlite.Stmt
|
||||
hasNext bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Next() bool {
|
||||
if r.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !r.hasNext {
|
||||
return false
|
||||
}
|
||||
|
||||
var err error
|
||||
r.hasNext, err = r.stmt.Step()
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return false
|
||||
}
|
||||
|
||||
return r.hasNext
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Scan(dest ...any) error {
|
||||
if r.err != nil {
|
||||
return r.err
|
||||
}
|
||||
|
||||
for i, d := range dest {
|
||||
if i >= r.stmt.ColumnCount() {
|
||||
break
|
||||
}
|
||||
|
||||
switch ptr := d.(type) {
|
||||
case *any:
|
||||
*ptr = r.getValue(i)
|
||||
case *string:
|
||||
*ptr = r.stmt.ColumnText(i)
|
||||
case *int:
|
||||
*ptr = int(r.stmt.ColumnInt64(i))
|
||||
case *int64:
|
||||
*ptr = r.stmt.ColumnInt64(i)
|
||||
case *float64:
|
||||
*ptr = r.stmt.ColumnFloat(i)
|
||||
case *bool:
|
||||
*ptr = r.stmt.ColumnInt64(i) != 0
|
||||
case *[]byte:
|
||||
if r.stmt.ColumnType(i) == sqlite.TypeBlob {
|
||||
// Get blob size first
|
||||
size := r.stmt.ColumnBytes(i, nil)
|
||||
if size == 0 {
|
||||
*ptr = []byte{}
|
||||
} else {
|
||||
buf := make([]byte, size)
|
||||
r.stmt.ColumnBytes(i, buf)
|
||||
*ptr = buf
|
||||
}
|
||||
} else {
|
||||
// Convert text to bytes
|
||||
*ptr = []byte(r.stmt.ColumnText(i))
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("sqlite: unsupported scan destination type: %T", d)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) getValue(index int) any {
|
||||
switch r.stmt.ColumnType(index) {
|
||||
case sqlite.TypeInteger:
|
||||
return r.stmt.ColumnInt64(index)
|
||||
case sqlite.TypeFloat:
|
||||
return r.stmt.ColumnFloat(index)
|
||||
case sqlite.TypeText:
|
||||
return r.stmt.ColumnText(index)
|
||||
case sqlite.TypeBlob:
|
||||
// For blob columns, we need to handle this differently
|
||||
// First, get the size by calling with nil buffer
|
||||
size := r.stmt.ColumnBytes(index, nil)
|
||||
if size == 0 {
|
||||
return []byte{}
|
||||
}
|
||||
// Now allocate buffer and get the actual data
|
||||
buf := make([]byte, size)
|
||||
r.stmt.ColumnBytes(index, buf)
|
||||
return buf
|
||||
case sqlite.TypeNull:
|
||||
return nil
|
||||
default:
|
||||
return r.stmt.ColumnText(index)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Columns() ([]string, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
|
||||
columns := make([]string, r.stmt.ColumnCount())
|
||||
for i := range columns {
|
||||
columns[i] = r.stmt.ColumnName(i)
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Close() error {
|
||||
if r.stmt != nil {
|
||||
return r.stmt.Finalize()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SQLiteRows) Err() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
// SQLiteRow implements the Row interface
|
||||
type SQLiteRow struct {
|
||||
rows *SQLiteRows
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *SQLiteRow) Scan(dest ...any) error {
|
||||
if r.err != nil {
|
||||
return r.err
|
||||
}
|
||||
|
||||
if r.rows == nil {
|
||||
return fmt.Errorf("sqlite: no rows available")
|
||||
}
|
||||
|
||||
if !r.rows.Next() {
|
||||
if r.rows.Err() != nil {
|
||||
return r.rows.Err()
|
||||
}
|
||||
return fmt.Errorf("sqlite: no rows in result set")
|
||||
}
|
||||
|
||||
return r.rows.Scan(dest...)
|
||||
}
|
||||
|
||||
// SQLiteResult implements the Result interface
|
||||
type SQLiteResult struct {
|
||||
lastInsertID int64
|
||||
rowsAffected int
|
||||
}
|
||||
|
||||
func (r *SQLiteResult) LastInsertId() (int64, error) {
|
||||
return r.lastInsertID, nil
|
||||
}
|
||||
|
||||
func (r *SQLiteResult) RowsAffected() (int64, error) {
|
||||
return int64(r.rowsAffected), nil
|
||||
}
|
||||
|
||||
// SQLiteStatement implements the Statement interface
|
||||
type SQLiteStatement struct {
|
||||
stmt *sqlite.Stmt
|
||||
conn *sqlite.Conn
|
||||
}
|
||||
|
||||
func (s *SQLiteStatement) Close() error {
|
||||
return s.stmt.Finalize()
|
||||
}
|
||||
|
||||
func (s *SQLiteStatement) Query(ctx context.Context, args ...any) (Rows, error) {
|
||||
conn := &SQLiteConnection{conn: s.conn}
|
||||
if err := conn.bindArgs(s.stmt, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SQLiteRows{stmt: s.stmt, hasNext: true}, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStatement) QueryRow(ctx context.Context, args ...any) Row {
|
||||
rows, err := s.Query(ctx, args...)
|
||||
if err != nil {
|
||||
return &SQLiteRow{err: err}
|
||||
}
|
||||
return &SQLiteRow{rows: rows.(*SQLiteRows)}
|
||||
}
|
||||
|
||||
func (s *SQLiteStatement) Exec(ctx context.Context, args ...any) (Result, error) {
|
||||
conn := &SQLiteConnection{conn: s.conn}
|
||||
if err := conn.bindArgs(s.stmt, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hasRow, err := s.stmt.Step()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: failed to execute statement: %w", err)
|
||||
}
|
||||
|
||||
// Consume all rows if any
|
||||
for hasRow {
|
||||
hasRow, err = s.stmt.Step()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite: error stepping through results: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &SQLiteResult{
|
||||
lastInsertID: s.conn.LastInsertRowID(),
|
||||
rowsAffected: s.conn.Changes(),
|
||||
}, nil
|
||||
}
|
||||
@ -1,502 +0,0 @@
|
||||
local sqlite = {}
|
||||
|
||||
local Connection = {}
|
||||
Connection.__index = Connection
|
||||
|
||||
function Connection:close()
|
||||
if self._id then
|
||||
local ok = moonshark.sql_close(self._id)
|
||||
self._id = nil
|
||||
return ok
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
function Connection:ping()
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_ping(self._id)
|
||||
end
|
||||
|
||||
function Connection:query(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_query(self._id, query_str:normalize_whitespace(), ...)
|
||||
end
|
||||
|
||||
function Connection:exec(query_str, ...)
|
||||
if not self._id then
|
||||
error("Connection is closed")
|
||||
end
|
||||
return moonshark.sql_exec(self._id, query_str:normalize_whitespace(), ...)
|
||||
end
|
||||
|
||||
function Connection:query_row(query_str, ...)
|
||||
local results = self:query(query_str, ...)
|
||||
return results and #results > 0 and results[1] or nil
|
||||
end
|
||||
|
||||
function Connection:query_value(query_str, ...)
|
||||
local row = self:query_row(query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function Connection:begin()
|
||||
local result = self:exec("BEGIN")
|
||||
if result then
|
||||
return {
|
||||
conn = self,
|
||||
active = true,
|
||||
commit = function(tx)
|
||||
if tx.active then
|
||||
tx.active = false
|
||||
return tx.conn:exec("COMMIT")
|
||||
end
|
||||
return false
|
||||
end,
|
||||
rollback = function(tx)
|
||||
if tx.active then
|
||||
tx.active = false
|
||||
return tx.conn:exec("ROLLBACK")
|
||||
end
|
||||
return false
|
||||
end,
|
||||
query = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:query(query_str, ...)
|
||||
end,
|
||||
exec = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:exec(query_str, ...)
|
||||
end,
|
||||
query_row = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:query_row(query_str, ...)
|
||||
end,
|
||||
query_value = function(tx, query_str, ...)
|
||||
if not tx.active then error("Transaction is not active") end
|
||||
return tx.conn:query_value(query_str, ...)
|
||||
end
|
||||
}
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function Connection:insert(table_name, data)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys = table.keys(data)
|
||||
local values = table.values(data)
|
||||
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
|
||||
|
||||
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}})":parse({
|
||||
table = table_name,
|
||||
columns = keys:join(", "),
|
||||
placeholders = placeholders
|
||||
})
|
||||
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:upsert(table_name, data, conflict_columns)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
local keys = table.keys(data)
|
||||
local values = table.values(data)
|
||||
local placeholders = string.repeat_("?, ", #keys):trim_right(", ")
|
||||
local updates = table.map(keys, function(key) return key .. " = excluded." .. key end):join(", ")
|
||||
|
||||
local conflict_clause = ""
|
||||
if conflict_columns then
|
||||
if type(conflict_columns) == "string" then
|
||||
conflict_clause = "(" .. conflict_columns .. ")"
|
||||
else
|
||||
conflict_clause = "(" .. table.concat(conflict_columns, ", ") .. ")"
|
||||
end
|
||||
end
|
||||
|
||||
local query = "INSERT INTO {{table}} ({{columns}}) VALUES ({{placeholders}}) ON CONFLICT {{conflict}} DO UPDATE SET {{updates}}":parse({
|
||||
table = table_name,
|
||||
columns = keys:join(", "),
|
||||
placeholders = placeholders,
|
||||
conflict = conflict_clause,
|
||||
updates = updates
|
||||
})
|
||||
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:update(table_name, data, where_clause, ...)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if where_clause:is_blank() then
|
||||
error("WHERE clause cannot be empty for UPDATE")
|
||||
end
|
||||
|
||||
local keys = table.keys(data)
|
||||
local values = table.values(data)
|
||||
local sets = table.map(keys, function(key) return key .. " = ?" end):join(", ")
|
||||
|
||||
local query = "UPDATE {{table}} SET {{sets}} WHERE {{where}}":parse({
|
||||
table = table_name,
|
||||
sets = sets,
|
||||
where = where_clause
|
||||
})
|
||||
|
||||
table.extend(values, {...})
|
||||
return self:exec(query, unpack(values))
|
||||
end
|
||||
|
||||
function Connection:delete(table_name, where_clause, ...)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
if where_clause:is_blank() then
|
||||
error("WHERE clause cannot be empty for DELETE")
|
||||
end
|
||||
|
||||
local query = "DELETE FROM {{table}} WHERE {{where}}":parse({
|
||||
table = table_name,
|
||||
where = where_clause
|
||||
})
|
||||
return self:exec(query, ...)
|
||||
end
|
||||
|
||||
function Connection:select(table_name, columns, where_clause, ...)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
|
||||
columns = columns or "*"
|
||||
if type(columns) == "table" then
|
||||
columns = table.concat(columns, ", ")
|
||||
end
|
||||
|
||||
if where_clause and not where_clause:is_blank() then
|
||||
local query = "SELECT {{columns}} FROM {{table}} WHERE {{where}}":parse({
|
||||
columns = columns,
|
||||
table = table_name,
|
||||
where = where_clause
|
||||
})
|
||||
return self:query(query, ...)
|
||||
else
|
||||
local query = "SELECT {{columns}} FROM {{table}}":parse({
|
||||
columns = columns,
|
||||
table = table_name
|
||||
})
|
||||
return self:query(query)
|
||||
end
|
||||
end
|
||||
|
||||
function Connection:table_exists(table_name)
|
||||
if table_name:is_blank() then return false end
|
||||
return self:query_value("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table_name:trim()) ~= nil
|
||||
end
|
||||
|
||||
function Connection:column_exists(table_name, column_name)
|
||||
if table_name:is_blank() or column_name:is_blank() then return false end
|
||||
|
||||
local result = self:query("PRAGMA table_info({{table}})":parse({table = table_name}))
|
||||
if result then
|
||||
return table.any(result, function(row)
|
||||
return row.name:iequals(column_name:trim())
|
||||
end)
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
function Connection:create_table(table_name, schema)
|
||||
if table_name:is_blank() or schema:is_blank() then
|
||||
error("Table name and schema cannot be empty")
|
||||
end
|
||||
|
||||
local query = "CREATE TABLE IF NOT EXISTS {{table}} ({{schema}})":parse({
|
||||
table = table_name,
|
||||
schema = schema:trim()
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:drop_table(table_name)
|
||||
if table_name:is_blank() then
|
||||
error("Table name cannot be empty")
|
||||
end
|
||||
return self:exec("DROP TABLE IF EXISTS {{table}}":parse({table = table_name}))
|
||||
end
|
||||
|
||||
function Connection:add_column(table_name, column_def)
|
||||
if table_name:is_blank() or column_def:is_blank() then
|
||||
error("Table name and column definition cannot be empty")
|
||||
end
|
||||
|
||||
local query = "ALTER TABLE {{table}} ADD COLUMN {{column}}":parse({
|
||||
table = table_name,
|
||||
column = column_def:trim()
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:create_index(index_name, table_name, columns, unique)
|
||||
if index_name:is_blank() or table_name:is_blank() then
|
||||
error("Index name and table name cannot be empty")
|
||||
end
|
||||
|
||||
local unique_clause = unique and "UNIQUE " or ""
|
||||
local columns_str = type(columns) == "table" and table.concat(columns, ", ") or tostring(columns)
|
||||
|
||||
local query = "CREATE {{unique}}INDEX IF NOT EXISTS {{index}} ON {{table}} ({{columns}})":parse({
|
||||
unique = unique_clause,
|
||||
index = index_name,
|
||||
table = table_name,
|
||||
columns = columns_str
|
||||
})
|
||||
return self:exec(query)
|
||||
end
|
||||
|
||||
function Connection:drop_index(index_name)
|
||||
if index_name:is_blank() then
|
||||
error("Index name cannot be empty")
|
||||
end
|
||||
return self:exec("DROP INDEX IF EXISTS {{index}}":parse({index = index_name}))
|
||||
end
|
||||
|
||||
-- SQLite-specific functions
|
||||
function Connection:vacuum()
|
||||
return self:exec("VACUUM")
|
||||
end
|
||||
|
||||
function Connection:analyze()
|
||||
return self:exec("ANALYZE")
|
||||
end
|
||||
|
||||
function Connection:integrity_check()
|
||||
return self:query("PRAGMA integrity_check")
|
||||
end
|
||||
|
||||
function Connection:foreign_keys(enabled)
|
||||
local value = enabled and "ON" or "OFF"
|
||||
return self:exec("PRAGMA foreign_keys = {{value}}":parse({value = value}))
|
||||
end
|
||||
|
||||
function Connection:journal_mode(mode)
|
||||
mode = (mode or "WAL"):upper()
|
||||
local valid_modes = {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}
|
||||
|
||||
if not table.contains(valid_modes, mode) then
|
||||
error("Invalid journal mode: " .. mode)
|
||||
end
|
||||
|
||||
return self:query("PRAGMA journal_mode = {{mode}}":parse({mode = mode}))
|
||||
end
|
||||
|
||||
function Connection:synchronous(level)
|
||||
level = (level or "NORMAL"):upper()
|
||||
local valid_levels = {"OFF", "NORMAL", "FULL", "EXTRA"}
|
||||
|
||||
if not table.contains(valid_levels, level) then
|
||||
error("Invalid synchronous level: " .. level)
|
||||
end
|
||||
|
||||
return self:exec("PRAGMA synchronous = {{level}}":parse({level = level}))
|
||||
end
|
||||
|
||||
function Connection:cache_size(size)
|
||||
size = size or -64000
|
||||
if type(size) ~= "number" then
|
||||
error("Cache size must be a number")
|
||||
end
|
||||
return self:exec("PRAGMA cache_size = {{size}}":parse({size = tostring(size)}))
|
||||
end
|
||||
|
||||
function Connection:temp_store(mode)
|
||||
mode = (mode or "MEMORY"):upper()
|
||||
local valid_modes = {"DEFAULT", "FILE", "MEMORY"}
|
||||
|
||||
if not table.contains(valid_modes, mode) then
|
||||
error("Invalid temp_store mode: " .. mode)
|
||||
end
|
||||
|
||||
return self:exec("PRAGMA temp_store = {{mode}}":parse({mode = mode}))
|
||||
end
|
||||
|
||||
-- Connection management
|
||||
function sqlite.open(database_path)
|
||||
database_path = database_path or ":memory:"
|
||||
if database_path ~= ":memory:" and database_path:is_blank() then
|
||||
database_path = ":memory:"
|
||||
end
|
||||
|
||||
local conn_id = moonshark.sql_connect("sqlite", database_path:trim())
|
||||
if conn_id then
|
||||
return setmetatable({_id = conn_id}, Connection)
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
sqlite.connect = sqlite.open
|
||||
|
||||
-- Quick execution functions
|
||||
function sqlite.query(database_path, query_str, ...)
|
||||
local conn = sqlite.open(database_path)
|
||||
if not conn then
|
||||
error("Failed to open SQLite database: {{path}}":parse({path = database_path or ":memory:"}))
|
||||
end
|
||||
|
||||
local results = conn:query(query_str, ...)
|
||||
conn:close()
|
||||
return results
|
||||
end
|
||||
|
||||
function sqlite.exec(database_path, query_str, ...)
|
||||
local conn = sqlite.open(database_path)
|
||||
if not conn then
|
||||
error("Failed to open SQLite database: {{path}}":parse({path = database_path or ":memory:"}))
|
||||
end
|
||||
|
||||
local result = conn:exec(query_str, ...)
|
||||
conn:close()
|
||||
return result
|
||||
end
|
||||
|
||||
function sqlite.query_row(database_path, query_str, ...)
|
||||
local results = sqlite.query(database_path, query_str, ...)
|
||||
return results and #results > 0 and results[1] or nil
|
||||
end
|
||||
|
||||
function sqlite.query_value(database_path, query_str, ...)
|
||||
local row = sqlite.query_row(database_path, query_str, ...)
|
||||
if row then
|
||||
for _, value in pairs(row) do
|
||||
return value
|
||||
end
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Migration helpers
|
||||
function sqlite.migrate(database_path, migrations)
|
||||
local conn = sqlite.open(database_path)
|
||||
if not conn then
|
||||
error("Failed to open SQLite database for migration")
|
||||
end
|
||||
|
||||
conn:create_table("_migrations", "id INTEGER PRIMARY KEY, name TEXT UNIQUE, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP")
|
||||
|
||||
local tx = conn:begin()
|
||||
if not tx then
|
||||
conn:close()
|
||||
error("Failed to begin migration transaction")
|
||||
end
|
||||
|
||||
for _, migration in ipairs(migrations) do
|
||||
if not migration.name or migration.name:is_blank() then
|
||||
tx:rollback()
|
||||
conn:close()
|
||||
error("Migration must have a non-empty name")
|
||||
end
|
||||
|
||||
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?", migration.name:trim())
|
||||
if not existing then
|
||||
local ok, err = pcall(function()
|
||||
if type(migration.up) == "string" then
|
||||
conn:exec(migration.up)
|
||||
elseif type(migration.up) == "function" then
|
||||
migration.up(conn)
|
||||
else
|
||||
error("Migration 'up' must be string or function")
|
||||
end
|
||||
end)
|
||||
|
||||
if ok then
|
||||
conn:exec("INSERT INTO _migrations (name) VALUES (?)", migration.name:trim())
|
||||
print("Applied migration: {{name}}":parse({name = migration.name}))
|
||||
else
|
||||
tx:rollback()
|
||||
conn:close()
|
||||
error("Migration '{{name}}' failed: {{error}}":parse({
|
||||
name = migration.name,
|
||||
error = err or "unknown error"
|
||||
}))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
tx:commit()
|
||||
conn:close()
|
||||
return true
|
||||
end
|
||||
|
||||
-- Result processing utilities
|
||||
function sqlite.to_array(results, column_name)
|
||||
if not results or table.is_empty(results) then return {} end
|
||||
if column_name:is_blank() then error("Column name cannot be empty") end
|
||||
return table.map(results, function(row) return row[column_name] end)
|
||||
end
|
||||
|
||||
function sqlite.to_map(results, key_column, value_column)
|
||||
if not results or table.is_empty(results) then return {} end
|
||||
if key_column:is_blank() then error("Key column name cannot be empty") end
|
||||
|
||||
local map = {}
|
||||
for _, row in ipairs(results) do
|
||||
local key = row[key_column]
|
||||
map[key] = value_column and row[value_column] or row
|
||||
end
|
||||
return map
|
||||
end
|
||||
|
||||
function sqlite.group_by(results, column_name)
|
||||
if not results or table.is_empty(results) then return {} end
|
||||
if column_name:is_blank() then error("Column name cannot be empty") end
|
||||
return table.group_by(results, function(row) return row[column_name] end)
|
||||
end
|
||||
|
||||
function sqlite.print_results(results)
|
||||
if not results or table.is_empty(results) then
|
||||
print("No results")
|
||||
return
|
||||
end
|
||||
|
||||
local columns = table.keys(results[1])
|
||||
table.sort(columns)
|
||||
|
||||
-- Calculate column widths
|
||||
local widths = {}
|
||||
for _, col in ipairs(columns) do
|
||||
widths[col] = col:length()
|
||||
for _, row in ipairs(results) do
|
||||
local value = tostring(row[col] or "")
|
||||
widths[col] = math.max(widths[col], value:length())
|
||||
end
|
||||
end
|
||||
|
||||
-- Print header and separator
|
||||
local header_parts = table.map(columns, function(col) return col:pad_right(widths[col]) end)
|
||||
local separator_parts = table.map(columns, function(col) return string.repeat_("-", widths[col]) end)
|
||||
|
||||
print(table.concat(header_parts, " | "))
|
||||
print(table.concat(separator_parts, "-+-"))
|
||||
|
||||
-- Print rows
|
||||
for _, row in ipairs(results) do
|
||||
local value_parts = table.map(columns, function(col)
|
||||
local value = tostring(row[col] or "")
|
||||
return value:pad_right(widths[col])
|
||||
end)
|
||||
print(table.concat(value_parts, " | "))
|
||||
end
|
||||
end
|
||||
|
||||
return sqlite
|
||||
@ -1,113 +0,0 @@
|
||||
package string
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"unicode/utf8"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
func GetFunctionList() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"string_slice": string_slice,
|
||||
"string_reverse": string_reverse,
|
||||
"string_length": string_length,
|
||||
"string_byte_length": string_byte_length,
|
||||
"regex_replace": regex_replace,
|
||||
"string_is_valid_utf8": string_is_valid_utf8,
|
||||
}
|
||||
}
|
||||
|
||||
func string_slice(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
start := int(s.ToNumber(2))
|
||||
|
||||
if !utf8.ValidString(str) {
|
||||
s.PushNil()
|
||||
s.PushString("invalid UTF-8")
|
||||
return 2
|
||||
}
|
||||
|
||||
runes := []rune(str)
|
||||
length := len(runes)
|
||||
startIdx := max(start-1, 0) // Convert from 1-indexed
|
||||
if startIdx >= length {
|
||||
s.PushString("")
|
||||
return 1
|
||||
}
|
||||
|
||||
endIdx := length
|
||||
if s.GetTop() >= 3 && !s.IsNil(3) {
|
||||
end := int(s.ToNumber(3))
|
||||
if end < 0 {
|
||||
endIdx = length + end + 1
|
||||
} else {
|
||||
endIdx = end
|
||||
}
|
||||
if endIdx < 0 {
|
||||
endIdx = 0
|
||||
}
|
||||
if endIdx > length {
|
||||
endIdx = length
|
||||
}
|
||||
}
|
||||
|
||||
if startIdx >= endIdx {
|
||||
s.PushString("")
|
||||
return 1
|
||||
}
|
||||
|
||||
s.PushString(string(runes[startIdx:endIdx]))
|
||||
return 1
|
||||
}
|
||||
|
||||
func string_reverse(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
|
||||
if !utf8.ValidString(str) {
|
||||
s.PushNil()
|
||||
s.PushString("invalid UTF-8")
|
||||
return 2
|
||||
}
|
||||
|
||||
runes := []rune(str)
|
||||
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
|
||||
runes[i], runes[j] = runes[j], runes[i]
|
||||
}
|
||||
s.PushString(string(runes))
|
||||
return 1
|
||||
}
|
||||
|
||||
func string_length(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
s.PushNumber(float64(utf8.RuneCountInString(str)))
|
||||
return 1
|
||||
}
|
||||
|
||||
func string_byte_length(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
s.PushNumber(float64(len(str)))
|
||||
return 1
|
||||
}
|
||||
|
||||
func regex_replace(s *luajit.State) int {
|
||||
pattern := s.ToString(1)
|
||||
str := s.ToString(2)
|
||||
replacement := s.ToString(3)
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
s.PushString(str)
|
||||
return 1
|
||||
}
|
||||
|
||||
result := re.ReplaceAllString(str, replacement)
|
||||
s.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
func string_is_valid_utf8(s *luajit.State) int {
|
||||
str := s.ToString(1)
|
||||
s.PushBoolean(utf8.ValidString(str))
|
||||
return 1
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
168
moonshark.go
168
moonshark.go
@ -1,168 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"Moonshark/modules/http"
|
||||
"Moonshark/modules/sql"
|
||||
"Moonshark/state"
|
||||
)
|
||||
|
||||
var (
|
||||
watchFlag = flag.Bool("watch", false, "Watch script files for changes and restart")
|
||||
wFlag = flag.Bool("w", false, "Watch script files for changes and restart")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if flag.NArg() < 1 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [--watch|-w] <script.lua>\n", filepath.Base(os.Args[0]))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
scriptPath := flag.Arg(0)
|
||||
watchMode := *watchFlag || *wFlag
|
||||
|
||||
if watchMode {
|
||||
runWithWatcher(scriptPath)
|
||||
} else {
|
||||
runOnce(scriptPath)
|
||||
}
|
||||
}
|
||||
|
||||
func runOnce(scriptPath string) {
|
||||
luaState, err := state.NewFromScript(scriptPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer luaState.Close()
|
||||
|
||||
if err := luaState.ExecuteFile(scriptPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if http.HasActiveServers() {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
fmt.Println("HTTP servers running. Press Ctrl+C to exit.")
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
fmt.Println("\nShutting down...")
|
||||
|
||||
// Close main state first (saves KV stores)
|
||||
luaState.Close()
|
||||
// Then stop servers (closes worker states)
|
||||
http.StopAllServers()
|
||||
sql.CloseAllConnections()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
http.WaitForServers()
|
||||
}
|
||||
}
|
||||
|
||||
func runWithWatcher(scriptPath string) {
|
||||
watcher, err := NewFileWatcher(500) // 500ms debounce
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create file watcher: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
if err := watcher.DiscoverRequiredFiles(scriptPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to watch files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
restartCh := watcher.Start()
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
fmt.Printf("Starting %s in watch mode...\n", scriptPath)
|
||||
|
||||
var hadError bool
|
||||
firstRun := true
|
||||
|
||||
for {
|
||||
// If we had an error on the last run and this isn't the first run,
|
||||
// wait for file changes before retrying
|
||||
if hadError && !firstRun {
|
||||
fmt.Println("Waiting for file changes before retrying...")
|
||||
select {
|
||||
case <-restartCh:
|
||||
fmt.Println("Files changed, retrying...")
|
||||
case <-sigChan:
|
||||
fmt.Println("\nExiting...")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
firstRun = false
|
||||
hadError = false
|
||||
|
||||
// Clear cache before each run
|
||||
state.ClearCache()
|
||||
|
||||
// Create and run state
|
||||
luaState, err := state.NewFromScript(scriptPath)
|
||||
if err != nil {
|
||||
log.Printf("Error creating state: %v", err)
|
||||
hadError = true
|
||||
continue
|
||||
}
|
||||
|
||||
if err := luaState.ExecuteFile(scriptPath); err != nil {
|
||||
log.Printf("Execution error: %v", err)
|
||||
luaState.Close()
|
||||
hadError = true
|
||||
continue
|
||||
}
|
||||
|
||||
// If not a long-running process, wait for changes and restart
|
||||
if !http.HasActiveServers() {
|
||||
fmt.Println("Script completed. Waiting for changes...")
|
||||
luaState.Close()
|
||||
|
||||
select {
|
||||
case <-restartCh:
|
||||
fmt.Println("Files changed, restarting...")
|
||||
continue
|
||||
case <-sigChan:
|
||||
fmt.Println("\nExiting...")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Long-running process - wait for restart signal or exit signal
|
||||
fmt.Println("HTTP servers running. Watching for file changes...")
|
||||
|
||||
select {
|
||||
case <-restartCh:
|
||||
fmt.Println("Files changed, restarting...")
|
||||
http.StopAllServers()
|
||||
luaState.Close()
|
||||
sql.CloseAllConnections()
|
||||
time.Sleep(100 * time.Millisecond) // Brief pause for cleanup
|
||||
continue
|
||||
|
||||
case <-sigChan:
|
||||
fmt.Println("\nShutting down...")
|
||||
http.StopAllServers()
|
||||
luaState.Close()
|
||||
sql.CloseAllConnections()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
90
router/build.go
Normal file
90
router/build.go
Normal file
@ -0,0 +1,90 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// buildRoutes scans the routes directory and builds the routing tree
|
||||
func (r *LuaRouter) buildRoutes() error {
|
||||
r.failedRoutes = make(map[string]*RouteError)
|
||||
r.middlewareFiles = make(map[string][]string)
|
||||
|
||||
// First pass: collect all middleware files
|
||||
err := filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.TrimSuffix(info.Name(), ".lua") == "middleware" {
|
||||
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fsPath := "/"
|
||||
if relDir != "." {
|
||||
fsPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
|
||||
}
|
||||
|
||||
// Use filesystem path for middleware (includes groups)
|
||||
r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Second pass: build routes with combined middleware + handler
|
||||
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
|
||||
return err
|
||||
}
|
||||
|
||||
fileName := strings.TrimSuffix(info.Name(), ".lua")
|
||||
|
||||
// Skip middleware files (already processed)
|
||||
if fileName == "middleware" {
|
||||
return nil
|
||||
}
|
||||
|
||||
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fsPath := "/"
|
||||
if relDir != "." {
|
||||
fsPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
|
||||
}
|
||||
|
||||
pathInfo := parsePathWithGroups(fsPath)
|
||||
|
||||
// Handle index.lua files
|
||||
if fileName == "index" {
|
||||
for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"} {
|
||||
root := r.routes[method]
|
||||
node := r.findOrCreateNode(root, pathInfo.urlPath)
|
||||
node.indexFile = path
|
||||
node.modTime = info.ModTime()
|
||||
node.fsPath = pathInfo.fsPath
|
||||
r.compileWithMiddleware(node, pathInfo.fsPath, path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle method files
|
||||
method := strings.ToUpper(fileName)
|
||||
root, exists := r.routes[method]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.addRoute(root, pathInfo, path, info.ModTime())
|
||||
return nil
|
||||
})
|
||||
}
|
||||
62
router/cache.go
Normal file
62
router/cache.go
Normal file
@ -0,0 +1,62 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"time"
|
||||
|
||||
"github.com/VictoriaMetrics/fastcache"
|
||||
)
|
||||
|
||||
// hashString generates a hash for a string
|
||||
func hashString(s string) uint64 {
|
||||
h := fnv.New64a()
|
||||
h.Write([]byte(s))
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// uint64ToBytes converts a uint64 to bytes for cache key
|
||||
func uint64ToBytes(n uint64) []byte {
|
||||
b := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(b, n)
|
||||
return b
|
||||
}
|
||||
|
||||
// getCacheKey generates a cache key for a method and path
|
||||
func getCacheKey(method, path string) []byte {
|
||||
key := hashString(method + ":" + path)
|
||||
return uint64ToBytes(key)
|
||||
}
|
||||
|
||||
// getBytecodeKey generates a cache key for a handler path
|
||||
func getBytecodeKey(handlerPath string) []byte {
|
||||
key := hashString(handlerPath)
|
||||
return uint64ToBytes(key)
|
||||
}
|
||||
|
||||
// ClearCache clears all caches
|
||||
func (r *LuaRouter) ClearCache() {
|
||||
r.routeCache.Reset()
|
||||
r.bytecodeCache.Reset()
|
||||
r.middlewareCache = make(map[string][]byte)
|
||||
r.sourceCache = make(map[string][]byte)
|
||||
r.sourceMtimes = make(map[string]time.Time)
|
||||
}
|
||||
|
||||
// GetCacheStats returns statistics about the cache
|
||||
func (r *LuaRouter) GetCacheStats() map[string]any {
|
||||
var routeStats fastcache.Stats
|
||||
var bytecodeStats fastcache.Stats
|
||||
|
||||
r.routeCache.UpdateStats(&routeStats)
|
||||
r.bytecodeCache.UpdateStats(&bytecodeStats)
|
||||
|
||||
return map[string]any{
|
||||
"routeEntries": routeStats.EntriesCount,
|
||||
"routeBytes": routeStats.BytesSize,
|
||||
"routeCollisions": routeStats.Collisions,
|
||||
"bytecodeEntries": bytecodeStats.EntriesCount,
|
||||
"bytecodeBytes": bytecodeStats.BytesSize,
|
||||
"bytecodeCollisions": bytecodeStats.Collisions,
|
||||
}
|
||||
}
|
||||
176
router/compile.go
Normal file
176
router/compile.go
Normal file
@ -0,0 +1,176 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// compileWithMiddleware combines middleware and handler source, then compiles
|
||||
func (r *LuaRouter) compileWithMiddleware(n *node, fsPath, scriptPath string) error {
|
||||
if scriptPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we need to recompile by comparing modification times
|
||||
sourceKey := r.getSourceCacheKey(fsPath, scriptPath)
|
||||
needsRecompile := false
|
||||
|
||||
// Check handler modification time
|
||||
handlerInfo, err := os.Stat(scriptPath)
|
||||
if err != nil {
|
||||
n.err = err
|
||||
return err
|
||||
}
|
||||
|
||||
lastCompiled, exists := r.sourceMtimes[sourceKey]
|
||||
if !exists || handlerInfo.ModTime().After(lastCompiled) {
|
||||
needsRecompile = true
|
||||
}
|
||||
|
||||
// Check middleware modification times
|
||||
if !needsRecompile {
|
||||
middlewareChain := r.getMiddlewareChain(fsPath)
|
||||
for _, mwPath := range middlewareChain {
|
||||
mwInfo, err := os.Stat(mwPath)
|
||||
if err != nil {
|
||||
n.err = err
|
||||
return err
|
||||
}
|
||||
if mwInfo.ModTime().After(lastCompiled) {
|
||||
needsRecompile = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use cached bytecode if available and fresh
|
||||
if !needsRecompile {
|
||||
if bytecode, exists := r.sourceCache[sourceKey]; exists {
|
||||
bytecodeKey := getBytecodeKey(scriptPath)
|
||||
r.bytecodeCache.Set(bytecodeKey, bytecode)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Build combined source
|
||||
combinedSource, err := r.buildCombinedSource(fsPath, scriptPath)
|
||||
if err != nil {
|
||||
n.err = err
|
||||
return err
|
||||
}
|
||||
|
||||
// Compile combined source using shared state
|
||||
r.compileStateMu.Lock()
|
||||
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
|
||||
r.compileStateMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
n.err = err
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache everything
|
||||
bytecodeKey := getBytecodeKey(scriptPath)
|
||||
r.bytecodeCache.Set(bytecodeKey, bytecode)
|
||||
r.sourceCache[sourceKey] = bytecode
|
||||
r.sourceMtimes[sourceKey] = time.Now()
|
||||
|
||||
n.err = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildCombinedSource builds the combined middleware + handler source
|
||||
func (r *LuaRouter) buildCombinedSource(fsPath, scriptPath string) (string, error) {
|
||||
var combinedSource strings.Builder
|
||||
|
||||
// Get middleware chain using filesystem path
|
||||
middlewareChain := r.getMiddlewareChain(fsPath)
|
||||
|
||||
// Add middleware in order
|
||||
for _, mwPath := range middlewareChain {
|
||||
content, err := r.getFileContent(mwPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
combinedSource.WriteString("-- Middleware: ")
|
||||
combinedSource.WriteString(mwPath)
|
||||
combinedSource.WriteString("\n")
|
||||
combinedSource.Write(content)
|
||||
combinedSource.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add main handler
|
||||
content, err := r.getFileContent(scriptPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
combinedSource.WriteString("-- Handler: ")
|
||||
combinedSource.WriteString(scriptPath)
|
||||
combinedSource.WriteString("\n")
|
||||
combinedSource.Write(content)
|
||||
|
||||
return combinedSource.String(), nil
|
||||
}
|
||||
|
||||
// getFileContent reads file content with caching
|
||||
func (r *LuaRouter) getFileContent(path string) ([]byte, error) {
|
||||
// Check cache first
|
||||
if content, exists := r.middlewareCache[path]; exists {
|
||||
// Verify file hasn't changed
|
||||
info, err := os.Stat(path)
|
||||
if err == nil {
|
||||
if cachedTime, exists := r.sourceMtimes[path]; exists && !info.ModTime().After(cachedTime) {
|
||||
return content, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read from disk
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache it
|
||||
r.middlewareCache[path] = content
|
||||
r.sourceMtimes[path] = time.Now()
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// getSourceCacheKey generates a unique key for combined source
|
||||
func (r *LuaRouter) getSourceCacheKey(fsPath, scriptPath string) string {
|
||||
middlewareChain := r.getMiddlewareChain(fsPath)
|
||||
var keyParts []string
|
||||
keyParts = append(keyParts, middlewareChain...)
|
||||
keyParts = append(keyParts, scriptPath)
|
||||
return strings.Join(keyParts, "|")
|
||||
}
|
||||
|
||||
// getMiddlewareChain returns middleware files that apply to the given filesystem path
|
||||
func (r *LuaRouter) getMiddlewareChain(fsPath string) []string {
|
||||
var chain []string
|
||||
|
||||
// Collect middleware from root to specific path using filesystem path (includes groups)
|
||||
pathParts := strings.Split(strings.Trim(fsPath, "/"), "/")
|
||||
if pathParts[0] == "" {
|
||||
pathParts = []string{}
|
||||
}
|
||||
|
||||
// Add root middleware
|
||||
if mw, exists := r.middlewareFiles["/"]; exists {
|
||||
chain = append(chain, mw...)
|
||||
}
|
||||
|
||||
// Add middleware from each path level (including groups)
|
||||
currentPath := ""
|
||||
for _, part := range pathParts {
|
||||
currentPath += "/" + part
|
||||
if mw, exists := r.middlewareFiles[currentPath]; exists {
|
||||
chain = append(chain, mw...)
|
||||
}
|
||||
}
|
||||
|
||||
return chain
|
||||
}
|
||||
25
router/errors.go
Normal file
25
router/errors.go
Normal file
@ -0,0 +1,25 @@
|
||||
package router
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrRoutesCompilationErrors indicates that some routes failed to compile
|
||||
// but the router is still operational
|
||||
ErrRoutesCompilationErrors = errors.New("some routes failed to compile")
|
||||
)
|
||||
|
||||
// RouteError represents an error with a specific route
|
||||
type RouteError struct {
|
||||
Path string // The URL path
|
||||
Method string // HTTP method
|
||||
ScriptPath string // Path to the Lua script
|
||||
Err error // The actual error
|
||||
}
|
||||
|
||||
// Error returns the error message
|
||||
func (re *RouteError) Error() string {
|
||||
if re.Err == nil {
|
||||
return "unknown route error"
|
||||
}
|
||||
return re.Err.Error()
|
||||
}
|
||||
187
router/match.go
Normal file
187
router/match.go
Normal file
@ -0,0 +1,187 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Match finds a handler for the given method and path (URL path, excludes groups)
|
||||
func (r *LuaRouter) Match(method, path string, params *Params) (*node, bool) {
|
||||
params.Count = 0
|
||||
|
||||
r.mu.RLock()
|
||||
root, exists := r.routes[method]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
segments := strings.Split(strings.Trim(path, "/"), "/")
|
||||
return r.matchPath(root, segments, params, 0)
|
||||
}
|
||||
|
||||
// matchPath recursively matches a path against the routing tree
|
||||
func (r *LuaRouter) matchPath(current *node, segments []string, params *Params, depth int) (*node, bool) {
|
||||
// Filter empty segments
|
||||
filteredSegments := segments[:0]
|
||||
for _, segment := range segments {
|
||||
if segment != "" {
|
||||
filteredSegments = append(filteredSegments, segment)
|
||||
}
|
||||
}
|
||||
segments = filteredSegments
|
||||
|
||||
if len(segments) == 0 {
|
||||
if current.handler != "" {
|
||||
return current, true
|
||||
}
|
||||
if current.indexFile != "" {
|
||||
return current, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
segment := segments[0]
|
||||
remaining := segments[1:]
|
||||
|
||||
// Try static child first
|
||||
if child, exists := current.staticChild[segment]; exists {
|
||||
if node, found := r.matchPath(child, remaining, params, depth+1); found {
|
||||
return node, true
|
||||
}
|
||||
}
|
||||
|
||||
// Try parameter child
|
||||
if current.paramChild != nil {
|
||||
if params.Count < maxParams {
|
||||
params.Keys[params.Count] = current.paramChild.paramName
|
||||
params.Values[params.Count] = segment
|
||||
params.Count++
|
||||
}
|
||||
|
||||
if node, found := r.matchPath(current.paramChild, remaining, params, depth+1); found {
|
||||
return node, true
|
||||
}
|
||||
|
||||
params.Count--
|
||||
}
|
||||
|
||||
// Fall back to index.lua
|
||||
if current.indexFile != "" {
|
||||
return current, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// GetRouteInfo returns bytecode, script path, error, params, and found status
|
||||
func (r *LuaRouter) GetRouteInfo(method, path []byte) ([]byte, string, error, *Params, bool) {
|
||||
// Convert to string for internal processing
|
||||
methodStr := string(method)
|
||||
pathStr := string(path)
|
||||
|
||||
routeCacheKey := getCacheKey(methodStr, pathStr)
|
||||
routeCacheData := r.routeCache.Get(nil, routeCacheKey)
|
||||
|
||||
// Fast path: found in cache
|
||||
if len(routeCacheData) > 0 {
|
||||
handlerPath := string(routeCacheData[8:])
|
||||
bytecodeKey := routeCacheData[:8]
|
||||
|
||||
bytecode := r.bytecodeCache.Get(nil, bytecodeKey)
|
||||
|
||||
n, exists := r.nodeForHandler(handlerPath)
|
||||
if !exists {
|
||||
r.routeCache.Del(routeCacheKey)
|
||||
return nil, "", nil, nil, false
|
||||
}
|
||||
|
||||
// Check if recompilation needed
|
||||
if len(bytecode) > 0 {
|
||||
// For cached routes, we need to re-match to get params
|
||||
params := &Params{}
|
||||
r.Match(methodStr, pathStr, params)
|
||||
return bytecode, handlerPath, n.err, params, true
|
||||
}
|
||||
|
||||
// Recompile if needed
|
||||
fileInfo, err := os.Stat(handlerPath)
|
||||
if err != nil || fileInfo.ModTime().After(n.modTime) {
|
||||
scriptPath := n.handler
|
||||
if scriptPath == "" {
|
||||
scriptPath = n.indexFile
|
||||
}
|
||||
|
||||
fsPath := n.fsPath
|
||||
if fsPath == "" {
|
||||
fsPath = "/"
|
||||
}
|
||||
|
||||
if err := r.compileWithMiddleware(n, fsPath, scriptPath); err != nil {
|
||||
params := &Params{}
|
||||
r.Match(methodStr, pathStr, params)
|
||||
return nil, handlerPath, n.err, params, true
|
||||
}
|
||||
|
||||
newBytecodeKey := getBytecodeKey(handlerPath)
|
||||
bytecode = r.bytecodeCache.Get(nil, newBytecodeKey)
|
||||
|
||||
newCacheData := make([]byte, 8+len(handlerPath))
|
||||
copy(newCacheData[:8], newBytecodeKey)
|
||||
copy(newCacheData[8:], handlerPath)
|
||||
r.routeCache.Set(routeCacheKey, newCacheData)
|
||||
|
||||
params := &Params{}
|
||||
r.Match(methodStr, pathStr, params)
|
||||
return bytecode, handlerPath, n.err, params, true
|
||||
}
|
||||
|
||||
params := &Params{}
|
||||
r.Match(methodStr, pathStr, params)
|
||||
return bytecode, handlerPath, n.err, params, true
|
||||
}
|
||||
|
||||
// Slow path: lookup and compile
|
||||
params := &Params{}
|
||||
node, found := r.Match(methodStr, pathStr, params)
|
||||
if !found {
|
||||
return nil, "", nil, nil, false
|
||||
}
|
||||
|
||||
scriptPath := node.handler
|
||||
if scriptPath == "" && node.indexFile != "" {
|
||||
scriptPath = node.indexFile
|
||||
}
|
||||
|
||||
if scriptPath == "" {
|
||||
return nil, "", nil, nil, false
|
||||
}
|
||||
|
||||
bytecodeKey := getBytecodeKey(scriptPath)
|
||||
bytecode := r.bytecodeCache.Get(nil, bytecodeKey)
|
||||
|
||||
if len(bytecode) == 0 {
|
||||
fsPath := node.fsPath
|
||||
if fsPath == "" {
|
||||
fsPath = "/"
|
||||
}
|
||||
if err := r.compileWithMiddleware(node, fsPath, scriptPath); err != nil {
|
||||
return nil, scriptPath, node.err, params, true
|
||||
}
|
||||
bytecode = r.bytecodeCache.Get(nil, bytecodeKey)
|
||||
}
|
||||
|
||||
// Cache the route
|
||||
cacheData := make([]byte, 8+len(scriptPath))
|
||||
copy(cacheData[:8], bytecodeKey)
|
||||
copy(cacheData[8:], scriptPath)
|
||||
r.routeCache.Set(routeCacheKey, cacheData)
|
||||
|
||||
return bytecode, scriptPath, node.err, params, true
|
||||
}
|
||||
|
||||
// GetRouteInfoString is a convenience method that accepts strings
|
||||
func (r *LuaRouter) GetRouteInfoString(method, path string) ([]byte, string, error, *Params, bool) {
|
||||
return r.GetRouteInfo([]byte(method), []byte(path))
|
||||
}
|
||||
190
router/node.go
Normal file
190
router/node.go
Normal file
@ -0,0 +1,190 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// node represents a node in the radix trie
|
||||
type node struct {
|
||||
// Static children mapped by path segment
|
||||
staticChild map[string]*node
|
||||
|
||||
// Parameter child for dynamic segments (e.g., :id)
|
||||
paramChild *node
|
||||
paramName string
|
||||
|
||||
// Handler information
|
||||
handler string // Path to the handler file
|
||||
indexFile string // Path to index.lua if exists
|
||||
modTime time.Time // Modification time of the handler
|
||||
fsPath string // Filesystem path (includes groups)
|
||||
|
||||
// Compilation error if any
|
||||
err error
|
||||
}
|
||||
|
||||
// pathInfo holds both URL path and filesystem path
|
||||
type pathInfo struct {
|
||||
urlPath string // URL path without groups (e.g., /users)
|
||||
fsPath string // Filesystem path with groups (e.g., /(admin)/users)
|
||||
}
|
||||
|
||||
// parsePathWithGroups parses a filesystem path, extracting groups
|
||||
func parsePathWithGroups(fsPath string) *pathInfo {
|
||||
segments := strings.Split(strings.Trim(fsPath, "/"), "/")
|
||||
var urlSegments []string
|
||||
|
||||
for _, segment := range segments {
|
||||
if segment == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip group segments (enclosed in parentheses)
|
||||
if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") {
|
||||
continue
|
||||
}
|
||||
|
||||
urlSegments = append(urlSegments, segment)
|
||||
}
|
||||
|
||||
urlPath := "/"
|
||||
if len(urlSegments) > 0 {
|
||||
urlPath = "/" + strings.Join(urlSegments, "/")
|
||||
}
|
||||
|
||||
return &pathInfo{
|
||||
urlPath: urlPath,
|
||||
fsPath: fsPath,
|
||||
}
|
||||
}
|
||||
|
||||
// findOrCreateNode finds or creates a node at the given URL path
|
||||
func (r *LuaRouter) findOrCreateNode(root *node, urlPath string) *node {
|
||||
segments := strings.Split(strings.Trim(urlPath, "/"), "/")
|
||||
if len(segments) == 1 && segments[0] == "" {
|
||||
return root
|
||||
}
|
||||
|
||||
current := root
|
||||
for _, segment := range segments {
|
||||
if segment == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's a parameter
|
||||
if strings.HasPrefix(segment, ":") {
|
||||
paramName := segment[1:]
|
||||
if current.paramChild == nil {
|
||||
current.paramChild = &node{
|
||||
staticChild: make(map[string]*node),
|
||||
paramName: paramName,
|
||||
}
|
||||
}
|
||||
current = current.paramChild
|
||||
} else {
|
||||
// Static segment
|
||||
if _, exists := current.staticChild[segment]; !exists {
|
||||
current.staticChild[segment] = &node{
|
||||
staticChild: make(map[string]*node),
|
||||
}
|
||||
}
|
||||
current = current.staticChild[segment]
|
||||
}
|
||||
}
|
||||
|
||||
return current
|
||||
}
|
||||
|
||||
// addRoute adds a route to the tree
|
||||
func (r *LuaRouter) addRoute(root *node, pathInfo *pathInfo, handlerPath string, modTime time.Time) {
|
||||
node := r.findOrCreateNode(root, pathInfo.urlPath)
|
||||
node.handler = handlerPath
|
||||
node.modTime = modTime
|
||||
node.fsPath = pathInfo.fsPath
|
||||
|
||||
// Compile the route with middleware
|
||||
r.compileWithMiddleware(node, pathInfo.fsPath, handlerPath)
|
||||
|
||||
// Track failed routes
|
||||
if node.err != nil {
|
||||
key := filepath.Base(handlerPath) + ":" + pathInfo.urlPath
|
||||
r.failedRoutes[key] = &RouteError{
|
||||
Path: pathInfo.urlPath,
|
||||
Method: strings.ToUpper(strings.TrimSuffix(filepath.Base(handlerPath), ".lua")),
|
||||
ScriptPath: handlerPath,
|
||||
Err: node.err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nodeForHandler finds a node by its handler path
|
||||
func (r *LuaRouter) nodeForHandler(handlerPath string) (*node, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, root := range r.routes {
|
||||
if node := findNodeByHandler(root, handlerPath); node != nil {
|
||||
return node, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findNodeByHandler recursively searches for a node with the given handler path
|
||||
func findNodeByHandler(n *node, handlerPath string) *node {
|
||||
if n.handler == handlerPath || n.indexFile == handlerPath {
|
||||
return n
|
||||
}
|
||||
|
||||
// Search static children
|
||||
for _, child := range n.staticChild {
|
||||
if found := findNodeByHandler(child, handlerPath); found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
// Search param child
|
||||
if n.paramChild != nil {
|
||||
if found := findNodeByHandler(n.paramChild, handlerPath); found != nil {
|
||||
return found
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// countNodesAndBytecode counts nodes and bytecode size in the tree
|
||||
func countNodesAndBytecode(n *node) (int, int64) {
|
||||
if n == nil {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
count := 0
|
||||
bytes := int64(0)
|
||||
|
||||
// Count this node if it has a handler
|
||||
if n.handler != "" || n.indexFile != "" {
|
||||
count = 1
|
||||
// Estimate bytecode size (would need actual bytecode cache lookup for accuracy)
|
||||
bytes = 1024 // Placeholder
|
||||
}
|
||||
|
||||
// Count static children
|
||||
for _, child := range n.staticChild {
|
||||
c, b := countNodesAndBytecode(child)
|
||||
count += c
|
||||
bytes += b
|
||||
}
|
||||
|
||||
// Count param child
|
||||
if n.paramChild != nil {
|
||||
c, b := countNodesAndBytecode(n.paramChild)
|
||||
count += c
|
||||
bytes += b
|
||||
}
|
||||
|
||||
return count, bytes
|
||||
}
|
||||
44
router/params.go
Normal file
44
router/params.go
Normal file
@ -0,0 +1,44 @@
|
||||
package router
|
||||
|
||||
// Maximum number of URL parameters per route
|
||||
const maxParams = 20
|
||||
|
||||
// Params holds URL parameters with fixed-size arrays to avoid allocations
|
||||
type Params struct {
|
||||
Keys [maxParams]string
|
||||
Values [maxParams]string
|
||||
Count int
|
||||
}
|
||||
|
||||
// Get returns a parameter value by name
|
||||
func (p *Params) Get(name string) string {
|
||||
for i := range p.Count {
|
||||
if p.Keys[i] == name {
|
||||
return p.Values[i]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Reset clears all parameters
|
||||
func (p *Params) Reset() {
|
||||
p.Count = 0
|
||||
}
|
||||
|
||||
// Set adds or updates a parameter
|
||||
func (p *Params) Set(name, value string) {
|
||||
// Try to update existing
|
||||
for i := range p.Count {
|
||||
if p.Keys[i] == name {
|
||||
p.Values[i] = value
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Add new if space available
|
||||
if p.Count < maxParams {
|
||||
p.Keys[p.Count] = name
|
||||
p.Values[p.Count] = value
|
||||
p.Count++
|
||||
}
|
||||
}
|
||||
156
router/router.go
Normal file
156
router/router.go
Normal file
@ -0,0 +1,156 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/VictoriaMetrics/fastcache"
|
||||
)
|
||||
|
||||
// Default cache sizes
|
||||
const (
|
||||
defaultBytecodeMaxBytes = 32 * 1024 * 1024 // 32MB for bytecode cache
|
||||
defaultRouteMaxBytes = 8 * 1024 * 1024 // 8MB for route match cache
|
||||
)
|
||||
|
||||
// LuaRouter is a filesystem-based HTTP router for Lua files
|
||||
type LuaRouter struct {
|
||||
routesDir string // Root directory containing route files
|
||||
routes map[string]*node // Method -> route tree
|
||||
failedRoutes map[string]*RouteError // Track failed routes
|
||||
mu sync.RWMutex // Lock for concurrent access to routes
|
||||
|
||||
routeCache *fastcache.Cache // Cache for route lookups
|
||||
bytecodeCache *fastcache.Cache // Cache for compiled bytecode
|
||||
|
||||
// Middleware tracking for path hierarchy
|
||||
middlewareFiles map[string][]string // filesystem path -> middleware file paths
|
||||
|
||||
// Caching fields
|
||||
middlewareCache map[string][]byte // path -> content
|
||||
sourceCache map[string][]byte // combined source cache key -> compiled bytecode
|
||||
sourceMtimes map[string]time.Time // track modification times
|
||||
|
||||
// Shared Lua state for compilation
|
||||
compileState *luajit.State
|
||||
compileStateMu sync.Mutex // Protect concurrent access to Lua state
|
||||
}
|
||||
|
||||
// NewLuaRouter creates a new LuaRouter instance
|
||||
func NewLuaRouter(routesDir string) (*LuaRouter, error) {
|
||||
info, err := os.Stat(routesDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, errors.New("routes path is not a directory")
|
||||
}
|
||||
|
||||
// Create shared Lua state
|
||||
compileState := luajit.New()
|
||||
if compileState == nil {
|
||||
return nil, errors.New("failed to create Lua compile state")
|
||||
}
|
||||
|
||||
r := &LuaRouter{
|
||||
routesDir: routesDir,
|
||||
routes: make(map[string]*node),
|
||||
failedRoutes: make(map[string]*RouteError),
|
||||
middlewareFiles: make(map[string][]string),
|
||||
routeCache: fastcache.New(defaultRouteMaxBytes),
|
||||
bytecodeCache: fastcache.New(defaultBytecodeMaxBytes),
|
||||
middlewareCache: make(map[string][]byte),
|
||||
sourceCache: make(map[string][]byte),
|
||||
sourceMtimes: make(map[string]time.Time),
|
||||
compileState: compileState,
|
||||
}
|
||||
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
|
||||
for _, method := range methods {
|
||||
r.routes[method] = &node{
|
||||
staticChild: make(map[string]*node),
|
||||
}
|
||||
}
|
||||
|
||||
err = r.buildRoutes()
|
||||
|
||||
if len(r.failedRoutes) > 0 {
|
||||
return r, ErrRoutesCompilationErrors
|
||||
}
|
||||
|
||||
return r, err
|
||||
}
|
||||
|
||||
// Refresh rebuilds the router by rescanning the routes directory
|
||||
func (r *LuaRouter) Refresh() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for method := range r.routes {
|
||||
r.routes[method] = &node{
|
||||
staticChild: make(map[string]*node),
|
||||
}
|
||||
}
|
||||
|
||||
r.failedRoutes = make(map[string]*RouteError)
|
||||
r.middlewareFiles = make(map[string][]string)
|
||||
r.middlewareCache = make(map[string][]byte)
|
||||
r.sourceCache = make(map[string][]byte)
|
||||
r.sourceMtimes = make(map[string]time.Time)
|
||||
|
||||
err := r.buildRoutes()
|
||||
|
||||
if len(r.failedRoutes) > 0 {
|
||||
return ErrRoutesCompilationErrors
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ReportFailedRoutes returns a list of routes that failed to compile
|
||||
func (r *LuaRouter) ReportFailedRoutes() []*RouteError {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
result := make([]*RouteError, 0, len(r.failedRoutes))
|
||||
for _, re := range r.failedRoutes {
|
||||
result = append(result, re)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Close cleans up the router and its resources
|
||||
func (r *LuaRouter) Close() {
|
||||
r.compileStateMu.Lock()
|
||||
if r.compileState != nil {
|
||||
r.compileState.Close()
|
||||
r.compileState = nil
|
||||
}
|
||||
r.compileStateMu.Unlock()
|
||||
}
|
||||
|
||||
// GetRouteStats returns statistics about the router
|
||||
func (r *LuaRouter) GetRouteStats() (int, int64) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
routeCount := 0
|
||||
bytecodeBytes := int64(0)
|
||||
|
||||
for _, root := range r.routes {
|
||||
count, bytes := countNodesAndBytecode(root)
|
||||
routeCount += count
|
||||
bytecodeBytes += bytes
|
||||
}
|
||||
|
||||
return routeCount, bytecodeBytes
|
||||
}
|
||||
|
||||
type NodeWithError struct {
|
||||
ScriptPath string
|
||||
Error error
|
||||
}
|
||||
117
runner/context.go
Normal file
117
runner/context.go
Normal file
@ -0,0 +1,117 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Context represents execution context for a Lua script
|
||||
type Context struct {
|
||||
// Values stores any context values (route params, HTTP request info, etc.)
|
||||
Values map[string]any
|
||||
|
||||
// FastHTTP context if this was created from an HTTP request
|
||||
RequestCtx *fasthttp.RequestCtx
|
||||
|
||||
// Buffer for efficient string operations
|
||||
buffer *bytebufferpool.ByteBuffer
|
||||
}
|
||||
|
||||
// Context pool to reduce allocations
|
||||
var contextPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Context{
|
||||
Values: make(map[string]any, 32),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// NewContext creates a new context, potentially reusing one from the pool
|
||||
func NewContext() *Context {
|
||||
ctx := contextPool.Get().(*Context)
|
||||
return ctx
|
||||
}
|
||||
|
||||
// NewHTTPContext creates a new context from a fasthttp RequestCtx
|
||||
func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context {
|
||||
ctx := NewContext()
|
||||
ctx.RequestCtx = requestCtx
|
||||
|
||||
// Extract common HTTP values that Lua might need
|
||||
if requestCtx != nil {
|
||||
ctx.Values["_request_method"] = string(requestCtx.Method())
|
||||
ctx.Values["_request_path"] = string(requestCtx.Path())
|
||||
ctx.Values["_request_url"] = string(requestCtx.RequestURI())
|
||||
|
||||
// Extract cookies
|
||||
cookies := make(map[string]any)
|
||||
requestCtx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||
cookies[string(key)] = string(value)
|
||||
})
|
||||
ctx.Values["_request_cookies"] = cookies
|
||||
|
||||
// Extract query params
|
||||
query := make(map[string]any)
|
||||
requestCtx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
query[string(key)] = string(value)
|
||||
})
|
||||
ctx.Values["_request_query"] = query
|
||||
|
||||
// Extract form data if present
|
||||
if requestCtx.IsPost() || requestCtx.IsPut() {
|
||||
form := make(map[string]any)
|
||||
requestCtx.PostArgs().VisitAll(func(key, value []byte) {
|
||||
form[string(key)] = string(value)
|
||||
})
|
||||
ctx.Values["_request_form"] = form
|
||||
}
|
||||
|
||||
// Extract headers
|
||||
headers := make(map[string]any)
|
||||
requestCtx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
headers[string(key)] = string(value)
|
||||
})
|
||||
ctx.Values["_request_headers"] = headers
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Release returns the context to the pool after clearing its values
|
||||
func (c *Context) Release() {
|
||||
// Clear all values to prevent data leakage
|
||||
for k := range c.Values {
|
||||
delete(c.Values, k)
|
||||
}
|
||||
|
||||
// Reset request context
|
||||
c.RequestCtx = nil
|
||||
|
||||
// Return buffer to pool if we have one
|
||||
if c.buffer != nil {
|
||||
bytebufferpool.Put(c.buffer)
|
||||
c.buffer = nil
|
||||
}
|
||||
|
||||
contextPool.Put(c)
|
||||
}
|
||||
|
||||
// GetBuffer returns a byte buffer for efficient string operations
|
||||
func (c *Context) GetBuffer() *bytebufferpool.ByteBuffer {
|
||||
if c.buffer == nil {
|
||||
c.buffer = bytebufferpool.Get()
|
||||
}
|
||||
return c.buffer
|
||||
}
|
||||
|
||||
// Set adds a value to the context
|
||||
func (c *Context) Set(key string, value any) {
|
||||
c.Values[key] = value
|
||||
}
|
||||
|
||||
// Get retrieves a value from the context
|
||||
func (c *Context) Get(key string) any {
|
||||
return c.Values[key]
|
||||
}
|
||||
404
runner/crypto.go
Normal file
404
runner/crypto.go
Normal file
@ -0,0 +1,404 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"hash"
|
||||
"math"
|
||||
mrand "math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
var (
|
||||
// Map to store state-specific RNGs
|
||||
stateRngs = make(map[*luajit.State]*mrand.PCG)
|
||||
stateRngsMu sync.Mutex
|
||||
)
|
||||
|
||||
// RegisterCryptoFunctions registers all crypto functions with the Lua state
|
||||
func RegisterCryptoFunctions(state *luajit.State) error {
|
||||
// Create a state-specific RNG
|
||||
stateRngsMu.Lock()
|
||||
stateRngs[state] = mrand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano()>>32))
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
// Register hash functions
|
||||
if err := state.RegisterGoFunction("__crypto_hash", cryptoHash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register HMAC functions
|
||||
if err := state.RegisterGoFunction("__crypto_hmac", cryptoHmac); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register UUID generation
|
||||
if err := state.RegisterGoFunction("__crypto_uuid", cryptoUuid); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register random functions
|
||||
if err := state.RegisterGoFunction("__crypto_random", cryptoRandom); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__crypto_random_bytes", cryptoRandomBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__crypto_random_int", cryptoRandomInt); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__crypto_random_seed", cryptoRandomSeed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Override Lua's math.random
|
||||
if err := OverrideLuaRandom(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupCrypto cleans up resources when a state is closed
|
||||
func CleanupCrypto(state *luajit.State) {
|
||||
stateRngsMu.Lock()
|
||||
delete(stateRngs, state)
|
||||
stateRngsMu.Unlock()
|
||||
}
|
||||
|
||||
// cryptoHash generates hash digests using various algorithms
|
||||
func cryptoHash(state *luajit.State) int {
|
||||
if !state.IsString(1) || !state.IsString(2) {
|
||||
state.PushString("hash: expected (string data, string algorithm)")
|
||||
return 1
|
||||
}
|
||||
|
||||
data := state.ToString(1)
|
||||
algorithm := state.ToString(2)
|
||||
|
||||
var h hash.Hash
|
||||
|
||||
switch algorithm {
|
||||
case "md5":
|
||||
h = md5.New()
|
||||
case "sha1":
|
||||
h = sha1.New()
|
||||
case "sha256":
|
||||
h = sha256.New()
|
||||
case "sha512":
|
||||
h = sha512.New()
|
||||
default:
|
||||
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
|
||||
return 1
|
||||
}
|
||||
|
||||
h.Write([]byte(data))
|
||||
hashBytes := h.Sum(nil)
|
||||
|
||||
// Output format
|
||||
outputFormat := "hex"
|
||||
if state.GetTop() >= 3 && state.IsString(3) {
|
||||
outputFormat = state.ToString(3)
|
||||
}
|
||||
|
||||
switch outputFormat {
|
||||
case "hex":
|
||||
state.PushString(hex.EncodeToString(hashBytes))
|
||||
case "binary":
|
||||
state.PushString(string(hashBytes))
|
||||
default:
|
||||
state.PushString(hex.EncodeToString(hashBytes))
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoHmac generates HMAC using various hash algorithms
|
||||
func cryptoHmac(state *luajit.State) int {
|
||||
if !state.IsString(1) || !state.IsString(2) || !state.IsString(3) {
|
||||
state.PushString("hmac: expected (string data, string key, string algorithm)")
|
||||
return 1
|
||||
}
|
||||
|
||||
data := state.ToString(1)
|
||||
key := state.ToString(2)
|
||||
algorithm := state.ToString(3)
|
||||
|
||||
var h func() hash.Hash
|
||||
|
||||
switch algorithm {
|
||||
case "md5":
|
||||
h = md5.New
|
||||
case "sha1":
|
||||
h = sha1.New
|
||||
case "sha256":
|
||||
h = sha256.New
|
||||
case "sha512":
|
||||
h = sha512.New
|
||||
default:
|
||||
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
|
||||
return 1
|
||||
}
|
||||
|
||||
mac := hmac.New(h, []byte(key))
|
||||
mac.Write([]byte(data))
|
||||
macBytes := mac.Sum(nil)
|
||||
|
||||
// Output format
|
||||
outputFormat := "hex"
|
||||
if state.GetTop() >= 4 && state.IsString(4) {
|
||||
outputFormat = state.ToString(4)
|
||||
}
|
||||
|
||||
switch outputFormat {
|
||||
case "hex":
|
||||
state.PushString(hex.EncodeToString(macBytes))
|
||||
case "binary":
|
||||
state.PushString(string(macBytes))
|
||||
default:
|
||||
state.PushString(hex.EncodeToString(macBytes))
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoUuid generates a random UUID v4
|
||||
func cryptoUuid(state *luajit.State) int {
|
||||
uuid := make([]byte, 16)
|
||||
_, err := rand.Read(uuid)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("uuid: generation error: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Set version (4) and variant (RFC 4122)
|
||||
uuid[6] = (uuid[6] & 0x0F) | 0x40
|
||||
uuid[8] = (uuid[8] & 0x3F) | 0x80
|
||||
|
||||
uuidStr := fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:])
|
||||
|
||||
state.PushString(uuidStr)
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoRandomBytes generates random bytes
|
||||
func cryptoRandomBytes(state *luajit.State) int {
|
||||
if !state.IsNumber(1) {
|
||||
state.PushString("random_bytes: expected (number length)")
|
||||
return 1
|
||||
}
|
||||
|
||||
length := int(state.ToNumber(1))
|
||||
if length <= 0 {
|
||||
state.PushString("random_bytes: length must be positive")
|
||||
return 1
|
||||
}
|
||||
|
||||
// Check if secure
|
||||
secure := true
|
||||
if state.GetTop() >= 2 && state.IsBoolean(2) {
|
||||
secure = state.ToBoolean(2)
|
||||
}
|
||||
|
||||
bytes := make([]byte, length)
|
||||
|
||||
if secure {
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("random_bytes: error: %v", err))
|
||||
return 1
|
||||
}
|
||||
} else {
|
||||
stateRngsMu.Lock()
|
||||
stateRng, ok := stateRngs[state]
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
state.PushString("random_bytes: RNG not initialized")
|
||||
return 1
|
||||
}
|
||||
|
||||
for i := range bytes {
|
||||
bytes[i] = byte(stateRng.Uint64() & 0xFF)
|
||||
}
|
||||
}
|
||||
|
||||
// Output format
|
||||
outputFormat := "binary"
|
||||
if state.GetTop() >= 3 && state.IsString(3) {
|
||||
outputFormat = state.ToString(3)
|
||||
}
|
||||
|
||||
switch outputFormat {
|
||||
case "binary":
|
||||
state.PushString(string(bytes))
|
||||
case "hex":
|
||||
state.PushString(hex.EncodeToString(bytes))
|
||||
default:
|
||||
state.PushString(string(bytes))
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoRandomInt generates a random integer in range [min, max]
|
||||
func cryptoRandomInt(state *luajit.State) int {
|
||||
if !state.IsNumber(1) || !state.IsNumber(2) {
|
||||
state.PushString("random_int: expected (number min, number max)")
|
||||
return 1
|
||||
}
|
||||
|
||||
min := int64(state.ToNumber(1))
|
||||
max := int64(state.ToNumber(2))
|
||||
|
||||
if max <= min {
|
||||
state.PushString("random_int: max must be greater than min")
|
||||
return 1
|
||||
}
|
||||
|
||||
// Check if secure
|
||||
secure := true
|
||||
if state.GetTop() >= 3 && state.IsBoolean(3) {
|
||||
secure = state.ToBoolean(3)
|
||||
}
|
||||
|
||||
range_size := max - min + 1
|
||||
|
||||
var result int64
|
||||
|
||||
if secure {
|
||||
bytes := make([]byte, 8)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("random_int: error: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
val := binary.BigEndian.Uint64(bytes)
|
||||
result = min + int64(val%uint64(range_size))
|
||||
} else {
|
||||
stateRngsMu.Lock()
|
||||
stateRng, ok := stateRngs[state]
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
state.PushString("random_int: RNG not initialized")
|
||||
return 1
|
||||
}
|
||||
|
||||
result = min + int64(stateRng.Uint64()%uint64(range_size))
|
||||
}
|
||||
|
||||
state.PushNumber(float64(result))
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoRandom implements math.random functionality
|
||||
func cryptoRandom(state *luajit.State) int {
|
||||
numArgs := state.GetTop()
|
||||
|
||||
// Check if secure
|
||||
secure := false
|
||||
|
||||
// math.random() - return [0,1)
|
||||
if numArgs == 0 {
|
||||
if secure {
|
||||
bytes := make([]byte, 8)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("random: error: %v", err))
|
||||
return 1
|
||||
}
|
||||
val := binary.BigEndian.Uint64(bytes)
|
||||
state.PushNumber(float64(val) / float64(math.MaxUint64))
|
||||
} else {
|
||||
stateRngsMu.Lock()
|
||||
stateRng, ok := stateRngs[state]
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
state.PushString("random: RNG not initialized")
|
||||
return 1
|
||||
}
|
||||
|
||||
state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64))
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// math.random(n) - return integer [1,n]
|
||||
if numArgs == 1 && state.IsNumber(1) {
|
||||
n := int64(state.ToNumber(1))
|
||||
if n < 1 {
|
||||
state.PushString("random: upper bound must be >= 1")
|
||||
return 1
|
||||
}
|
||||
|
||||
state.PushNumber(1) // min
|
||||
state.PushNumber(float64(n)) // max
|
||||
state.PushBoolean(secure) // secure flag
|
||||
return cryptoRandomInt(state)
|
||||
}
|
||||
|
||||
// math.random(m, n) - return integer [m,n]
|
||||
if numArgs >= 2 && state.IsNumber(1) && state.IsNumber(2) {
|
||||
state.PushBoolean(secure) // secure flag
|
||||
return cryptoRandomInt(state)
|
||||
}
|
||||
|
||||
state.PushString("random: invalid arguments")
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoRandomSeed sets seed for non-secure RNG
|
||||
func cryptoRandomSeed(state *luajit.State) int {
|
||||
if !state.IsNumber(1) {
|
||||
state.PushString("randomseed: expected (number seed)")
|
||||
return 1
|
||||
}
|
||||
|
||||
seed := uint64(state.ToNumber(1))
|
||||
|
||||
stateRngsMu.Lock()
|
||||
stateRngs[state] = mrand.NewPCG(seed, seed>>32)
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// OverrideLuaRandom replaces Lua's math.random with Go implementation
|
||||
func OverrideLuaRandom(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("go_math_random", cryptoRandom); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := state.RegisterGoFunction("go_math_randomseed", cryptoRandomSeed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Replace original functions
|
||||
return state.DoString(`
|
||||
-- Save original functions
|
||||
_G._original_math_random = math.random
|
||||
_G._original_math_randomseed = math.randomseed
|
||||
|
||||
-- Replace with Go implementations
|
||||
math.random = go_math_random
|
||||
math.randomseed = go_math_randomseed
|
||||
|
||||
-- Clean up global namespace
|
||||
go_math_random = nil
|
||||
go_math_randomseed = nil
|
||||
`)
|
||||
}
|
||||
166
runner/embed.go
Normal file
166
runner/embed.go
Normal file
@ -0,0 +1,166 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"Moonshark/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
//go:embed lua/sandbox.lua
|
||||
var sandboxLuaCode string
|
||||
|
||||
//go:embed lua/json.lua
|
||||
var jsonLuaCode string
|
||||
|
||||
//go:embed lua/sqlite.lua
|
||||
var sqliteLuaCode string
|
||||
|
||||
//go:embed lua/fs.lua
|
||||
var fsLuaCode string
|
||||
|
||||
//go:embed lua/util.lua
|
||||
var utilLuaCode string
|
||||
|
||||
//go:embed lua/string.lua
|
||||
var stringLuaCode string
|
||||
|
||||
//go:embed lua/table.lua
|
||||
var tableLuaCode string
|
||||
|
||||
//go:embed lua/crypto.lua
|
||||
var cryptoLuaCode string
|
||||
|
||||
//go:embed lua/time.lua
|
||||
var timeLuaCode string
|
||||
|
||||
//go:embed lua/math.lua
|
||||
var mathLuaCode string
|
||||
|
||||
//go:embed lua/env.lua
|
||||
var envLuaCode string
|
||||
|
||||
// ModuleInfo holds information about an embeddable Lua module
|
||||
type ModuleInfo struct {
|
||||
Name string // Module name
|
||||
Code string // Module source code
|
||||
Bytecode atomic.Pointer[[]byte] // Cached bytecode
|
||||
Once sync.Once // For one-time compilation
|
||||
DefinesGlobal bool // Whether module defines globals directly
|
||||
}
|
||||
|
||||
var (
|
||||
sandbox = ModuleInfo{Name: "sandbox", Code: sandboxLuaCode}
|
||||
modules = []ModuleInfo{
|
||||
{Name: "json", Code: jsonLuaCode, DefinesGlobal: true},
|
||||
{Name: "sqlite", Code: sqliteLuaCode},
|
||||
{Name: "fs", Code: fsLuaCode, DefinesGlobal: true},
|
||||
{Name: "util", Code: utilLuaCode, DefinesGlobal: true},
|
||||
{Name: "string", Code: stringLuaCode},
|
||||
{Name: "table", Code: tableLuaCode},
|
||||
{Name: "crypto", Code: cryptoLuaCode, DefinesGlobal: true},
|
||||
{Name: "time", Code: timeLuaCode},
|
||||
{Name: "math", Code: mathLuaCode},
|
||||
{Name: "env", Code: envLuaCode, DefinesGlobal: true},
|
||||
}
|
||||
)
|
||||
|
||||
// precompileModule compiles a module's code to bytecode once
|
||||
func precompileModule(m *ModuleInfo) {
|
||||
m.Once.Do(func() {
|
||||
tempState := luajit.New()
|
||||
if tempState == nil {
|
||||
logger.Fatal("Failed to create temp Lua state for %s module compilation", m.Name)
|
||||
return
|
||||
}
|
||||
defer tempState.Close()
|
||||
defer tempState.Cleanup()
|
||||
|
||||
code, err := tempState.CompileBytecode(m.Code, m.Name+".lua")
|
||||
if err != nil {
|
||||
logger.Error("Failed to compile %s module: %v", m.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
bytecode := make([]byte, len(code))
|
||||
copy(bytecode, code)
|
||||
m.Bytecode.Store(&bytecode)
|
||||
|
||||
logger.Debug("Successfully precompiled %s.lua to bytecode (%d bytes)", m.Name, len(code))
|
||||
})
|
||||
}
|
||||
|
||||
// loadModule loads a module into a Lua state
|
||||
func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error {
|
||||
// Ensure bytecode is compiled
|
||||
precompileModule(m)
|
||||
|
||||
// Attempt to load from bytecode
|
||||
bytecode := m.Bytecode.Load()
|
||||
if bytecode != nil && len(*bytecode) > 0 {
|
||||
if verbose {
|
||||
logger.Debug("Loading %s.lua from precompiled bytecode", m.Name)
|
||||
}
|
||||
|
||||
if err := state.LoadBytecode(*bytecode, m.Name+".lua"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.DefinesGlobal {
|
||||
// Module defines its own globals, just run it
|
||||
if err := state.RunBytecode(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Module returns a table, capture and set as global
|
||||
if err := state.RunBytecodeWithResults(1); err != nil {
|
||||
return err
|
||||
}
|
||||
state.SetGlobal(m.Name)
|
||||
}
|
||||
} else {
|
||||
// Fallback to interpreting the source
|
||||
if verbose {
|
||||
logger.Warning("Using non-precompiled %s.lua", m.Name)
|
||||
}
|
||||
|
||||
if err := state.DoString(m.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadSandboxIntoState loads all modules and sandbox into a Lua state
|
||||
func loadSandboxIntoState(state *luajit.State, verbose bool) error {
|
||||
// Load all modules first
|
||||
for i := range modules {
|
||||
if err := loadModule(state, &modules[i], verbose); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize active connections tracking (specific to SQLite)
|
||||
if err := state.DoString(`__active_sqlite_connections = {}`); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load the sandbox last
|
||||
precompileModule(&sandbox)
|
||||
bytecode := sandbox.Bytecode.Load()
|
||||
if bytecode != nil && len(*bytecode) > 0 {
|
||||
if verbose {
|
||||
logger.Debug("Loading sandbox.lua from precompiled bytecode")
|
||||
}
|
||||
return state.LoadAndRunBytecode(*bytecode, "sandbox.lua")
|
||||
}
|
||||
|
||||
if verbose {
|
||||
logger.Warning("Using non-precompiled sandbox.lua")
|
||||
}
|
||||
return state.DoString(sandboxLuaCode)
|
||||
}
|
||||
278
runner/env.go
Normal file
278
runner/env.go
Normal file
@ -0,0 +1,278 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"Moonshark/utils/color"
|
||||
"Moonshark/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// EnvManager handles loading, storing, and saving environment variables
|
||||
type EnvManager struct {
|
||||
envPath string // Path to .env file
|
||||
vars map[string]any // Environment variables in memory
|
||||
mu sync.RWMutex // Thread-safe access
|
||||
}
|
||||
|
||||
// Global environment manager instance
|
||||
var globalEnvManager *EnvManager
|
||||
|
||||
// InitEnv initializes the environment manager with the given data directory
|
||||
func InitEnv(dataDir string) error {
|
||||
if dataDir == "" {
|
||||
return fmt.Errorf("data directory cannot be empty")
|
||||
}
|
||||
|
||||
// Create data directory if it doesn't exist
|
||||
if err := os.MkdirAll(dataDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create data directory: %w", err)
|
||||
}
|
||||
|
||||
envPath := filepath.Join(dataDir, ".env")
|
||||
|
||||
globalEnvManager = &EnvManager{
|
||||
envPath: envPath,
|
||||
vars: make(map[string]any),
|
||||
}
|
||||
|
||||
// Load existing .env file if it exists
|
||||
if err := globalEnvManager.load(); err != nil {
|
||||
logger.Warning("Failed to load .env file: %v", err)
|
||||
}
|
||||
|
||||
count := len(globalEnvManager.vars)
|
||||
if count > 0 {
|
||||
logger.Info("Environment loaded: %s vars from %s",
|
||||
color.Apply(fmt.Sprintf("%d", count), color.Yellow),
|
||||
color.Apply(envPath, color.Yellow))
|
||||
} else {
|
||||
logger.Info("Environment initialized: %s", color.Apply(envPath, color.Yellow))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGlobalEnvManager returns the global environment manager instance
|
||||
func GetGlobalEnvManager() *EnvManager {
|
||||
return globalEnvManager
|
||||
}
|
||||
|
||||
// load reads the .env file and populates the vars map
|
||||
func (e *EnvManager) load() error {
|
||||
file, err := os.Open(e.envPath)
|
||||
if os.IsNotExist(err) {
|
||||
// File doesn't exist, start with empty env
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open .env file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse key=value
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
logger.Warning("Invalid .env line %d: %s", lineNum, line)
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
// Remove quotes if present
|
||||
if len(value) >= 2 {
|
||||
if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) ||
|
||||
(strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) {
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
}
|
||||
|
||||
e.vars[key] = value
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// Save writes the current environment variables to the .env file
|
||||
func (e *EnvManager) Save() error {
|
||||
if e == nil {
|
||||
return nil // No env manager initialized
|
||||
}
|
||||
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
file, err := os.Create(e.envPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create .env file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Sort keys for consistent output
|
||||
keys := make([]string, 0, len(e.vars))
|
||||
for key := range e.vars {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
// Write header comment
|
||||
fmt.Fprintln(file, "# Environment variables for Moonshark")
|
||||
fmt.Fprintln(file, "# Generated automatically - you can edit this file")
|
||||
fmt.Fprintln(file)
|
||||
|
||||
// Write each variable
|
||||
for _, key := range keys {
|
||||
value := e.vars[key]
|
||||
|
||||
// Convert value to string
|
||||
var strValue string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
strValue = v
|
||||
case nil:
|
||||
continue // Skip nil values
|
||||
default:
|
||||
strValue = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
// Quote values that contain spaces or special characters
|
||||
if strings.ContainsAny(strValue, " \t\n\r\"'\\") {
|
||||
strValue = fmt.Sprintf("\"%s\"", strings.ReplaceAll(strValue, "\"", "\\\""))
|
||||
}
|
||||
|
||||
fmt.Fprintf(file, "%s=%s\n", key, strValue)
|
||||
}
|
||||
|
||||
logger.Debug("Environment saved: %d vars to %s", len(e.vars), e.envPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an environment variable
|
||||
func (e *EnvManager) Get(key string) (any, bool) {
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
value, exists := e.vars[key]
|
||||
return value, exists
|
||||
}
|
||||
|
||||
// Set stores an environment variable
|
||||
func (e *EnvManager) Set(key string, value any) {
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.vars[key] = value
|
||||
}
|
||||
|
||||
// GetAll returns a copy of all environment variables
|
||||
func (e *EnvManager) GetAll() map[string]any {
|
||||
if e == nil {
|
||||
return make(map[string]any)
|
||||
}
|
||||
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
result := make(map[string]any, len(e.vars))
|
||||
for k, v := range e.vars {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// CleanupEnv saves the environment and cleans up resources
|
||||
func CleanupEnv() error {
|
||||
if globalEnvManager != nil {
|
||||
return globalEnvManager.Save()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// envGet Lua function to get an environment variable
|
||||
func envGet(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
key := state.ToString(1)
|
||||
if value, exists := globalEnvManager.Get(key); exists {
|
||||
if err := state.PushValue(value); err != nil {
|
||||
state.PushNil()
|
||||
}
|
||||
} else {
|
||||
state.PushNil()
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// envSet Lua function to set an environment variable
|
||||
func envSet(state *luajit.State) int {
|
||||
if !state.IsString(1) || !state.IsString(2) {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
key := state.ToString(1)
|
||||
value := state.ToString(2)
|
||||
|
||||
globalEnvManager.Set(key, value)
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// envGetAll Lua function to get all environment variables
|
||||
func envGetAll(state *luajit.State) int {
|
||||
vars := globalEnvManager.GetAll()
|
||||
|
||||
if err := state.PushTable(vars); err != nil {
|
||||
state.PushNil()
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// RegisterEnvFunctions registers environment functions with the Lua state
|
||||
func RegisterEnvFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__env_get", envGet); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__env_set", envSet); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__env_get_all", envGetAll); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
579
runner/fs.go
Normal file
579
runner/fs.go
Normal file
@ -0,0 +1,579 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"Moonshark/utils/color"
|
||||
"Moonshark/utils/logger"
|
||||
|
||||
lru "git.sharkk.net/Go/LRU"
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/golang/snappy"
|
||||
)
|
||||
|
||||
// Global filesystem path (set during initialization)
|
||||
var fsBasePath string
|
||||
|
||||
// Global file cache with compressed data
|
||||
var fileCache *lru.LRUCache
|
||||
|
||||
// Cache entry info for statistics/debugging
|
||||
type cacheStats struct {
|
||||
hits int64
|
||||
misses int64
|
||||
}
|
||||
|
||||
var stats cacheStats
|
||||
|
||||
// InitFS initializes the filesystem with the given base path
|
||||
func InitFS(basePath string) error {
|
||||
if basePath == "" {
|
||||
return errors.New("filesystem base path cannot be empty")
|
||||
}
|
||||
|
||||
// Create the directory if it doesn't exist
|
||||
if err := os.MkdirAll(basePath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create filesystem directory: %w", err)
|
||||
}
|
||||
|
||||
// Store the absolute path
|
||||
absPath, err := filepath.Abs(basePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
|
||||
fsBasePath = absPath
|
||||
|
||||
// Initialize file cache with 2000 entries (reasonable for most use cases)
|
||||
fileCache = lru.NewLRUCache(2000)
|
||||
|
||||
logger.Info("Filesystem is g2g! %s", color.Apply(fsBasePath, color.Yellow))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupFS performs any necessary cleanup
|
||||
func CleanupFS() {
|
||||
if fileCache != nil {
|
||||
fileCache.Clear()
|
||||
logger.Info(
|
||||
"File cache cleared - %s hits, %s misses",
|
||||
color.Apply(fmt.Sprintf("%d", stats.hits), color.Yellow),
|
||||
color.Apply(fmt.Sprintf("%d", stats.misses), color.Red),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ResolvePath resolves a given path relative to the filesystem base
|
||||
// Returns the actual path and an error if the path tries to escape the sandbox
|
||||
func ResolvePath(path string) (string, error) {
|
||||
if fsBasePath == "" {
|
||||
return "", errors.New("filesystem not initialized")
|
||||
}
|
||||
|
||||
// Clean the path to remove any .. or . components
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Replace backslashes with forward slashes for consistent handling
|
||||
cleanPath = strings.ReplaceAll(cleanPath, "\\", "/")
|
||||
|
||||
// Remove any leading / or drive letter to make it relative
|
||||
cleanPath = strings.TrimPrefix(cleanPath, "/")
|
||||
|
||||
// Remove drive letter on Windows (e.g. C:)
|
||||
if len(cleanPath) >= 2 && cleanPath[1] == ':' {
|
||||
cleanPath = cleanPath[2:]
|
||||
}
|
||||
|
||||
// Ensure the path doesn't contain .. to prevent escaping
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return "", errors.New("path cannot contain .. components")
|
||||
}
|
||||
|
||||
// Join with the base path
|
||||
fullPath := filepath.Join(fsBasePath, cleanPath)
|
||||
|
||||
// Verify the path is still within the base directory
|
||||
if !strings.HasPrefix(fullPath, fsBasePath) {
|
||||
return "", errors.New("path escapes the filesystem sandbox")
|
||||
}
|
||||
|
||||
return fullPath, nil
|
||||
}
|
||||
|
||||
// getCacheKey creates a cache key from path and modification time
|
||||
func getCacheKey(fullPath string, modTime time.Time) string {
|
||||
return fmt.Sprintf("%s:%d", fullPath, modTime.Unix())
|
||||
}
|
||||
|
||||
// fsReadFile reads a file and returns its contents
|
||||
func fsReadFile(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.read_file: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
state.PushString("fs.read_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Get file info for cache key and validation
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
state.PushString("fs.read_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Create cache key with path and modification time
|
||||
cacheKey := getCacheKey(fullPath, info.ModTime())
|
||||
|
||||
// Try to get from cache first
|
||||
if fileCache != nil {
|
||||
if cachedData, exists := fileCache.Get(cacheKey); exists {
|
||||
if compressedData, ok := cachedData.([]byte); ok {
|
||||
// Decompress cached data
|
||||
data, err := snappy.Decode(nil, compressedData)
|
||||
if err == nil {
|
||||
stats.hits++
|
||||
state.PushString(string(data))
|
||||
return 1
|
||||
}
|
||||
// Cache corruption - continue to disk read
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss or error - read from disk
|
||||
stats.misses++
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
state.PushString("fs.read_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Compress and cache the data
|
||||
if fileCache != nil {
|
||||
compressedData := snappy.Encode(nil, data)
|
||||
fileCache.Put(cacheKey, compressedData)
|
||||
}
|
||||
|
||||
state.PushString(string(data))
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsWriteFile writes data to a file
|
||||
func fsWriteFile(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.write_file: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
if !state.IsString(2) {
|
||||
state.PushString("fs.write_file: content must be a string")
|
||||
return -1
|
||||
}
|
||||
content := state.ToString(2)
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
state.PushString("fs.write_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
dir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
state.PushString("fs.write_file: failed to create directory: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
err = os.WriteFile(fullPath, []byte(content), 0644)
|
||||
if err != nil {
|
||||
state.PushString("fs.write_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Invalidate cache entries for this file path
|
||||
if fileCache != nil {
|
||||
// We can't easily iterate through cache keys, so we'll let the cache
|
||||
// naturally expire old entries when the file is read again
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsAppendFile appends data to a file
|
||||
func fsAppendFile(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.append_file: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
if !state.IsString(2) {
|
||||
state.PushString("fs.append_file: content must be a string")
|
||||
return -1
|
||||
}
|
||||
content := state.ToString(2)
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
state.PushString("fs.append_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
dir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
state.PushString("fs.append_file: failed to create directory: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
state.PushString("fs.append_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = file.Write([]byte(content))
|
||||
if err != nil {
|
||||
state.PushString("fs.append_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsExists checks if a file or directory exists
|
||||
func fsExists(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.exists: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
state.PushString("fs.exists: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
_, err = os.Stat(fullPath)
|
||||
state.PushBoolean(err == nil)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsRemoveFile removes a file
|
||||
func fsRemoveFile(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.remove_file: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
state.PushString("fs.remove_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Check if it's a directory
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
state.PushString("fs.remove_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
state.PushString("fs.remove_file: cannot remove directory, use remove_dir instead")
|
||||
return -1
|
||||
}
|
||||
|
||||
err = os.Remove(fullPath)
|
||||
if err != nil {
|
||||
state.PushString("fs.remove_file: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsGetInfo gets information about a file
|
||||
func fsGetInfo(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.get_info: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
state.PushString("fs.get_info: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
state.PushString("fs.get_info: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
state.NewTable()
|
||||
|
||||
state.PushString(info.Name())
|
||||
state.SetField(-2, "name")
|
||||
|
||||
state.PushNumber(float64(info.Size()))
|
||||
state.SetField(-2, "size")
|
||||
|
||||
state.PushNumber(float64(info.Mode()))
|
||||
state.SetField(-2, "mode")
|
||||
|
||||
state.PushNumber(float64(info.ModTime().Unix()))
|
||||
state.SetField(-2, "mod_time")
|
||||
|
||||
state.PushBoolean(info.IsDir())
|
||||
state.SetField(-2, "is_dir")
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsMakeDir creates a directory
|
||||
func fsMakeDir(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.make_dir: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
perm := os.FileMode(0755)
|
||||
if state.GetTop() >= 2 && state.IsNumber(2) {
|
||||
perm = os.FileMode(state.ToNumber(2))
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
state.PushString("fs.make_dir: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
err = os.MkdirAll(fullPath, perm)
|
||||
if err != nil {
|
||||
state.PushString("fs.make_dir: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsListDir lists the contents of a directory
|
||||
func fsListDir(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.list_dir: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
state.PushString("fs.list_dir: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
state.PushString("fs.list_dir: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
state.PushString("fs.list_dir: not a directory")
|
||||
return -1
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(fullPath)
|
||||
if err != nil {
|
||||
state.PushString("fs.list_dir: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
state.NewTable()
|
||||
|
||||
for i, file := range files {
|
||||
state.PushNumber(float64(i + 1))
|
||||
state.PushString(file.Name())
|
||||
state.SetTable(-3)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsRemoveDir removes a directory
|
||||
func fsRemoveDir(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.remove_dir: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
recursive := false
|
||||
if state.GetTop() >= 2 {
|
||||
recursive = state.ToBoolean(2)
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
state.PushString("fs.remove_dir: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
state.PushString("fs.remove_dir: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
state.PushString("fs.remove_dir: not a directory")
|
||||
return -1
|
||||
}
|
||||
|
||||
if recursive {
|
||||
err = os.RemoveAll(fullPath)
|
||||
} else {
|
||||
err = os.Remove(fullPath)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
state.PushString("fs.remove_dir: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsJoinPaths joins path components
|
||||
func fsJoinPaths(state *luajit.State) int {
|
||||
nargs := state.GetTop()
|
||||
if nargs < 1 {
|
||||
state.PushString("fs.join_paths: at least one path component required")
|
||||
return -1
|
||||
}
|
||||
|
||||
components := make([]string, nargs)
|
||||
for i := 1; i <= nargs; i++ {
|
||||
if !state.IsString(i) {
|
||||
state.PushString("fs.join_paths: all arguments must be strings")
|
||||
return -1
|
||||
}
|
||||
components[i-1] = state.ToString(i)
|
||||
}
|
||||
|
||||
result := filepath.Join(components...)
|
||||
result = strings.ReplaceAll(result, "\\", "/")
|
||||
|
||||
state.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsDirName returns the directory portion of a path
|
||||
func fsDirName(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.dir_name: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
dir = strings.ReplaceAll(dir, "\\", "/")
|
||||
|
||||
state.PushString(dir)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsBaseName returns the file name portion of a path
|
||||
func fsBaseName(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.base_name: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
base := filepath.Base(path)
|
||||
|
||||
state.PushString(base)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsExtension returns the file extension
|
||||
func fsExtension(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("fs.extension: path must be a string")
|
||||
return -1
|
||||
}
|
||||
path := state.ToString(1)
|
||||
|
||||
ext := filepath.Ext(path)
|
||||
|
||||
state.PushString(ext)
|
||||
return 1
|
||||
}
|
||||
|
||||
// RegisterFSFunctions registers filesystem functions with the Lua state
|
||||
func RegisterFSFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__fs_read_file", fsReadFile); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_write_file", fsWriteFile); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_append_file", fsAppendFile); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_exists", fsExists); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_remove_file", fsRemoveFile); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_get_info", fsGetInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_make_dir", fsMakeDir); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_list_dir", fsListDir); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_remove_dir", fsRemoveDir); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_join_paths", fsJoinPaths); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_dir_name", fsDirName); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_base_name", fsBaseName); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_extension", fsExtension); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
334
runner/http.go
Normal file
334
runner/http.go
Normal file
@ -0,0 +1,334 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"Moonshark/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Default HTTP client with sensible timeout
|
||||
var defaultFastClient = fasthttp.Client{
|
||||
MaxConnsPerHost: 1024,
|
||||
MaxIdleConnDuration: time.Minute,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
DisableHeaderNamesNormalizing: true,
|
||||
}
|
||||
|
||||
// HTTPClientConfig contains client settings
|
||||
type HTTPClientConfig struct {
|
||||
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
|
||||
DefaultTimeout time.Duration // Default request timeout
|
||||
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
|
||||
AllowRemote bool // Whether to allow remote connections
|
||||
}
|
||||
|
||||
// DefaultHTTPClientConfig provides sensible defaults
|
||||
var DefaultHTTPClientConfig = HTTPClientConfig{
|
||||
MaxTimeout: 60 * time.Second,
|
||||
DefaultTimeout: 30 * time.Second,
|
||||
MaxResponseSize: 10 * 1024 * 1024, // 10MB
|
||||
AllowRemote: true,
|
||||
}
|
||||
|
||||
// ApplyResponse applies a Response to a fasthttp.RequestCtx
|
||||
func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
|
||||
// Set status code
|
||||
ctx.SetStatusCode(resp.Status)
|
||||
|
||||
// Set headers
|
||||
for name, value := range resp.Headers {
|
||||
ctx.Response.Header.Set(name, value)
|
||||
}
|
||||
|
||||
// Set cookies
|
||||
for _, cookie := range resp.Cookies {
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// Process the body based on its type
|
||||
if resp.Body == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get a buffer from the pool
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
// Set body based on type
|
||||
switch body := resp.Body.(type) {
|
||||
case string:
|
||||
ctx.SetBodyString(body)
|
||||
case []byte:
|
||||
ctx.SetBody(body)
|
||||
case map[string]any, []any, []float64, []string, []int:
|
||||
// Marshal JSON
|
||||
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
||||
// Set content type if not already set
|
||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||
ctx.Response.Header.SetContentType("application/json")
|
||||
}
|
||||
ctx.SetBody(buf.Bytes())
|
||||
} else {
|
||||
// Fallback
|
||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||
}
|
||||
default:
|
||||
// Default to string representation
|
||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||
}
|
||||
}
|
||||
|
||||
// httpRequest makes an HTTP request and returns the result to Lua
|
||||
func httpRequest(state *luajit.State) int {
|
||||
// Get method (required)
|
||||
if !state.IsString(1) {
|
||||
state.PushString("http.client.request: method must be a string")
|
||||
return -1
|
||||
}
|
||||
method := strings.ToUpper(state.ToString(1))
|
||||
|
||||
// Get URL (required)
|
||||
if !state.IsString(2) {
|
||||
state.PushString("http.client.request: url must be a string")
|
||||
return -1
|
||||
}
|
||||
urlStr := state.ToString(2)
|
||||
|
||||
// Parse URL to check if it's valid
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
state.PushString("Invalid URL: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Get client configuration
|
||||
config := DefaultHTTPClientConfig
|
||||
|
||||
// Check if remote connections are allowed
|
||||
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
|
||||
state.PushString("Remote connections are not allowed")
|
||||
return -1
|
||||
}
|
||||
|
||||
// Use bytebufferpool for request and response
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set up request
|
||||
req.Header.SetMethod(method)
|
||||
req.SetRequestURI(urlStr)
|
||||
req.Header.Set("User-Agent", "Moonshark/1.0")
|
||||
|
||||
// Get body (optional)
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
if state.IsString(3) {
|
||||
// String body
|
||||
req.SetBodyString(state.ToString(3))
|
||||
} else if state.IsTable(3) {
|
||||
// Table body - convert to JSON
|
||||
luaTable, err := state.ToTable(3)
|
||||
if err != nil {
|
||||
state.PushString("Failed to parse body table: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
// Use bytebufferpool for JSON serialization
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
|
||||
state.PushString("Failed to convert body to JSON: " + err.Error())
|
||||
return -1
|
||||
}
|
||||
|
||||
req.SetBody(buf.Bytes())
|
||||
req.Header.SetContentType("application/json")
|
||||
} else {
|
||||
state.PushString("Body must be a string or table")
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
// Process options (headers, timeout, etc.)
|
||||
timeout := config.DefaultTimeout
|
||||
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
|
||||
// Process headers
|
||||
state.GetField(4, "headers")
|
||||
if state.IsTable(-1) {
|
||||
// Iterate through headers
|
||||
state.PushNil() // Start iteration
|
||||
for state.Next(-2) {
|
||||
// Stack now has key at -2 and value at -1
|
||||
if state.IsString(-2) && state.IsString(-1) {
|
||||
headerName := state.ToString(-2)
|
||||
headerValue := state.ToString(-1)
|
||||
req.Header.Set(headerName, headerValue)
|
||||
}
|
||||
state.Pop(1) // Pop value, leave key for next iteration
|
||||
}
|
||||
}
|
||||
state.Pop(1) // Pop headers table
|
||||
|
||||
// Get timeout
|
||||
state.GetField(4, "timeout")
|
||||
if state.IsNumber(-1) {
|
||||
requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second
|
||||
|
||||
// Apply max timeout if configured
|
||||
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
|
||||
timeout = config.MaxTimeout
|
||||
} else {
|
||||
timeout = requestTimeout
|
||||
}
|
||||
}
|
||||
state.Pop(1) // Pop timeout
|
||||
|
||||
// Process query parameters
|
||||
state.GetField(4, "query")
|
||||
if state.IsTable(-1) {
|
||||
// Create URL args
|
||||
args := req.URI().QueryArgs()
|
||||
|
||||
// Iterate through query params
|
||||
state.PushNil() // Start iteration
|
||||
for state.Next(-2) {
|
||||
if state.IsString(-2) {
|
||||
paramName := state.ToString(-2)
|
||||
|
||||
// Handle different value types
|
||||
if state.IsString(-1) {
|
||||
args.Add(paramName, state.ToString(-1))
|
||||
} else if state.IsNumber(-1) {
|
||||
args.Add(paramName, strings.TrimRight(strings.TrimRight(
|
||||
state.ToString(-1), "0"), "."))
|
||||
} else if state.IsBoolean(-1) {
|
||||
if state.ToBoolean(-1) {
|
||||
args.Add(paramName, "true")
|
||||
} else {
|
||||
args.Add(paramName, "false")
|
||||
}
|
||||
}
|
||||
}
|
||||
state.Pop(1) // Pop value, leave key for next iteration
|
||||
}
|
||||
}
|
||||
state.Pop(1) // Pop query table
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
_, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Execute request
|
||||
err = defaultFastClient.DoTimeout(req, resp, timeout)
|
||||
if err != nil {
|
||||
errStr := "Request failed: " + err.Error()
|
||||
if errors.Is(err, fasthttp.ErrTimeout) {
|
||||
errStr = "Request timed out after " + timeout.String()
|
||||
}
|
||||
state.PushString(errStr)
|
||||
return -1
|
||||
}
|
||||
|
||||
// Create response table
|
||||
state.NewTable()
|
||||
|
||||
// Set status code
|
||||
state.PushNumber(float64(resp.StatusCode()))
|
||||
state.SetField(-2, "status")
|
||||
|
||||
// Set status text
|
||||
statusText := fasthttp.StatusMessage(resp.StatusCode())
|
||||
state.PushString(statusText)
|
||||
state.SetField(-2, "status_text")
|
||||
|
||||
// Set body
|
||||
var respBody []byte
|
||||
|
||||
// Apply size limits to response
|
||||
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
|
||||
// Make a limited copy
|
||||
respBody = make([]byte, config.MaxResponseSize)
|
||||
copy(respBody, resp.Body())
|
||||
} else {
|
||||
respBody = resp.Body()
|
||||
}
|
||||
|
||||
state.PushString(string(respBody))
|
||||
state.SetField(-2, "body")
|
||||
|
||||
// Parse body as JSON if content type is application/json
|
||||
contentType := string(resp.Header.ContentType())
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
var jsonData any
|
||||
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
||||
if err := state.PushValue(jsonData); err == nil {
|
||||
state.SetField(-2, "json")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set headers
|
||||
state.NewTable()
|
||||
resp.Header.VisitAll(func(key, value []byte) {
|
||||
state.PushString(string(value))
|
||||
state.SetField(-2, string(key))
|
||||
})
|
||||
state.SetField(-2, "headers")
|
||||
|
||||
// Create ok field (true if status code is 2xx)
|
||||
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
|
||||
state.SetField(-2, "ok")
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// generateToken creates a cryptographically secure random token
|
||||
func generateToken(state *luajit.State) int {
|
||||
// Get the length from the Lua arguments (default to 32)
|
||||
length := 32
|
||||
if state.GetTop() >= 1 && state.IsNumber(1) {
|
||||
length = int(state.ToNumber(1))
|
||||
}
|
||||
|
||||
// Enforce minimum length for security
|
||||
if length < 16 {
|
||||
length = 16
|
||||
}
|
||||
|
||||
// Generate secure random bytes
|
||||
tokenBytes := make([]byte, length)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
logger.Error("Failed to generate secure token: %v", err)
|
||||
state.PushString("")
|
||||
return 1 // Return empty string on error
|
||||
}
|
||||
|
||||
// Encode as base64
|
||||
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
// Trim to requested length (base64 might be longer)
|
||||
if len(token) > length {
|
||||
token = token[:length]
|
||||
}
|
||||
|
||||
// Push the token to the Lua stack
|
||||
state.PushString(token)
|
||||
return 1 // One return value
|
||||
}
|
||||
142
runner/lua/crypto.lua
Normal file
142
runner/lua/crypto.lua
Normal file
@ -0,0 +1,142 @@
|
||||
--[[
|
||||
crypto.lua - Cryptographic functions powered by Go
|
||||
]]--
|
||||
|
||||
-- ======================================================================
|
||||
-- HASHING FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate hash digest using various algorithms
|
||||
-- Algorithms: md5, sha1, sha256, sha512
|
||||
-- Formats: hex (default), binary
|
||||
function hash(data, algorithm, format)
|
||||
if type(data) ~= "string" then
|
||||
error("hash: data must be a string", 2)
|
||||
end
|
||||
|
||||
algorithm = algorithm or "sha256"
|
||||
format = format or "hex"
|
||||
|
||||
return __crypto_hash(data, algorithm, format)
|
||||
end
|
||||
|
||||
function md5(data, format)
|
||||
return hash(data, "md5", format)
|
||||
end
|
||||
|
||||
function sha1(data, format)
|
||||
return hash(data, "sha1", format)
|
||||
end
|
||||
|
||||
function sha256(data, format)
|
||||
return hash(data, "sha256", format)
|
||||
end
|
||||
|
||||
function sha512(data, format)
|
||||
return hash(data, "sha512", format)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- HMAC FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate HMAC using various algorithms
|
||||
-- Algorithms: md5, sha1, sha256, sha512
|
||||
-- Formats: hex (default), binary
|
||||
function hmac(data, key, algorithm, format)
|
||||
if type(data) ~= "string" then
|
||||
error("hmac: data must be a string", 2)
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
error("hmac: key must be a string", 2)
|
||||
end
|
||||
|
||||
algorithm = algorithm or "sha256"
|
||||
format = format or "hex"
|
||||
|
||||
return __crypto_hmac(data, key, algorithm, format)
|
||||
end
|
||||
|
||||
function hmac_md5(data, key, format)
|
||||
return hmac(data, key, "md5", format)
|
||||
end
|
||||
|
||||
function hmac_sha1(data, key, format)
|
||||
return hmac(data, key, "sha1", format)
|
||||
end
|
||||
|
||||
function hmac_sha256(data, key, format)
|
||||
return hmac(data, key, "sha256", format)
|
||||
end
|
||||
|
||||
function hmac_sha512(data, key, format)
|
||||
return hmac(data, key, "sha512", format)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- RANDOM FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate random bytes
|
||||
-- Formats: binary (default), hex
|
||||
function random_bytes(length, secure, format)
|
||||
if type(length) ~= "number" or length <= 0 then
|
||||
error("random_bytes: length must be positive", 2)
|
||||
end
|
||||
|
||||
secure = secure ~= false -- Default to secure
|
||||
format = format or "binary"
|
||||
|
||||
return __crypto_random_bytes(length, secure, format)
|
||||
end
|
||||
|
||||
-- Generate random integer in range [min, max]
|
||||
function random_int(min, max, secure)
|
||||
if type(min) ~= "number" or type(max) ~= "number" then
|
||||
error("random_int: min and max must be numbers", 2)
|
||||
end
|
||||
|
||||
if max <= min then
|
||||
error("random_int: max must be greater than min", 2)
|
||||
end
|
||||
|
||||
secure = secure ~= false -- Default to secure
|
||||
|
||||
return __crypto_random_int(min, max, secure)
|
||||
end
|
||||
|
||||
-- Generate random string of specified length
|
||||
function random_string(length, charset, secure)
|
||||
if type(length) ~= "number" or length <= 0 then
|
||||
error("random_string: length must be positive", 2)
|
||||
end
|
||||
|
||||
secure = secure ~= false -- Default to secure
|
||||
|
||||
-- Default character set: alphanumeric
|
||||
charset = charset or "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
if type(charset) ~= "string" or #charset == 0 then
|
||||
error("random_string: charset must be non-empty", 2)
|
||||
end
|
||||
|
||||
local result = ""
|
||||
local charset_length = #charset
|
||||
|
||||
for i = 1, length do
|
||||
local index = random_int(1, charset_length, secure)
|
||||
result = result .. charset:sub(index, index)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- UUID FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate random UUID (v4)
|
||||
function uuid()
|
||||
return __crypto_uuid()
|
||||
end
|
||||
93
runner/lua/env.lua
Normal file
93
runner/lua/env.lua
Normal file
@ -0,0 +1,93 @@
|
||||
-- Environment variable module for Moonshark
|
||||
-- Provides access to persistent environment variables stored in .env file
|
||||
|
||||
-- Get an environment variable with a default value
|
||||
-- Returns the value if it exists, default_value otherwise
|
||||
function env_get(key, default_value)
|
||||
if type(key) ~= "string" then
|
||||
error("env_get: key must be a string")
|
||||
end
|
||||
|
||||
-- First check context for environment variables (no Go call needed)
|
||||
if _env and _env[key] ~= nil then
|
||||
return _env[key]
|
||||
end
|
||||
|
||||
return default_value
|
||||
end
|
||||
|
||||
-- Set an environment variable
|
||||
-- Returns true on success, false on failure
|
||||
function env_set(key, value)
|
||||
if type(key) ~= "string" then
|
||||
error("env_set: key must be a string")
|
||||
end
|
||||
|
||||
-- Update context immediately for future reads
|
||||
if not _env then
|
||||
_env = {}
|
||||
end
|
||||
_env[key] = value
|
||||
|
||||
-- Persist to Go backend
|
||||
return __env_set(key, value)
|
||||
end
|
||||
|
||||
-- Get all environment variables as a table
|
||||
-- Returns a table with all key-value pairs
|
||||
function env_get_all()
|
||||
-- Return context table directly if available
|
||||
if _env then
|
||||
local copy = {}
|
||||
for k, v in pairs(_env) do
|
||||
copy[k] = v
|
||||
end
|
||||
return copy
|
||||
end
|
||||
|
||||
-- Fallback to Go call
|
||||
return __env_get_all()
|
||||
end
|
||||
|
||||
-- Check if an environment variable exists
|
||||
-- Returns true if the variable exists, false otherwise
|
||||
function env_exists(key)
|
||||
if type(key) ~= "string" then
|
||||
error("env_exists: key must be a string")
|
||||
end
|
||||
|
||||
-- Check context first
|
||||
if _env then
|
||||
return _env[key] ~= nil
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
-- Set multiple environment variables from a table
|
||||
-- Returns true on success, false if any setting failed
|
||||
function env_set_many(vars)
|
||||
if type(vars) ~= "table" then
|
||||
error("env_set_many: vars must be a table")
|
||||
end
|
||||
|
||||
if not _env then
|
||||
_env = {}
|
||||
end
|
||||
|
||||
local success = true
|
||||
for key, value in pairs(vars) do
|
||||
if type(key) == "string" and type(value) == "string" then
|
||||
-- Update context
|
||||
_env[key] = value
|
||||
-- Persist to Go
|
||||
if not __env_set(key, value) then
|
||||
success = false
|
||||
end
|
||||
else
|
||||
error("env_set_many: all keys and values must be strings")
|
||||
end
|
||||
end
|
||||
|
||||
return success
|
||||
end
|
||||
134
runner/lua/fs.lua
Normal file
134
runner/lua/fs.lua
Normal file
@ -0,0 +1,134 @@
|
||||
function fs_read(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_read: path must be a string", 2)
|
||||
end
|
||||
return __fs_read_file(path)
|
||||
end
|
||||
|
||||
function fs_write(path, content)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_write: path must be a string", 2)
|
||||
end
|
||||
if type(content) ~= "string" then
|
||||
error("fs_write: content must be a string", 2)
|
||||
end
|
||||
return __fs_write_file(path, content)
|
||||
end
|
||||
|
||||
function fs_append(path, content)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_append: path must be a string", 2)
|
||||
end
|
||||
if type(content) ~= "string" then
|
||||
error("fs_append: content must be a string", 2)
|
||||
end
|
||||
return __fs_append_file(path, content)
|
||||
end
|
||||
|
||||
function fs_exists(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_exists: path must be a string", 2)
|
||||
end
|
||||
return __fs_exists(path)
|
||||
end
|
||||
|
||||
function fs_remove(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_remove: path must be a string", 2)
|
||||
end
|
||||
return __fs_remove_file(path)
|
||||
end
|
||||
|
||||
function fs_info(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_info: path must be a string", 2)
|
||||
end
|
||||
local info = __fs_get_info(path)
|
||||
|
||||
-- Convert the Unix timestamp to a readable date
|
||||
if info and info.mod_time then
|
||||
info.mod_time_str = os.date("%Y-%m-%d %H:%M:%S", info.mod_time)
|
||||
end
|
||||
|
||||
return info
|
||||
end
|
||||
|
||||
-- Directory Operations
|
||||
function fs_mkdir(path, mode)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_mkdir: path must be a string", 2)
|
||||
end
|
||||
mode = mode or 0755
|
||||
return __fs_make_dir(path, mode)
|
||||
end
|
||||
|
||||
function fs_ls(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_ls: path must be a string", 2)
|
||||
end
|
||||
return __fs_list_dir(path)
|
||||
end
|
||||
|
||||
function fs_rmdir(path, recursive)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_rmdir: path must be a string", 2)
|
||||
end
|
||||
recursive = recursive or false
|
||||
return __fs_remove_dir(path, recursive)
|
||||
end
|
||||
|
||||
-- Path Operations
|
||||
function fs_join_paths(...)
|
||||
return __fs_join_paths(...)
|
||||
end
|
||||
|
||||
function fs_dir_name(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_dir_name: path must be a string", 2)
|
||||
end
|
||||
return __fs_dir_name(path)
|
||||
end
|
||||
|
||||
function fs_base_name(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_base_name: path must be a string", 2)
|
||||
end
|
||||
return __fs_base_name(path)
|
||||
end
|
||||
|
||||
function fs_extension(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_extension: path must be a string", 2)
|
||||
end
|
||||
return __fs_extension(path)
|
||||
end
|
||||
|
||||
-- Utility Functions
|
||||
function fs_read_json(path)
|
||||
local content = fs_read(path)
|
||||
if not content then
|
||||
return nil, "Could not read file"
|
||||
end
|
||||
|
||||
local ok, result = pcall(json.decode, content)
|
||||
if not ok then
|
||||
return nil, "Invalid JSON: " .. tostring(result)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
function fs_write_json(path, data, pretty)
|
||||
if type(data) ~= "table" then
|
||||
error("fs_write_json: data must be a table", 2)
|
||||
end
|
||||
|
||||
local content
|
||||
if pretty then
|
||||
content = json.pretty_print(data)
|
||||
else
|
||||
content = json.encode(data)
|
||||
end
|
||||
|
||||
return fs_write(path, content)
|
||||
end
|
||||
422
runner/lua/json.lua
Normal file
422
runner/lua/json.lua
Normal file
@ -0,0 +1,422 @@
|
||||
-- json.lua: High-performance JSON module for Moonshark
|
||||
|
||||
-- Pre-computed escape sequences to avoid recreating table
|
||||
local escape_chars = {
|
||||
['"'] = '\\"', ['\\'] = '\\\\',
|
||||
['\n'] = '\\n', ['\r'] = '\\r', ['\t'] = '\\t'
|
||||
}
|
||||
|
||||
function json_go_encode(value)
|
||||
return __json_marshal(value)
|
||||
end
|
||||
|
||||
function json_go_decode(str)
|
||||
if type(str) ~= "string" then
|
||||
error("json_decode: expected string, got " .. type(str), 2)
|
||||
end
|
||||
return __json_unmarshal(str)
|
||||
end
|
||||
|
||||
function json_encode(data)
|
||||
local t = type(data)
|
||||
|
||||
if t == "nil" then return "null" end
|
||||
if t == "boolean" then return data and "true" or "false" end
|
||||
if t == "number" then return tostring(data) end
|
||||
|
||||
if t == "string" then
|
||||
return '"' .. data:gsub('[\\"\n\r\t]', escape_chars) .. '"'
|
||||
end
|
||||
|
||||
if t == "table" then
|
||||
local isArray = true
|
||||
local count = 0
|
||||
|
||||
-- Check if it's an array in one pass
|
||||
for k, _ in pairs(data) do
|
||||
count = count + 1
|
||||
if type(k) ~= "number" or k ~= count or k < 1 then
|
||||
isArray = false
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if isArray then
|
||||
local result = {}
|
||||
for i = 1, count do
|
||||
result[i] = json_encode(data[i])
|
||||
end
|
||||
return "[" .. table.concat(result, ",") .. "]"
|
||||
else
|
||||
local result = {}
|
||||
local index = 1
|
||||
for k, v in pairs(data) do
|
||||
if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then
|
||||
result[index] = json_encode(k) .. ":" .. json_encode(v)
|
||||
index = index + 1
|
||||
end
|
||||
end
|
||||
return "{" .. table.concat(result, ",") .. "}"
|
||||
end
|
||||
end
|
||||
|
||||
return "null" -- Unsupported type
|
||||
end
|
||||
|
||||
function json_decode(data)
|
||||
local pos = 1
|
||||
local len = #data
|
||||
|
||||
-- Pre-compute byte values
|
||||
local b_space = string.byte(' ')
|
||||
local b_tab = string.byte('\t')
|
||||
local b_cr = string.byte('\r')
|
||||
local b_lf = string.byte('\n')
|
||||
local b_quote = string.byte('"')
|
||||
local b_backslash = string.byte('\\')
|
||||
local b_slash = string.byte('/')
|
||||
local b_lcurly = string.byte('{')
|
||||
local b_rcurly = string.byte('}')
|
||||
local b_lbracket = string.byte('[')
|
||||
local b_rbracket = string.byte(']')
|
||||
local b_colon = string.byte(':')
|
||||
local b_comma = string.byte(',')
|
||||
local b_0 = string.byte('0')
|
||||
local b_9 = string.byte('9')
|
||||
local b_minus = string.byte('-')
|
||||
local b_plus = string.byte('+')
|
||||
local b_dot = string.byte('.')
|
||||
local b_e = string.byte('e')
|
||||
local b_E = string.byte('E')
|
||||
|
||||
-- Skip whitespace more efficiently
|
||||
local function skip()
|
||||
local b
|
||||
while pos <= len do
|
||||
b = data:byte(pos)
|
||||
if b > b_space or (b ~= b_space and b ~= b_tab and b ~= b_cr and b ~= b_lf) then
|
||||
break
|
||||
end
|
||||
pos = pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
-- Forward declarations
|
||||
local parse_value, parse_string, parse_number, parse_object, parse_array
|
||||
|
||||
-- Parse a string more efficiently
|
||||
parse_string = function()
|
||||
pos = pos + 1 -- Skip opening quote
|
||||
|
||||
if pos > len then
|
||||
error("Unterminated string")
|
||||
end
|
||||
|
||||
-- Use a table to build the string
|
||||
local result = {}
|
||||
local result_pos = 1
|
||||
local start = pos
|
||||
local c, b
|
||||
|
||||
while pos <= len do
|
||||
b = data:byte(pos)
|
||||
|
||||
if b == b_backslash then
|
||||
-- Add the chunk before the escape character
|
||||
if pos > start then
|
||||
result[result_pos] = data:sub(start, pos - 1)
|
||||
result_pos = result_pos + 1
|
||||
end
|
||||
|
||||
pos = pos + 1
|
||||
if pos > len then
|
||||
error("Unterminated string escape")
|
||||
end
|
||||
|
||||
c = data:byte(pos)
|
||||
if c == b_quote then
|
||||
result[result_pos] = '"'
|
||||
elseif c == b_backslash then
|
||||
result[result_pos] = '\\'
|
||||
elseif c == b_slash then
|
||||
result[result_pos] = '/'
|
||||
elseif c == string.byte('b') then
|
||||
result[result_pos] = '\b'
|
||||
elseif c == string.byte('f') then
|
||||
result[result_pos] = '\f'
|
||||
elseif c == string.byte('n') then
|
||||
result[result_pos] = '\n'
|
||||
elseif c == string.byte('r') then
|
||||
result[result_pos] = '\r'
|
||||
elseif c == string.byte('t') then
|
||||
result[result_pos] = '\t'
|
||||
else
|
||||
result[result_pos] = data:sub(pos, pos)
|
||||
end
|
||||
|
||||
result_pos = result_pos + 1
|
||||
pos = pos + 1
|
||||
start = pos
|
||||
elseif b == b_quote then
|
||||
-- Add the final chunk
|
||||
if pos > start then
|
||||
result[result_pos] = data:sub(start, pos - 1)
|
||||
result_pos = result_pos + 1
|
||||
end
|
||||
|
||||
pos = pos + 1
|
||||
return table.concat(result)
|
||||
else
|
||||
pos = pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
error("Unterminated string")
|
||||
end
|
||||
|
||||
-- Parse a number more efficiently
|
||||
parse_number = function()
|
||||
local start = pos
|
||||
local b = data:byte(pos)
|
||||
|
||||
-- Skip any sign
|
||||
if b == b_minus then
|
||||
pos = pos + 1
|
||||
if pos > len then
|
||||
error("Malformed number")
|
||||
end
|
||||
b = data:byte(pos)
|
||||
end
|
||||
|
||||
-- Integer part
|
||||
if b < b_0 or b > b_9 then
|
||||
error("Malformed number")
|
||||
end
|
||||
|
||||
repeat
|
||||
pos = pos + 1
|
||||
if pos > len then break end
|
||||
b = data:byte(pos)
|
||||
until b < b_0 or b > b_9
|
||||
|
||||
-- Fractional part
|
||||
if pos <= len and b == b_dot then
|
||||
pos = pos + 1
|
||||
if pos > len or data:byte(pos) < b_0 or data:byte(pos) > b_9 then
|
||||
error("Malformed number")
|
||||
end
|
||||
|
||||
repeat
|
||||
pos = pos + 1
|
||||
if pos > len then break end
|
||||
b = data:byte(pos)
|
||||
until b < b_0 or b > b_9
|
||||
end
|
||||
|
||||
-- Exponent
|
||||
if pos <= len and (b == b_e or b == b_E) then
|
||||
pos = pos + 1
|
||||
if pos > len then
|
||||
error("Malformed number")
|
||||
end
|
||||
|
||||
b = data:byte(pos)
|
||||
if b == b_plus or b == b_minus then
|
||||
pos = pos + 1
|
||||
if pos > len then
|
||||
error("Malformed number")
|
||||
end
|
||||
b = data:byte(pos)
|
||||
end
|
||||
|
||||
if b < b_0 or b > b_9 then
|
||||
error("Malformed number")
|
||||
end
|
||||
|
||||
repeat
|
||||
pos = pos + 1
|
||||
if pos > len then break end
|
||||
b = data:byte(pos)
|
||||
until b < b_0 or b > b_9
|
||||
end
|
||||
|
||||
return tonumber(data:sub(start, pos - 1))
|
||||
end
|
||||
|
||||
-- Parse an object more efficiently
|
||||
parse_object = function()
|
||||
pos = pos + 1 -- Skip opening brace
|
||||
local obj = {}
|
||||
|
||||
skip()
|
||||
if pos <= len and data:byte(pos) == b_rcurly then
|
||||
pos = pos + 1
|
||||
return obj
|
||||
end
|
||||
|
||||
while pos <= len do
|
||||
skip()
|
||||
|
||||
if data:byte(pos) ~= b_quote then
|
||||
error("Expected string key")
|
||||
end
|
||||
|
||||
local key = parse_string()
|
||||
skip()
|
||||
|
||||
if data:byte(pos) ~= b_colon then
|
||||
error("Expected colon")
|
||||
end
|
||||
pos = pos + 1
|
||||
|
||||
obj[key] = parse_value()
|
||||
skip()
|
||||
|
||||
local b = data:byte(pos)
|
||||
if b == b_rcurly then
|
||||
pos = pos + 1
|
||||
return obj
|
||||
end
|
||||
|
||||
if b ~= b_comma then
|
||||
error("Expected comma or closing brace")
|
||||
end
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
error("Unterminated object")
|
||||
end
|
||||
|
||||
-- Parse an array more efficiently
|
||||
parse_array = function()
|
||||
pos = pos + 1 -- Skip opening bracket
|
||||
local arr = {}
|
||||
local index = 1
|
||||
|
||||
skip()
|
||||
if pos <= len and data:byte(pos) == b_rbracket then
|
||||
pos = pos + 1
|
||||
return arr
|
||||
end
|
||||
|
||||
while pos <= len do
|
||||
arr[index] = parse_value()
|
||||
index = index + 1
|
||||
|
||||
skip()
|
||||
|
||||
local b = data:byte(pos)
|
||||
if b == b_rbracket then
|
||||
pos = pos + 1
|
||||
return arr
|
||||
end
|
||||
|
||||
if b ~= b_comma then
|
||||
error("Expected comma or closing bracket")
|
||||
end
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
error("Unterminated array")
|
||||
end
|
||||
|
||||
-- Parse a value more efficiently
|
||||
parse_value = function()
|
||||
skip()
|
||||
|
||||
if pos > len then
|
||||
error("Unexpected end of input")
|
||||
end
|
||||
|
||||
local b = data:byte(pos)
|
||||
|
||||
if b == b_quote then
|
||||
return parse_string()
|
||||
elseif b == b_lcurly then
|
||||
return parse_object()
|
||||
elseif b == b_lbracket then
|
||||
return parse_array()
|
||||
elseif b == string.byte('n') and pos + 3 <= len and data:sub(pos, pos + 3) == "null" then
|
||||
pos = pos + 4
|
||||
return nil
|
||||
elseif b == string.byte('t') and pos + 3 <= len and data:sub(pos, pos + 3) == "true" then
|
||||
pos = pos + 4
|
||||
return true
|
||||
elseif b == string.byte('f') and pos + 4 <= len and data:sub(pos, pos + 4) == "false" then
|
||||
pos = pos + 5
|
||||
return false
|
||||
elseif b == b_minus or (b >= b_0 and b <= b_9) then
|
||||
return parse_number()
|
||||
else
|
||||
error("Unexpected character: " .. string.char(b))
|
||||
end
|
||||
end
|
||||
|
||||
skip()
|
||||
local result = parse_value()
|
||||
skip()
|
||||
|
||||
if pos <= len then
|
||||
error("Unexpected trailing characters")
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
function json_is_valid(str)
|
||||
if type(str) ~= "string" then return false end
|
||||
local status, _ = pcall(json_decode, str)
|
||||
return status
|
||||
end
|
||||
|
||||
function json_pretty_print(value)
|
||||
if type(value) == "string" then
|
||||
value = json_decode(value)
|
||||
end
|
||||
|
||||
local function stringify(val, indent, visited)
|
||||
visited = visited or {}
|
||||
indent = indent or 0
|
||||
local spaces = string.rep(" ", indent)
|
||||
|
||||
if type(val) == "table" then
|
||||
if visited[val] then return "{...}" end
|
||||
visited[val] = true
|
||||
|
||||
local isArray = true
|
||||
local i = 1
|
||||
for k in pairs(val) do
|
||||
if type(k) ~= "number" or k ~= i then
|
||||
isArray = false
|
||||
break
|
||||
end
|
||||
i = i + 1
|
||||
end
|
||||
|
||||
local result = isArray and "[\n" or "{\n"
|
||||
local first = true
|
||||
|
||||
if isArray then
|
||||
for i, v in ipairs(val) do
|
||||
if not first then result = result .. ",\n" end
|
||||
first = false
|
||||
result = result .. spaces .. " " .. stringify(v, indent + 1, visited)
|
||||
end
|
||||
else
|
||||
for k, v in pairs(val) do
|
||||
if not first then result = result .. ",\n" end
|
||||
first = false
|
||||
result = result .. spaces .. " \"" .. tostring(k) .. "\": " .. stringify(v, indent + 1, visited)
|
||||
end
|
||||
end
|
||||
|
||||
return result .. "\n" .. spaces .. (isArray and "]" or "}")
|
||||
elseif type(val) == "string" then
|
||||
return "\"" .. val:gsub('\\', '\\\\'):gsub('"', '\\"'):gsub('\n', '\\n') .. "\""
|
||||
else
|
||||
return tostring(val)
|
||||
end
|
||||
end
|
||||
|
||||
return stringify(value)
|
||||
end
|
||||
802
runner/lua/math.lua
Normal file
802
runner/lua/math.lua
Normal file
@ -0,0 +1,802 @@
|
||||
--[[
|
||||
math.lua - High-performance math library
|
||||
]]--
|
||||
|
||||
local math_ext = {}
|
||||
|
||||
-- Import standard math functions
|
||||
for name, func in pairs(_G.math) do
|
||||
math_ext[name] = func
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- CONSTANTS (higher precision)
|
||||
-- ======================================================================
|
||||
|
||||
math_ext.pi = 3.14159265358979323846
|
||||
math_ext.tau = 6.28318530717958647693 -- 2*pi
|
||||
math_ext.e = 2.71828182845904523536
|
||||
math_ext.phi = 1.61803398874989484820 -- Golden ratio
|
||||
math_ext.sqrt2 = 1.41421356237309504880
|
||||
math_ext.sqrt3 = 1.73205080756887729353
|
||||
math_ext.ln2 = 0.69314718055994530942
|
||||
math_ext.ln10 = 2.30258509299404568402
|
||||
math_ext.infinity = 1/0
|
||||
math_ext.nan = 0/0
|
||||
|
||||
-- ======================================================================
|
||||
-- EXTENDED FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Cube root (handles negative numbers correctly)
|
||||
function math_ext.cbrt(x)
|
||||
return x < 0 and -(-x)^(1/3) or x^(1/3)
|
||||
end
|
||||
|
||||
-- Hypotenuse of right-angled triangle
|
||||
function math_ext.hypot(x, y)
|
||||
return math.sqrt(x * x + y * y)
|
||||
end
|
||||
|
||||
-- Check if value is NaN
|
||||
function math_ext.isnan(x)
|
||||
return x ~= x
|
||||
end
|
||||
|
||||
-- Check if value is finite
|
||||
function math_ext.isfinite(x)
|
||||
return x > -math_ext.infinity and x < math_ext.infinity
|
||||
end
|
||||
|
||||
-- Sign function (-1, 0, 1)
|
||||
function math_ext.sign(x)
|
||||
return x > 0 and 1 or (x < 0 and -1 or 0)
|
||||
end
|
||||
|
||||
-- Clamp value between min and max
|
||||
function math_ext.clamp(x, min, max)
|
||||
return x < min and min or (x > max and max or x)
|
||||
end
|
||||
|
||||
-- Linear interpolation
|
||||
function math_ext.lerp(a, b, t)
|
||||
return a + (b - a) * t
|
||||
end
|
||||
|
||||
-- Smooth step interpolation
|
||||
function math_ext.smoothstep(a, b, t)
|
||||
t = math_ext.clamp((t - a) / (b - a), 0, 1)
|
||||
return t * t * (3 - 2 * t)
|
||||
end
|
||||
|
||||
-- Map value from one range to another
|
||||
function math_ext.map(x, in_min, in_max, out_min, out_max)
|
||||
return (x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min
|
||||
end
|
||||
|
||||
-- Round to nearest integer
|
||||
function math_ext.round(x)
|
||||
return x >= 0 and math.floor(x + 0.5) or math.ceil(x - 0.5)
|
||||
end
|
||||
|
||||
-- Round to specified decimal places
|
||||
function math_ext.roundto(x, decimals)
|
||||
local mult = 10 ^ (decimals or 0)
|
||||
return math.floor(x * mult + 0.5) / mult
|
||||
end
|
||||
|
||||
-- Normalize angle to [-π, π]
|
||||
function math_ext.normalize_angle(angle)
|
||||
return angle - 2 * math_ext.pi * math.floor((angle + math_ext.pi) / (2 * math_ext.pi))
|
||||
end
|
||||
|
||||
-- Distance between points
|
||||
function math_ext.distance(x1, y1, x2, y2)
|
||||
local dx, dy = x2 - x1, y2 - y1
|
||||
return math.sqrt(dx * dx + dy * dy)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- RANDOM NUMBER FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Random float in range [min, max)
|
||||
function math_ext.randomf(min, max)
|
||||
if not min and not max then
|
||||
return math.random()
|
||||
elseif not max then
|
||||
max = min
|
||||
min = 0
|
||||
end
|
||||
return min + math.random() * (max - min)
|
||||
end
|
||||
|
||||
-- Random integer in range [min, max]
|
||||
function math_ext.randint(min, max)
|
||||
if not max then
|
||||
max = min
|
||||
min = 1
|
||||
end
|
||||
return math.floor(math.random() * (max - min + 1) + min)
|
||||
end
|
||||
|
||||
-- Random boolean with probability p (default 0.5)
|
||||
function math_ext.randboolean(p)
|
||||
p = p or 0.5
|
||||
return math.random() < p
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- STATISTICS FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Sum of values
|
||||
function math_ext.sum(t)
|
||||
if type(t) ~= "table" then return 0 end
|
||||
|
||||
local sum = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
sum = sum + t[i]
|
||||
end
|
||||
end
|
||||
return sum
|
||||
end
|
||||
|
||||
-- Mean (average) of values
|
||||
function math_ext.mean(t)
|
||||
if type(t) ~= "table" or #t == 0 then return 0 end
|
||||
|
||||
local sum = 0
|
||||
local count = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
sum = sum + t[i]
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
return count > 0 and sum / count or 0
|
||||
end
|
||||
|
||||
-- Median of values
|
||||
function math_ext.median(t)
|
||||
if type(t) ~= "table" or #t == 0 then return 0 end
|
||||
|
||||
local nums = {}
|
||||
local count = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
count = count + 1
|
||||
nums[count] = t[i]
|
||||
end
|
||||
end
|
||||
|
||||
if count == 0 then return 0 end
|
||||
|
||||
table.sort(nums)
|
||||
|
||||
if count % 2 == 0 then
|
||||
return (nums[count/2] + nums[count/2 + 1]) / 2
|
||||
else
|
||||
return nums[math.ceil(count/2)]
|
||||
end
|
||||
end
|
||||
|
||||
-- Variance of values
|
||||
function math_ext.variance(t)
|
||||
if type(t) ~= "table" then return 0 end
|
||||
|
||||
local count = 0
|
||||
local m = math_ext.mean(t)
|
||||
local sum = 0
|
||||
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
local dev = t[i] - m
|
||||
sum = sum + dev * dev
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
|
||||
return count > 1 and sum / count or 0
|
||||
end
|
||||
|
||||
-- Standard deviation
|
||||
function math_ext.stdev(t)
|
||||
return math.sqrt(math_ext.variance(t))
|
||||
end
|
||||
|
||||
-- Population variance
|
||||
function math_ext.pvariance(t)
|
||||
if type(t) ~= "table" then return 0 end
|
||||
|
||||
local count = 0
|
||||
local m = math_ext.mean(t)
|
||||
local sum = 0
|
||||
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
local dev = t[i] - m
|
||||
sum = sum + dev * dev
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
|
||||
return count > 0 and sum / count or 0
|
||||
end
|
||||
|
||||
-- Population standard deviation
|
||||
function math_ext.pstdev(t)
|
||||
return math.sqrt(math_ext.pvariance(t))
|
||||
end
|
||||
|
||||
-- Mode (most common value)
|
||||
function math_ext.mode(t)
|
||||
if type(t) ~= "table" or #t == 0 then return nil end
|
||||
|
||||
local counts = {}
|
||||
local most_frequent = nil
|
||||
local max_count = 0
|
||||
|
||||
for i=1, #t do
|
||||
local v = t[i]
|
||||
counts[v] = (counts[v] or 0) + 1
|
||||
if counts[v] > max_count then
|
||||
max_count = counts[v]
|
||||
most_frequent = v
|
||||
end
|
||||
end
|
||||
|
||||
return most_frequent
|
||||
end
|
||||
|
||||
-- Min and max simultaneously (faster than calling both separately)
|
||||
function math_ext.minmax(t)
|
||||
if type(t) ~= "table" or #t == 0 then return nil, nil end
|
||||
|
||||
local min, max
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
min = t[i]
|
||||
max = t[i]
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if min == nil then return nil, nil end
|
||||
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
if t[i] < min then min = t[i] end
|
||||
if t[i] > max then max = t[i] end
|
||||
end
|
||||
end
|
||||
|
||||
return min, max
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- VECTOR OPERATIONS (2D/3D vectors)
|
||||
-- ======================================================================
|
||||
|
||||
-- 2D Vector operations
|
||||
math_ext.vec2 = {
|
||||
new = function(x, y)
|
||||
return {x = x or 0, y = y or 0}
|
||||
end,
|
||||
|
||||
copy = function(v)
|
||||
return {x = v.x, y = v.y}
|
||||
end,
|
||||
|
||||
add = function(a, b)
|
||||
return {x = a.x + b.x, y = a.y + b.y}
|
||||
end,
|
||||
|
||||
sub = function(a, b)
|
||||
return {x = a.x - b.x, y = a.y - b.y}
|
||||
end,
|
||||
|
||||
mul = function(a, b)
|
||||
if type(b) == "number" then
|
||||
return {x = a.x * b, y = a.y * b}
|
||||
end
|
||||
return {x = a.x * b.x, y = a.y * b.y}
|
||||
end,
|
||||
|
||||
div = function(a, b)
|
||||
if type(b) == "number" then
|
||||
local inv = 1 / b
|
||||
return {x = a.x * inv, y = a.y * inv}
|
||||
end
|
||||
return {x = a.x / b.x, y = a.y / b.y}
|
||||
end,
|
||||
|
||||
dot = function(a, b)
|
||||
return a.x * b.x + a.y * b.y
|
||||
end,
|
||||
|
||||
length = function(v)
|
||||
return math.sqrt(v.x * v.x + v.y * v.y)
|
||||
end,
|
||||
|
||||
length_squared = function(v)
|
||||
return v.x * v.x + v.y * v.y
|
||||
end,
|
||||
|
||||
distance = function(a, b)
|
||||
local dx, dy = b.x - a.x, b.y - a.y
|
||||
return math.sqrt(dx * dx + dy * dy)
|
||||
end,
|
||||
|
||||
distance_squared = function(a, b)
|
||||
local dx, dy = b.x - a.x, b.y - a.y
|
||||
return dx * dx + dy * dy
|
||||
end,
|
||||
|
||||
normalize = function(v)
|
||||
local len = math.sqrt(v.x * v.x + v.y * v.y)
|
||||
if len > 1e-10 then
|
||||
local inv_len = 1 / len
|
||||
return {x = v.x * inv_len, y = v.y * inv_len}
|
||||
end
|
||||
return {x = 0, y = 0}
|
||||
end,
|
||||
|
||||
rotate = function(v, angle)
|
||||
local c, s = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
x = v.x * c - v.y * s,
|
||||
y = v.x * s + v.y * c
|
||||
}
|
||||
end,
|
||||
|
||||
angle = function(v)
|
||||
return math.atan2(v.y, v.x)
|
||||
end,
|
||||
|
||||
lerp = function(a, b, t)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
return {
|
||||
x = a.x + (b.x - a.x) * t,
|
||||
y = a.y + (b.y - a.y) * t
|
||||
}
|
||||
end,
|
||||
|
||||
reflect = function(v, normal)
|
||||
local dot = v.x * normal.x + v.y * normal.y
|
||||
return {
|
||||
x = v.x - 2 * dot * normal.x,
|
||||
y = v.y - 2 * dot * normal.y
|
||||
}
|
||||
end
|
||||
}
|
||||
|
||||
-- 3D Vector operations
|
||||
math_ext.vec3 = {
|
||||
new = function(x, y, z)
|
||||
return {x = x or 0, y = y or 0, z = z or 0}
|
||||
end,
|
||||
|
||||
copy = function(v)
|
||||
return {x = v.x, y = v.y, z = v.z}
|
||||
end,
|
||||
|
||||
add = function(a, b)
|
||||
return {x = a.x + b.x, y = a.y + b.y, z = a.z + b.z}
|
||||
end,
|
||||
|
||||
sub = function(a, b)
|
||||
return {x = a.x - b.x, y = a.y - b.y, z = a.z - b.z}
|
||||
end,
|
||||
|
||||
mul = function(a, b)
|
||||
if type(b) == "number" then
|
||||
return {x = a.x * b, y = a.y * b, z = a.z * b}
|
||||
end
|
||||
return {x = a.x * b.x, y = a.y * b.y, z = a.z * b.z}
|
||||
end,
|
||||
|
||||
div = function(a, b)
|
||||
if type(b) == "number" then
|
||||
local inv = 1 / b
|
||||
return {x = a.x * inv, y = a.y * inv, z = a.z * inv}
|
||||
end
|
||||
return {x = a.x / b.x, y = a.y / b.y, z = a.z / b.z}
|
||||
end,
|
||||
|
||||
dot = function(a, b)
|
||||
return a.x * b.x + a.y * b.y + a.z * b.z
|
||||
end,
|
||||
|
||||
cross = function(a, b)
|
||||
return {
|
||||
x = a.y * b.z - a.z * b.y,
|
||||
y = a.z * b.x - a.x * b.z,
|
||||
z = a.x * b.y - a.y * b.x
|
||||
}
|
||||
end,
|
||||
|
||||
length = function(v)
|
||||
return math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
|
||||
end,
|
||||
|
||||
length_squared = function(v)
|
||||
return v.x * v.x + v.y * v.y + v.z * v.z
|
||||
end,
|
||||
|
||||
distance = function(a, b)
|
||||
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
|
||||
return math.sqrt(dx * dx + dy * dy + dz * dz)
|
||||
end,
|
||||
|
||||
distance_squared = function(a, b)
|
||||
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
|
||||
return dx * dx + dy * dy + dz * dz
|
||||
end,
|
||||
|
||||
normalize = function(v)
|
||||
local len = math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
|
||||
if len > 1e-10 then
|
||||
local inv_len = 1 / len
|
||||
return {x = v.x * inv_len, y = v.y * inv_len, z = v.z * inv_len}
|
||||
end
|
||||
return {x = 0, y = 0, z = 0}
|
||||
end,
|
||||
|
||||
lerp = function(a, b, t)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
return {
|
||||
x = a.x + (b.x - a.x) * t,
|
||||
y = a.y + (b.y - a.y) * t,
|
||||
z = a.z + (b.z - a.z) * t
|
||||
}
|
||||
end,
|
||||
|
||||
reflect = function(v, normal)
|
||||
local dot = v.x * normal.x + v.y * normal.y + v.z * normal.z
|
||||
return {
|
||||
x = v.x - 2 * dot * normal.x,
|
||||
y = v.y - 2 * dot * normal.y,
|
||||
z = v.z - 2 * dot * normal.z
|
||||
}
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- MATRIX OPERATIONS (2x2 and 3x3 matrices)
|
||||
-- ======================================================================
|
||||
|
||||
math_ext.mat2 = {
|
||||
-- Create a new 2x2 matrix
|
||||
new = function(a, b, c, d)
|
||||
return {
|
||||
{a or 1, b or 0},
|
||||
{c or 0, d or 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Create identity matrix
|
||||
identity = function()
|
||||
return {{1, 0}, {0, 1}}
|
||||
end,
|
||||
|
||||
-- Matrix multiplication
|
||||
mul = function(a, b)
|
||||
return {
|
||||
{
|
||||
a[1][1] * b[1][1] + a[1][2] * b[2][1],
|
||||
a[1][1] * b[1][2] + a[1][2] * b[2][2]
|
||||
},
|
||||
{
|
||||
a[2][1] * b[1][1] + a[2][2] * b[2][1],
|
||||
a[2][1] * b[1][2] + a[2][2] * b[2][2]
|
||||
}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Determinant
|
||||
det = function(m)
|
||||
return m[1][1] * m[2][2] - m[1][2] * m[2][1]
|
||||
end,
|
||||
|
||||
-- Inverse matrix
|
||||
inverse = function(m)
|
||||
local det = m[1][1] * m[2][2] - m[1][2] * m[2][1]
|
||||
if math.abs(det) < 1e-10 then
|
||||
return nil -- Matrix is not invertible
|
||||
end
|
||||
|
||||
local inv_det = 1 / det
|
||||
return {
|
||||
{m[2][2] * inv_det, -m[1][2] * inv_det},
|
||||
{-m[2][1] * inv_det, m[1][1] * inv_det}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Rotation matrix
|
||||
rotation = function(angle)
|
||||
local cos, sin = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
{cos, -sin},
|
||||
{sin, cos}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Apply matrix to vector
|
||||
transform = function(m, v)
|
||||
return {
|
||||
x = m[1][1] * v.x + m[1][2] * v.y,
|
||||
y = m[2][1] * v.x + m[2][2] * v.y
|
||||
}
|
||||
end,
|
||||
|
||||
-- Scale matrix
|
||||
scale = function(sx, sy)
|
||||
sy = sy or sx
|
||||
return {
|
||||
{sx, 0},
|
||||
{0, sy}
|
||||
}
|
||||
end
|
||||
}
|
||||
|
||||
math_ext.mat3 = {
|
||||
-- Create identity matrix 3x3
|
||||
identity = function()
|
||||
return {
|
||||
{1, 0, 0},
|
||||
{0, 1, 0},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Create a 2D transformation matrix (translation, rotation, scale)
|
||||
transform = function(x, y, angle, sx, sy)
|
||||
sx = sx or 1
|
||||
sy = sy or sx
|
||||
local cos, sin = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
{cos * sx, -sin * sy, x},
|
||||
{sin * sx, cos * sy, y},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Matrix multiplication
|
||||
mul = function(a, b)
|
||||
local result = {
|
||||
{0, 0, 0},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0}
|
||||
}
|
||||
|
||||
for i = 1, 3 do
|
||||
for j = 1, 3 do
|
||||
for k = 1, 3 do
|
||||
result[i][j] = result[i][j] + a[i][k] * b[k][j]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end,
|
||||
|
||||
-- Apply matrix to point (homogeneous coordinates)
|
||||
transform_point = function(m, v)
|
||||
local x = m[1][1] * v.x + m[1][2] * v.y + m[1][3]
|
||||
local y = m[2][1] * v.x + m[2][2] * v.y + m[2][3]
|
||||
local w = m[3][1] * v.x + m[3][2] * v.y + m[3][3]
|
||||
|
||||
if math.abs(w) < 1e-10 then
|
||||
return {x = 0, y = 0}
|
||||
end
|
||||
|
||||
return {x = x / w, y = y / w}
|
||||
end,
|
||||
|
||||
-- Translation matrix
|
||||
translation = function(x, y)
|
||||
return {
|
||||
{1, 0, x},
|
||||
{0, 1, y},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Rotation matrix
|
||||
rotation = function(angle)
|
||||
local cos, sin = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
{cos, -sin, 0},
|
||||
{sin, cos, 0},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Scale matrix
|
||||
scale = function(sx, sy)
|
||||
sy = sy or sx
|
||||
return {
|
||||
{sx, 0, 0},
|
||||
{0, sy, 0},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Determinant
|
||||
det = function(m)
|
||||
return m[1][1] * (m[2][2] * m[3][3] - m[2][3] * m[3][2]) -
|
||||
m[1][2] * (m[2][1] * m[3][3] - m[2][3] * m[3][1]) +
|
||||
m[1][3] * (m[2][1] * m[3][2] - m[2][2] * m[3][1])
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- GEOMETRY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
math_ext.geometry = {
|
||||
-- Distance from point to line
|
||||
point_line_distance = function(px, py, x1, y1, x2, y2)
|
||||
local dx, dy = x2 - x1, y2 - y1
|
||||
local len_sq = dx * dx + dy * dy
|
||||
|
||||
if len_sq < 1e-10 then
|
||||
return math_ext.distance(px, py, x1, y1)
|
||||
end
|
||||
|
||||
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
|
||||
local nearestX = x1 + t * dx
|
||||
local nearestY = y1 + t * dy
|
||||
|
||||
return math_ext.distance(px, py, nearestX, nearestY)
|
||||
end,
|
||||
|
||||
-- Check if point is inside polygon
|
||||
point_in_polygon = function(px, py, vertices)
|
||||
local inside = false
|
||||
local n = #vertices / 2
|
||||
|
||||
for i = 1, n do
|
||||
local x1, y1 = vertices[i*2-1], vertices[i*2]
|
||||
local x2, y2
|
||||
|
||||
if i == n then
|
||||
x2, y2 = vertices[1], vertices[2]
|
||||
else
|
||||
x2, y2 = vertices[i*2+1], vertices[i*2+2]
|
||||
end
|
||||
|
||||
if ((y1 > py) ~= (y2 > py)) and
|
||||
(px < (x2 - x1) * (py - y1) / (y2 - y1) + x1) then
|
||||
inside = not inside
|
||||
end
|
||||
end
|
||||
|
||||
return inside
|
||||
end,
|
||||
|
||||
-- Area of a triangle
|
||||
triangle_area = function(x1, y1, x2, y2, x3, y3)
|
||||
return math.abs((x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2)) / 2)
|
||||
end,
|
||||
|
||||
-- Check if point is inside triangle
|
||||
point_in_triangle = function(px, py, x1, y1, x2, y2, x3, y3)
|
||||
local area = math_ext.geometry.triangle_area(x1, y1, x2, y2, x3, y3)
|
||||
local area1 = math_ext.geometry.triangle_area(px, py, x2, y2, x3, y3)
|
||||
local area2 = math_ext.geometry.triangle_area(x1, y1, px, py, x3, y3)
|
||||
local area3 = math_ext.geometry.triangle_area(x1, y1, x2, y2, px, py)
|
||||
|
||||
return math.abs(area - (area1 + area2 + area3)) < 1e-10
|
||||
end,
|
||||
|
||||
-- Check if two line segments intersect
|
||||
line_intersect = function(x1, y1, x2, y2, x3, y3, x4, y4)
|
||||
local d = (y4 - y3) * (x2 - x1) - (x4 - x3) * (y2 - y1)
|
||||
|
||||
if math.abs(d) < 1e-10 then
|
||||
return false, nil, nil -- Lines are parallel
|
||||
end
|
||||
|
||||
local ua = ((x4 - x3) * (y1 - y3) - (y4 - y3) * (x1 - x3)) / d
|
||||
local ub = ((x2 - x1) * (y1 - y3) - (y2 - y1) * (x1 - x3)) / d
|
||||
|
||||
if ua >= 0 and ua <= 1 and ub >= 0 and ub <= 1 then
|
||||
local x = x1 + ua * (x2 - x1)
|
||||
local y = y1 + ua * (y2 - y1)
|
||||
return true, x, y
|
||||
end
|
||||
|
||||
return false, nil, nil
|
||||
end,
|
||||
|
||||
-- Closest point on line segment to point
|
||||
closest_point_on_segment = function(px, py, x1, y1, x2, y2)
|
||||
local dx, dy = x2 - x1, y2 - y1
|
||||
local len_sq = dx * dx + dy * dy
|
||||
|
||||
if len_sq < 1e-10 then
|
||||
return x1, y1
|
||||
end
|
||||
|
||||
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
|
||||
return x1 + t * dx, y1 + t * dy
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- INTERPOLATION FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
math_ext.interpolation = {
|
||||
-- Cubic Bezier interpolation
|
||||
bezier = function(t, p0, p1, p2, p3)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
local t2 = t * t
|
||||
local t3 = t2 * t
|
||||
local mt = 1 - t
|
||||
local mt2 = mt * mt
|
||||
local mt3 = mt2 * mt
|
||||
|
||||
return p0 * mt3 + 3 * p1 * mt2 * t + 3 * p2 * mt * t2 + p3 * t3
|
||||
end,
|
||||
|
||||
-- Catmull-Rom spline interpolation
|
||||
catmull_rom = function(t, p0, p1, p2, p3)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
local t2 = t * t
|
||||
local t3 = t2 * t
|
||||
|
||||
return 0.5 * (
|
||||
(2 * p1) +
|
||||
(-p0 + p2) * t +
|
||||
(2 * p0 - 5 * p1 + 4 * p2 - p3) * t2 +
|
||||
(-p0 + 3 * p1 - 3 * p2 + p3) * t3
|
||||
)
|
||||
end,
|
||||
|
||||
-- Hermite interpolation
|
||||
hermite = function(t, p0, p1, m0, m1)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
local t2 = t * t
|
||||
local t3 = t2 * t
|
||||
local h00 = 2 * t3 - 3 * t2 + 1
|
||||
local h10 = t3 - 2 * t2 + t
|
||||
local h01 = -2 * t3 + 3 * t2
|
||||
local h11 = t3 - t2
|
||||
|
||||
return h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
|
||||
end,
|
||||
|
||||
-- Quadratic Bezier interpolation
|
||||
quadratic_bezier = function(t, p0, p1, p2)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
local mt = 1 - t
|
||||
return mt * mt * p0 + 2 * mt * t * p1 + t * t * p2
|
||||
end,
|
||||
|
||||
-- Step interpolation
|
||||
step = function(t, edge, x)
|
||||
return t < edge and 0 or x
|
||||
end,
|
||||
|
||||
-- Smoothstep interpolation
|
||||
smoothstep = function(edge0, edge1, x)
|
||||
local t = math_ext.clamp((x - edge0) / (edge1 - edge0), 0, 1)
|
||||
return t * t * (3 - 2 * t)
|
||||
end,
|
||||
|
||||
-- Smootherstep interpolation (Ken Perlin)
|
||||
smootherstep = function(edge0, edge1, x)
|
||||
local t = math_ext.clamp((x - edge0) / (edge1 - edge0), 0, 1)
|
||||
return t * t * t * (t * (t * 6 - 15) + 10)
|
||||
end
|
||||
}
|
||||
|
||||
return math_ext
|
||||
667
runner/lua/sandbox.lua
Normal file
667
runner/lua/sandbox.lua
Normal file
@ -0,0 +1,667 @@
|
||||
--[[
|
||||
sandbox.lua
|
||||
]]--
|
||||
|
||||
__http_response = {}
|
||||
__module_paths = {}
|
||||
__module_bytecode = {}
|
||||
__ready_modules = {}
|
||||
__EXIT_SENTINEL = {} -- Unique object for exit identification
|
||||
|
||||
-- ======================================================================
|
||||
-- CORE SANDBOX FUNCTIONALITY
|
||||
-- ======================================================================
|
||||
|
||||
function exit()
|
||||
error(__EXIT_SENTINEL)
|
||||
end
|
||||
|
||||
-- Create environment inheriting from _G
|
||||
function __create_env(ctx)
|
||||
local env = setmetatable({}, {__index = _G})
|
||||
|
||||
if ctx then
|
||||
env.ctx = ctx
|
||||
|
||||
if ctx._env then
|
||||
env._env = ctx._env
|
||||
end
|
||||
end
|
||||
|
||||
if __setup_require then
|
||||
__setup_require(env)
|
||||
end
|
||||
|
||||
return env
|
||||
end
|
||||
|
||||
-- Execute script with clean environment
|
||||
function __execute_script(fn, ctx)
|
||||
__http_response = nil
|
||||
|
||||
local env = __create_env(ctx)
|
||||
env.exit = exit
|
||||
setfenv(fn, env)
|
||||
|
||||
local ok, result = pcall(fn)
|
||||
if not ok then
|
||||
if result == __EXIT_SENTINEL then
|
||||
return
|
||||
end
|
||||
|
||||
error(result, 0)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- Ensure __http_response exists, then return it
|
||||
function __ensure_response()
|
||||
if not __http_response then
|
||||
__http_response = {}
|
||||
end
|
||||
return __http_response
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- HTTP FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Set HTTP status code
|
||||
function http_set_status(code)
|
||||
if type(code) ~= "number" then
|
||||
error("http_set_status: status code must be a number", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.status = code
|
||||
end
|
||||
|
||||
-- Set HTTP header
|
||||
function http_set_header(name, value)
|
||||
if type(name) ~= "string" or type(value) ~= "string" then
|
||||
error("http_set_header: name and value must be strings", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.headers = resp.headers or {}
|
||||
resp.headers[name] = value
|
||||
end
|
||||
|
||||
-- Set content type; http_set_header helper
|
||||
function http_set_content_type(content_type)
|
||||
http_set_header("Content-Type", content_type)
|
||||
end
|
||||
|
||||
-- Set metadata (arbitrary data to be returned with response)
|
||||
function http_set_metadata(key, value)
|
||||
if type(key) ~= "string" then
|
||||
error("http_set_metadata: key must be a string", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.metadata = resp.metadata or {}
|
||||
resp.metadata[key] = value
|
||||
end
|
||||
|
||||
-- Generic HTTP request function
|
||||
function http_request(method, url, body, options)
|
||||
if type(method) ~= "string" then
|
||||
error("http_request: method must be a string", 2)
|
||||
end
|
||||
if type(url) ~= "string" then
|
||||
error("http_request: url must be a string", 2)
|
||||
end
|
||||
|
||||
-- Call native implementation
|
||||
local result = __http_request(method, url, body, options)
|
||||
return result
|
||||
end
|
||||
|
||||
-- Shorthand function to directly get JSON
|
||||
function http_get_json(url, options)
|
||||
options = options or {}
|
||||
local response = http_get(url, options)
|
||||
if response.ok and response.json then
|
||||
return response.json
|
||||
end
|
||||
return nil, response
|
||||
end
|
||||
|
||||
-- Utility to build a URL with query parameters
|
||||
function http_build_url(base_url, params)
|
||||
if not params or type(params) ~= "table" then
|
||||
return base_url
|
||||
end
|
||||
|
||||
local query = {}
|
||||
for k, v in pairs(params) do
|
||||
if type(v) == "table" then
|
||||
for _, item in ipairs(v) do
|
||||
table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(item)))
|
||||
end
|
||||
else
|
||||
table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(v)))
|
||||
end
|
||||
end
|
||||
|
||||
if #query > 0 then
|
||||
if string.contains(base_url, "?") then
|
||||
return base_url .. "&" .. table.concat(query, "&")
|
||||
else
|
||||
return base_url .. "?" .. table.concat(query, "&")
|
||||
end
|
||||
end
|
||||
|
||||
return base_url
|
||||
end
|
||||
|
||||
local function make_method(method, needs_body)
|
||||
return function(url, body_or_options, options)
|
||||
if needs_body then
|
||||
options = options or {}
|
||||
return http_request(method, url, body_or_options, options)
|
||||
else
|
||||
body_or_options = body_or_options or {}
|
||||
return http_request(method, url, nil, body_or_options)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
http_get = make_method("GET", false)
|
||||
http_delete = make_method("DELETE", false)
|
||||
http_head = make_method("HEAD", false)
|
||||
http_options = make_method("OPTIONS", false)
|
||||
http_post = make_method("POST", true)
|
||||
http_put = make_method("PUT", true)
|
||||
http_patch = make_method("PATCH", true)
|
||||
|
||||
function http_redirect(url, status)
|
||||
if type(url) ~= "string" then
|
||||
error("http_redirect: url must be a string", 2)
|
||||
end
|
||||
|
||||
status = status or 302 -- Default to temporary redirect
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.status = status
|
||||
|
||||
resp.headers = resp.headers or {}
|
||||
resp.headers["Location"] = url
|
||||
|
||||
exit()
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- COOKIE FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Set a cookie
|
||||
function cookie_set(name, value, options)
|
||||
if type(name) ~= "string" then
|
||||
error("cookie_set: name must be a string", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.cookies = resp.cookies or {}
|
||||
|
||||
local opts = options or {}
|
||||
local cookie = {
|
||||
name = name,
|
||||
value = value or "",
|
||||
path = opts.path or "/",
|
||||
domain = opts.domain
|
||||
}
|
||||
|
||||
if opts.expires then
|
||||
if type(opts.expires) == "number" then
|
||||
if opts.expires > 0 then
|
||||
cookie.max_age = opts.expires
|
||||
local now = os.time()
|
||||
cookie.expires = now + opts.expires
|
||||
elseif opts.expires < 0 then
|
||||
cookie.expires = 1
|
||||
cookie.max_age = 0
|
||||
end
|
||||
-- opts.expires == 0: Session cookie (omitting both expires and max-age)
|
||||
end
|
||||
end
|
||||
|
||||
cookie.secure = (opts.secure ~= false)
|
||||
cookie.http_only = (opts.http_only ~= false)
|
||||
|
||||
if opts.same_site then
|
||||
local same_site = string.trim(opts.same_site):lower()
|
||||
local valid_values = {none = true, lax = true, strict = true}
|
||||
|
||||
if not valid_values[same_site] then
|
||||
error("cookie_set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
|
||||
end
|
||||
|
||||
-- If SameSite=None, the cookie must be secure
|
||||
if same_site == "none" and not cookie.secure then
|
||||
cookie.secure = true
|
||||
end
|
||||
|
||||
cookie.same_site = opts.same_site
|
||||
else
|
||||
cookie.same_site = "Lax"
|
||||
end
|
||||
|
||||
table.insert(resp.cookies, cookie)
|
||||
return true
|
||||
end
|
||||
|
||||
-- Get a cookie value
|
||||
function cookie_get(name)
|
||||
if type(name) ~= "string" then
|
||||
error("cookie_get: name must be a string", 2)
|
||||
end
|
||||
|
||||
local env = getfenv(2)
|
||||
|
||||
if env.ctx and env.ctx.cookies then
|
||||
return env.ctx.cookies[name]
|
||||
end
|
||||
|
||||
if env.ctx and env.ctx._request_cookies then
|
||||
return env.ctx._request_cookies[name]
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
-- Remove a cookie
|
||||
function cookie_remove(name, path, domain)
|
||||
if type(name) ~= "string" then
|
||||
error("cookie_remove: name must be a string", 2)
|
||||
end
|
||||
|
||||
return cookie_set(name, "", {expires = 0, path = path or "/", domain = domain})
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- SESSION FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function session_get(key)
|
||||
if type(key) ~= "string" then
|
||||
error("session_get: key must be a string", 2)
|
||||
end
|
||||
|
||||
local env = getfenv(2)
|
||||
|
||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
||||
return env.ctx.session.data[key]
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
function session_set(key, value)
|
||||
if type(key) ~= "string" then
|
||||
error("session_set: key must be a string", 2)
|
||||
end
|
||||
if type(value) == nil then
|
||||
error("session_set: value cannot be nil", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.session = resp.session or {}
|
||||
resp.session[key] = value
|
||||
|
||||
local env = getfenv(2)
|
||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
||||
env.ctx.session.data[key] = value
|
||||
end
|
||||
end
|
||||
|
||||
function session_id()
|
||||
local env = getfenv(2)
|
||||
|
||||
if env.ctx and env.ctx.session then
|
||||
return env.ctx.session.id
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
function session_get_all()
|
||||
local env = getfenv(2)
|
||||
|
||||
if env.ctx and env.ctx.session then
|
||||
return env.ctx.session.data
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
function session_delete(key)
|
||||
if type(key) ~= "string" then
|
||||
error("session_delete: key must be a string", 2)
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.session = resp.session or {}
|
||||
resp.session[key] = "__SESSION_DELETE_MARKER__"
|
||||
|
||||
local env = getfenv(2)
|
||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
||||
env.ctx.session.data[key] = nil
|
||||
end
|
||||
end
|
||||
|
||||
function session_clear()
|
||||
local env = getfenv(2)
|
||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
||||
for k, _ in pairs(env.ctx.session.data) do
|
||||
env.ctx.session.data[k] = nil
|
||||
end
|
||||
end
|
||||
|
||||
local resp = __ensure_response()
|
||||
resp.session = {}
|
||||
resp.session["__clear_all"] = true
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- CSRF FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function csrf_generate()
|
||||
local token = generate_token(32)
|
||||
session_set("_csrf_token", token)
|
||||
return token
|
||||
end
|
||||
|
||||
function csrf_field()
|
||||
local token = session_get("_csrf_token")
|
||||
if not token then
|
||||
token = csrf_generate()
|
||||
end
|
||||
return string.format('<input type="hidden" name="_csrf_token" value="%s" />',
|
||||
html_special_chars(token))
|
||||
end
|
||||
|
||||
function csrf_validate()
|
||||
local env = getfenv(2)
|
||||
local token = false
|
||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
||||
token = env.ctx.session.data["_csrf_token"]
|
||||
end
|
||||
|
||||
if not token then
|
||||
http_set_status(403)
|
||||
__http_response.body = "CSRF validation failed"
|
||||
exit()
|
||||
end
|
||||
|
||||
local request_token = nil
|
||||
if env.ctx and env.ctx.form then
|
||||
request_token = env.ctx.form._csrf_token
|
||||
end
|
||||
|
||||
if not request_token and env.ctx and env.ctx._request_headers then
|
||||
request_token = env.ctx._request_headers["x-csrf-token"] or
|
||||
env.ctx._request_headers["csrf-token"]
|
||||
end
|
||||
|
||||
if not request_token or request_token ~= token then
|
||||
http_set_status(403)
|
||||
__http_response.body = "CSRF validation failed"
|
||||
exit()
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- TEMPLATE RENDER FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Template processing with code execution
|
||||
_G.render = function(template_str, env)
|
||||
local function get_line(s, ln)
|
||||
for line in s:gmatch("([^\n]*)\n?") do
|
||||
if ln == 1 then return line end
|
||||
ln = ln - 1
|
||||
end
|
||||
end
|
||||
|
||||
local function pos_to_line(s, pos)
|
||||
local line = 1
|
||||
for _ in s:sub(1, pos):gmatch("\n") do line = line + 1 end
|
||||
return line
|
||||
end
|
||||
|
||||
local pos, chunks = 1, {}
|
||||
while pos <= #template_str do
|
||||
local unescaped_start = template_str:find("{{{", pos, true)
|
||||
local escaped_start = template_str:find("{{", pos, true)
|
||||
|
||||
local start, tag_type, open_len
|
||||
if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then
|
||||
start, tag_type, open_len = unescaped_start, "-", 3
|
||||
elseif escaped_start then
|
||||
start, tag_type, open_len = escaped_start, "=", 2
|
||||
else
|
||||
table.insert(chunks, template_str:sub(pos))
|
||||
break
|
||||
end
|
||||
|
||||
if start > pos then
|
||||
table.insert(chunks, template_str:sub(pos, start-1))
|
||||
end
|
||||
|
||||
pos = start + open_len
|
||||
local close_tag = tag_type == "-" and "}}}" or "}}"
|
||||
local close_start, close_stop = template_str:find(close_tag, pos, true)
|
||||
if not close_start then
|
||||
error("Failed to find closing tag at position " .. pos)
|
||||
end
|
||||
|
||||
local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$")
|
||||
|
||||
-- Check if it's a simple variable name for escaped output
|
||||
local is_simple_var = tag_type == "=" and code:match("^[%w_]+$")
|
||||
|
||||
table.insert(chunks, {tag_type, code, pos, is_simple_var})
|
||||
pos = close_stop + 1
|
||||
end
|
||||
|
||||
local buffer = {"local _tostring, _escape, _b, _b_i = ...\n"}
|
||||
for _, chunk in ipairs(chunks) do
|
||||
local t = type(chunk)
|
||||
if t == "string" then
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "_b[_b_i] = " .. string.format("%q", chunk) .. "\n")
|
||||
else
|
||||
t = chunk[1]
|
||||
if t == "=" then
|
||||
if chunk[4] then -- is_simple_var
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _escape(_tostring(" .. chunk[2] .. "))\n")
|
||||
else
|
||||
table.insert(buffer, "--[[" .. chunk[3] .. "]] " .. chunk[2] .. "\n")
|
||||
end
|
||||
elseif t == "-" then
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _tostring(" .. chunk[2] .. ")\n")
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(buffer, "return _b")
|
||||
|
||||
local fn, err = loadstring(table.concat(buffer))
|
||||
if not fn then error(err) end
|
||||
|
||||
env = env or {}
|
||||
local runtime_env = setmetatable({}, {__index = function(_, k) return env[k] or _G[k] end})
|
||||
setfenv(fn, runtime_env)
|
||||
|
||||
local output_buffer = {}
|
||||
fn(tostring, html_special_chars, output_buffer, 0)
|
||||
return table.concat(output_buffer)
|
||||
end
|
||||
|
||||
-- Named placeholder processing
|
||||
_G.parse = function(template_str, env)
|
||||
local pos, output = 1, {}
|
||||
env = env or {}
|
||||
|
||||
while pos <= #template_str do
|
||||
local unescaped_start, unescaped_end, unescaped_name = template_str:find("{{{%s*([%w_]+)%s*}}}", pos)
|
||||
local escaped_start, escaped_end, escaped_name = template_str:find("{{%s*([%w_]+)%s*}}", pos)
|
||||
|
||||
local next_pos, placeholder_end, name, escaped
|
||||
if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then
|
||||
next_pos, placeholder_end, name, escaped = unescaped_start, unescaped_end, unescaped_name, false
|
||||
elseif escaped_start then
|
||||
next_pos, placeholder_end, name, escaped = escaped_start, escaped_end, escaped_name, true
|
||||
else
|
||||
local text = template_str:sub(pos)
|
||||
if text and #text > 0 then
|
||||
table.insert(output, text)
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
local text = template_str:sub(pos, next_pos - 1)
|
||||
if text and #text > 0 then
|
||||
table.insert(output, text)
|
||||
end
|
||||
|
||||
local value = env[name]
|
||||
local str = tostring(value or "")
|
||||
if escaped then
|
||||
str = html_special_chars(str)
|
||||
end
|
||||
table.insert(output, str)
|
||||
|
||||
pos = placeholder_end + 1
|
||||
end
|
||||
|
||||
return table.concat(output)
|
||||
end
|
||||
|
||||
-- Indexed placeholder processing
|
||||
_G.iparse = function(template_str, values)
|
||||
local pos, output, value_index = 1, {}, 1
|
||||
values = values or {}
|
||||
|
||||
while pos <= #template_str do
|
||||
local unescaped_start, unescaped_end = template_str:find("{{{}}}", pos, true)
|
||||
local escaped_start, escaped_end = template_str:find("{{}}", pos, true)
|
||||
|
||||
local next_pos, placeholder_end, escaped
|
||||
if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then
|
||||
next_pos, placeholder_end, escaped = unescaped_start, unescaped_end, false
|
||||
elseif escaped_start then
|
||||
next_pos, placeholder_end, escaped = escaped_start, escaped_end, true
|
||||
else
|
||||
local text = template_str:sub(pos)
|
||||
if text and #text > 0 then
|
||||
table.insert(output, text)
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
local text = template_str:sub(pos, next_pos - 1)
|
||||
if text and #text > 0 then
|
||||
table.insert(output, text)
|
||||
end
|
||||
|
||||
local value = values[value_index]
|
||||
local str = tostring(value or "")
|
||||
if escaped then
|
||||
str = html_special_chars(str)
|
||||
end
|
||||
table.insert(output, str)
|
||||
|
||||
pos = placeholder_end + 1
|
||||
value_index = value_index + 1
|
||||
end
|
||||
|
||||
return table.concat(output)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- PASSWORD FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Hash a password using Argon2id
|
||||
-- Options:
|
||||
-- memory: Amount of memory to use in KB (default: 128MB)
|
||||
-- iterations: Number of iterations (default: 4)
|
||||
-- parallelism: Number of threads (default: 4)
|
||||
-- salt_length: Length of salt in bytes (default: 16)
|
||||
-- key_length: Length of the derived key in bytes (default: 32)
|
||||
function password_hash(plain_password, options)
|
||||
if type(plain_password) ~= "string" then
|
||||
error("password_hash: expected string password", 2)
|
||||
end
|
||||
|
||||
return __password_hash(plain_password, options)
|
||||
end
|
||||
|
||||
-- Verify a password against a hash
|
||||
function password_verify(plain_password, hash_string)
|
||||
if type(plain_password) ~= "string" then
|
||||
error("password_verify: expected string password", 2)
|
||||
end
|
||||
|
||||
if type(hash_string) ~= "string" then
|
||||
error("password_verify: expected string hash", 2)
|
||||
end
|
||||
|
||||
return __password_verify(plain_password, hash_string)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- SEND FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function send_html(content)
|
||||
http_set_content_type("text/html")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_json(content)
|
||||
http_set_content_type("application/json")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_text(content)
|
||||
http_set_content_type("text/plain")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_xml(content)
|
||||
http_set_content_type("application/xml")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_javascript(content)
|
||||
http_set_content_type("application/javascript")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_css(content)
|
||||
http_set_content_type("text/css")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_svg(content)
|
||||
http_set_content_type("image/svg+xml")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_csv(content)
|
||||
http_set_content_type("text/csv")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_binary(content, mime_type)
|
||||
http_set_content_type(mime_type or "application/octet-stream")
|
||||
return content
|
||||
end
|
||||
297
runner/lua/sqlite.lua
Normal file
297
runner/lua/sqlite.lua
Normal file
@ -0,0 +1,297 @@
|
||||
local function normalize_params(params, ...)
|
||||
if type(params) == "table" then return params end
|
||||
local args = {...}
|
||||
if #args > 0 or params ~= nil then
|
||||
table.insert(args, 1, params)
|
||||
return args
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
local connection_mt = {
|
||||
__index = {
|
||||
query = function(self, query, params, ...)
|
||||
if type(query) ~= "string" then
|
||||
error("connection:query: query must be a string", 2)
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_query(self.db_name, query, normalized_params)
|
||||
end,
|
||||
|
||||
exec = function(self, query, params, ...)
|
||||
if type(query) ~= "string" then
|
||||
error("connection:exec: query must be a string", 2)
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_exec(self.db_name, query, normalized_params)
|
||||
end,
|
||||
|
||||
get_one = function(self, query, params, ...)
|
||||
if type(query) ~= "string" then
|
||||
error("connection:get_one: query must be a string", 2)
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_get_one(self.db_name, query, normalized_params)
|
||||
end,
|
||||
|
||||
insert = function(self, table_name, data, columns)
|
||||
if type(data) ~= "table" then
|
||||
error("connection:insert: data must be a table", 2)
|
||||
end
|
||||
|
||||
-- Single object: {col1=val1, col2=val2}
|
||||
if data[1] == nil and next(data) ~= nil then
|
||||
local cols = table.keys(data)
|
||||
local placeholders = table.map(cols, function(_, i) return ":p" .. i end)
|
||||
local params = {}
|
||||
for i, col in ipairs(cols) do
|
||||
params["p" .. i] = data[col]
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES (%s)",
|
||||
table_name,
|
||||
table.concat(cols, ", "),
|
||||
table.concat(placeholders, ", ")
|
||||
)
|
||||
return self:exec(query, params)
|
||||
end
|
||||
|
||||
-- Array data with columns
|
||||
if columns and type(columns) == "table" then
|
||||
if #data > 0 and type(data[1]) == "table" then
|
||||
-- Multiple rows
|
||||
local value_groups = {}
|
||||
local params = {}
|
||||
local param_idx = 1
|
||||
|
||||
for _, row in ipairs(data) do
|
||||
local row_placeholders = {}
|
||||
for j = 1, #columns do
|
||||
local param_name = "p" .. param_idx
|
||||
table.insert(row_placeholders, ":" .. param_name)
|
||||
params[param_name] = row[j]
|
||||
param_idx = param_idx + 1
|
||||
end
|
||||
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES %s",
|
||||
table_name,
|
||||
table.concat(columns, ", "),
|
||||
table.concat(value_groups, ", ")
|
||||
)
|
||||
return self:exec(query, params)
|
||||
else
|
||||
-- Single row array
|
||||
local placeholders = table.map(columns, function(_, i) return ":p" .. i end)
|
||||
local params = {}
|
||||
for i = 1, #columns do
|
||||
params["p" .. i] = data[i]
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES (%s)",
|
||||
table_name,
|
||||
table.concat(columns, ", "),
|
||||
table.concat(placeholders, ", ")
|
||||
)
|
||||
return self:exec(query, params)
|
||||
end
|
||||
end
|
||||
|
||||
-- Array of objects
|
||||
if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then
|
||||
local cols = table.keys(data[1])
|
||||
local value_groups = {}
|
||||
local params = {}
|
||||
local param_idx = 1
|
||||
|
||||
for _, row in ipairs(data) do
|
||||
local row_placeholders = {}
|
||||
for _, col in ipairs(cols) do
|
||||
local param_name = "p" .. param_idx
|
||||
table.insert(row_placeholders, ":" .. param_name)
|
||||
params[param_name] = row[col]
|
||||
param_idx = param_idx + 1
|
||||
end
|
||||
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES %s",
|
||||
table_name,
|
||||
table.concat(cols, ", "),
|
||||
table.concat(value_groups, ", ")
|
||||
)
|
||||
return self:exec(query, params)
|
||||
end
|
||||
|
||||
error("connection:insert: invalid data format", 2)
|
||||
end,
|
||||
|
||||
update = function(self, table_name, data, where, where_params, ...)
|
||||
if type(data) ~= "table" or next(data) == nil then
|
||||
return 0
|
||||
end
|
||||
|
||||
local sets = {}
|
||||
local params = {}
|
||||
local param_idx = 1
|
||||
|
||||
for col, val in pairs(data) do
|
||||
local param_name = "p" .. param_idx
|
||||
table.insert(sets, col .. " = :" .. param_name)
|
||||
params[param_name] = val
|
||||
param_idx = param_idx + 1
|
||||
end
|
||||
|
||||
local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", "))
|
||||
|
||||
if where then
|
||||
query = query .. " WHERE " .. where
|
||||
if where_params then
|
||||
local normalized = normalize_params(where_params, ...)
|
||||
if type(normalized) == "table" then
|
||||
for k, v in pairs(normalized) do
|
||||
if type(k) == "string" then
|
||||
params[k] = v
|
||||
else
|
||||
params["w" .. param_idx] = v
|
||||
param_idx = param_idx + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return self:exec(query, params)
|
||||
end,
|
||||
|
||||
create_table = function(self, table_name, ...)
|
||||
local column_definitions = {}
|
||||
local index_definitions = {}
|
||||
|
||||
for _, def_string in ipairs({...}) do
|
||||
if type(def_string) == "string" then
|
||||
local is_unique = false
|
||||
local index_def = def_string
|
||||
|
||||
if string.starts_with(def_string, "UNIQUE INDEX:") then
|
||||
is_unique = true
|
||||
index_def = string.trim(def_string:sub(14))
|
||||
elseif string.starts_with(def_string, "INDEX:") then
|
||||
index_def = string.trim(def_string:sub(7))
|
||||
else
|
||||
table.insert(column_definitions, def_string)
|
||||
goto continue
|
||||
end
|
||||
|
||||
local paren_pos = index_def:find("%(")
|
||||
if not paren_pos then goto continue end
|
||||
|
||||
local index_name = string.trim(index_def:sub(1, paren_pos - 1))
|
||||
local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$")
|
||||
if not columns_part then goto continue end
|
||||
|
||||
local columns = table.map(string.split(columns_part, ","), string.trim)
|
||||
|
||||
if #columns > 0 then
|
||||
table.insert(index_definitions, {
|
||||
name = index_name,
|
||||
columns = columns,
|
||||
unique = is_unique
|
||||
})
|
||||
end
|
||||
end
|
||||
::continue::
|
||||
end
|
||||
|
||||
if #column_definitions == 0 then
|
||||
error("connection:create_table: no column definitions specified for table " .. table_name, 2)
|
||||
end
|
||||
|
||||
local statements = {}
|
||||
|
||||
table.insert(statements, string.format(
|
||||
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
||||
table_name,
|
||||
table.concat(column_definitions, ", ")
|
||||
))
|
||||
|
||||
for _, idx in ipairs(index_definitions) do
|
||||
local unique_prefix = idx.unique and "UNIQUE " or ""
|
||||
table.insert(statements, string.format(
|
||||
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
||||
unique_prefix,
|
||||
idx.name,
|
||||
table_name,
|
||||
table.concat(idx.columns, ", ")
|
||||
))
|
||||
end
|
||||
|
||||
return self:exec(table.concat(statements, ";\n"))
|
||||
end,
|
||||
|
||||
delete = function(self, table_name, where, params, ...)
|
||||
local query = "DELETE FROM " .. table_name
|
||||
if where then
|
||||
query = query .. " WHERE " .. where
|
||||
end
|
||||
return self:exec(query, normalize_params(params, ...))
|
||||
end,
|
||||
|
||||
exists = function(self, table_name, where, params, ...)
|
||||
if type(table_name) ~= "string" then
|
||||
error("connection:exists: table_name must be a string", 2)
|
||||
end
|
||||
|
||||
local query = "SELECT 1 FROM " .. table_name
|
||||
if where then
|
||||
query = query .. " WHERE " .. where
|
||||
end
|
||||
query = query .. " LIMIT 1"
|
||||
|
||||
local results = self:query(query, normalize_params(params, ...))
|
||||
return #results > 0
|
||||
end,
|
||||
|
||||
begin = function(self)
|
||||
return self:exec("BEGIN TRANSACTION")
|
||||
end,
|
||||
|
||||
commit = function(self)
|
||||
return self:exec("COMMIT")
|
||||
end,
|
||||
|
||||
rollback = function(self)
|
||||
return self:exec("ROLLBACK")
|
||||
end,
|
||||
|
||||
transaction = function(self, callback)
|
||||
self:begin()
|
||||
local success, result = pcall(callback, self)
|
||||
if success then
|
||||
self:commit()
|
||||
return result
|
||||
else
|
||||
self:rollback()
|
||||
error(result, 2)
|
||||
end
|
||||
end
|
||||
}
|
||||
}
|
||||
|
||||
return function(db_name)
|
||||
if type(db_name) ~= "string" then
|
||||
error("sqlite: database name must be a string", 2)
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
db_name = db_name
|
||||
}, connection_mt)
|
||||
end
|
||||
197
runner/lua/string.lua
Normal file
197
runner/lua/string.lua
Normal file
@ -0,0 +1,197 @@
|
||||
--[[
|
||||
string.lua - Extended string library functions
|
||||
]]--
|
||||
|
||||
local string_ext = {}
|
||||
|
||||
-- ======================================================================
|
||||
-- STRING UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Trim whitespace from both ends
|
||||
function string_ext.trim(s)
|
||||
if type(s) ~= "string" then return s end
|
||||
return s:match("^%s*(.-)%s*$")
|
||||
end
|
||||
|
||||
-- Split string by delimiter
|
||||
function string_ext.split(s, delimiter)
|
||||
if type(s) ~= "string" then return {} end
|
||||
|
||||
delimiter = delimiter or ","
|
||||
local result = {}
|
||||
for match in (s..delimiter):gmatch("(.-)"..delimiter) do
|
||||
table.insert(result, match)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- Check if string starts with prefix
|
||||
function string_ext.starts_with(s, prefix)
|
||||
if type(s) ~= "string" or type(prefix) ~= "string" then return false end
|
||||
return s:sub(1, #prefix) == prefix
|
||||
end
|
||||
|
||||
-- Check if string ends with suffix
|
||||
function string_ext.ends_with(s, suffix)
|
||||
if type(s) ~= "string" or type(suffix) ~= "string" then return false end
|
||||
return suffix == "" or s:sub(-#suffix) == suffix
|
||||
end
|
||||
|
||||
-- Left pad a string
|
||||
function string_ext.pad_left(s, len, char)
|
||||
if type(s) ~= "string" or type(len) ~= "number" then return s end
|
||||
|
||||
char = char or " "
|
||||
if #s >= len then return s end
|
||||
|
||||
return string.rep(char:sub(1,1), len - #s) .. s
|
||||
end
|
||||
|
||||
-- Right pad a string
|
||||
function string_ext.pad_right(s, len, char)
|
||||
if type(s) ~= "string" or type(len) ~= "number" then return s end
|
||||
|
||||
char = char or " "
|
||||
if #s >= len then return s end
|
||||
|
||||
return s .. string.rep(char:sub(1,1), len - #s)
|
||||
end
|
||||
|
||||
-- Center a string
|
||||
function string_ext.center(s, width, char)
|
||||
if type(s) ~= "string" or width <= #s then return s end
|
||||
|
||||
char = char or " "
|
||||
local pad_len = width - #s
|
||||
local left_pad = math.floor(pad_len / 2)
|
||||
local right_pad = pad_len - left_pad
|
||||
|
||||
return string.rep(char:sub(1,1), left_pad) .. s .. string.rep(char:sub(1,1), right_pad)
|
||||
end
|
||||
|
||||
-- Count occurrences of substring
|
||||
function string_ext.count(s, substr)
|
||||
if type(s) ~= "string" or type(substr) ~= "string" or #substr == 0 then return 0 end
|
||||
|
||||
local count, pos = 0, 1
|
||||
while true do
|
||||
pos = s:find(substr, pos, true)
|
||||
if not pos then break end
|
||||
count = count + 1
|
||||
pos = pos + 1
|
||||
end
|
||||
return count
|
||||
end
|
||||
|
||||
-- Capitalize first letter
|
||||
function string_ext.capitalize(s)
|
||||
if type(s) ~= "string" or #s == 0 then return s end
|
||||
return s:sub(1,1):upper() .. s:sub(2)
|
||||
end
|
||||
|
||||
-- Capitalize all words
|
||||
function string_ext.title(s)
|
||||
if type(s) ~= "string" then return s end
|
||||
|
||||
return s:gsub("(%w)([%w]*)", function(first, rest)
|
||||
return first:upper() .. rest:lower()
|
||||
end)
|
||||
end
|
||||
|
||||
-- Insert string at position
|
||||
function string_ext.insert(s, pos, insert_str)
|
||||
if type(s) ~= "string" or type(insert_str) ~= "string" then return s end
|
||||
|
||||
pos = math.max(1, math.min(pos, #s + 1))
|
||||
return s:sub(1, pos - 1) .. insert_str .. s:sub(pos)
|
||||
end
|
||||
|
||||
-- Remove substring
|
||||
function string_ext.remove(s, start, length)
|
||||
if type(s) ~= "string" then return s end
|
||||
|
||||
length = length or 1
|
||||
if start < 1 or start > #s then return s end
|
||||
|
||||
return s:sub(1, start - 1) .. s:sub(start + length)
|
||||
end
|
||||
|
||||
-- Replace substring once
|
||||
function string_ext.replace(s, old, new, n)
|
||||
if type(s) ~= "string" or type(old) ~= "string" or #old == 0 then return s end
|
||||
|
||||
new = new or ""
|
||||
n = n or 1
|
||||
|
||||
return s:gsub(old:gsub("[%-%^%$%(%)%%%.%[%]%*%+%-%?]", "%%%1"), new, n)
|
||||
end
|
||||
|
||||
-- Check if string contains substring
|
||||
function string_ext.contains(s, substr)
|
||||
if type(s) ~= "string" or type(substr) ~= "string" then return false end
|
||||
return s:find(substr, 1, true) ~= nil
|
||||
end
|
||||
|
||||
-- Escape pattern magic characters
|
||||
function string_ext.escape_pattern(s)
|
||||
if type(s) ~= "string" then return s end
|
||||
return s:gsub("[%-%^%$%(%)%%%.%[%]%*%+%-%?]", "%%%1")
|
||||
end
|
||||
|
||||
-- Wrap text at specified width
|
||||
function string_ext.wrap(s, width, indent_first, indent_rest)
|
||||
if type(s) ~= "string" or type(width) ~= "number" then return s end
|
||||
|
||||
width = math.max(1, width)
|
||||
indent_first = indent_first or ""
|
||||
indent_rest = indent_rest or indent_first
|
||||
|
||||
local result = {}
|
||||
local line_prefix = indent_first
|
||||
local pos = 1
|
||||
|
||||
while pos <= #s do
|
||||
local line_width = width - #line_prefix
|
||||
local end_pos = math.min(pos + line_width - 1, #s)
|
||||
|
||||
if end_pos < #s then
|
||||
local last_space = s:sub(pos, end_pos):match(".*%s()")
|
||||
if last_space then
|
||||
end_pos = pos + last_space - 2
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(result, line_prefix .. s:sub(pos, end_pos))
|
||||
pos = end_pos + 1
|
||||
|
||||
-- Skip leading spaces on next line
|
||||
while s:sub(pos, pos) == " " do
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
line_prefix = indent_rest
|
||||
end
|
||||
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
-- Limit string length with ellipsis
|
||||
function string_ext.truncate(s, length, ellipsis)
|
||||
if type(s) ~= "string" then return s end
|
||||
|
||||
ellipsis = ellipsis or "..."
|
||||
if #s <= length then return s end
|
||||
|
||||
return s:sub(1, length - #ellipsis) .. ellipsis
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- INSTALL EXTENSIONS INTO STRING LIBRARY
|
||||
-- ======================================================================
|
||||
|
||||
for name, func in pairs(string) do
|
||||
string_ext[name] = func
|
||||
end
|
||||
|
||||
return string_ext
|
||||
1092
runner/lua/table.lua
Normal file
1092
runner/lua/table.lua
Normal file
File diff suppressed because it is too large
Load Diff
130
runner/lua/time.lua
Normal file
130
runner/lua/time.lua
Normal file
@ -0,0 +1,130 @@
|
||||
--[[
|
||||
time.lua - High performance timing functions
|
||||
]]--
|
||||
|
||||
local ffi = require('ffi')
|
||||
local is_windows = (ffi.os == "Windows")
|
||||
|
||||
-- Define C structures and functions based on platform
|
||||
if is_windows then
|
||||
ffi.cdef[[
|
||||
typedef struct {
|
||||
int64_t QuadPart;
|
||||
} LARGE_INTEGER;
|
||||
int QueryPerformanceCounter(LARGE_INTEGER* lpPerformanceCount);
|
||||
int QueryPerformanceFrequency(LARGE_INTEGER* lpFrequency);
|
||||
]]
|
||||
else
|
||||
ffi.cdef[[
|
||||
typedef long time_t;
|
||||
typedef struct timeval {
|
||||
long tv_sec;
|
||||
long tv_usec;
|
||||
} timeval;
|
||||
int gettimeofday(struct timeval* tv, void* tz);
|
||||
time_t time(time_t* t);
|
||||
]]
|
||||
end
|
||||
|
||||
local time = {}
|
||||
local has_initialized = false
|
||||
local start_time, timer_freq
|
||||
|
||||
-- Initialize timing system based on platform
|
||||
local function init()
|
||||
if has_initialized then return end
|
||||
|
||||
if ffi.os == "Windows" then
|
||||
local frequency = ffi.new("LARGE_INTEGER")
|
||||
ffi.C.QueryPerformanceFrequency(frequency)
|
||||
timer_freq = tonumber(frequency.QuadPart)
|
||||
|
||||
local counter = ffi.new("LARGE_INTEGER")
|
||||
ffi.C.QueryPerformanceCounter(counter)
|
||||
start_time = tonumber(counter.QuadPart)
|
||||
else
|
||||
-- Nothing special needed for Unix platform init
|
||||
start_time = ffi.C.time(nil)
|
||||
end
|
||||
|
||||
has_initialized = true
|
||||
end
|
||||
|
||||
-- PHP-compatible microtime implementation
|
||||
function time.microtime(get_as_float)
|
||||
init()
|
||||
|
||||
if ffi.os == "Windows" then
|
||||
local counter = ffi.new("LARGE_INTEGER")
|
||||
ffi.C.QueryPerformanceCounter(counter)
|
||||
local now = tonumber(counter.QuadPart)
|
||||
local seconds = math.floor((now - start_time) / timer_freq)
|
||||
local microseconds = ((now - start_time) % timer_freq) * 1000000 / timer_freq
|
||||
|
||||
if get_as_float then
|
||||
return seconds + microseconds / 1000000
|
||||
else
|
||||
return string.format("0.%06d %d", microseconds, seconds)
|
||||
end
|
||||
else
|
||||
local tv = ffi.new("struct timeval")
|
||||
ffi.C.gettimeofday(tv, nil)
|
||||
|
||||
if get_as_float then
|
||||
return tonumber(tv.tv_sec) + tonumber(tv.tv_usec) / 1000000
|
||||
else
|
||||
return string.format("0.%06d %d", tv.tv_usec, tv.tv_sec)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- High-precision monotonic timer (returns seconds with microsecond precision)
|
||||
function time.monotonic()
|
||||
init()
|
||||
|
||||
if ffi.os == "Windows" then
|
||||
local counter = ffi.new("LARGE_INTEGER")
|
||||
ffi.C.QueryPerformanceCounter(counter)
|
||||
local now = tonumber(counter.QuadPart)
|
||||
return (now - start_time) / timer_freq
|
||||
else
|
||||
local tv = ffi.new("struct timeval")
|
||||
ffi.C.gettimeofday(tv, nil)
|
||||
return tonumber(tv.tv_sec) - start_time + tonumber(tv.tv_usec) / 1000000
|
||||
end
|
||||
end
|
||||
|
||||
-- Benchmark function that measures execution time
|
||||
function time.benchmark(func, iterations, warmup)
|
||||
iterations = iterations or 1000
|
||||
warmup = warmup or 10
|
||||
|
||||
-- Warmup
|
||||
for i=1, warmup do func() end
|
||||
|
||||
local start = time.microtime(true)
|
||||
for i=1, iterations do
|
||||
func()
|
||||
end
|
||||
local finish = time.microtime(true)
|
||||
|
||||
local elapsed = (finish - start) * 1000000 -- Convert to microseconds
|
||||
return elapsed / iterations
|
||||
end
|
||||
|
||||
-- Simple sleep function using coroutine yielding
|
||||
function time.sleep(seconds)
|
||||
if type(seconds) ~= "number" or seconds <= 0 then
|
||||
return
|
||||
end
|
||||
|
||||
local start = time.monotonic()
|
||||
while time.monotonic() - start < seconds do
|
||||
-- Use coroutine.yield to avoid consuming CPU
|
||||
coroutine.yield()
|
||||
end
|
||||
end
|
||||
|
||||
_G.microtime = time.microtime
|
||||
|
||||
return time
|
||||
293
runner/lua/util.lua
Normal file
293
runner/lua/util.lua
Normal file
@ -0,0 +1,293 @@
|
||||
--[[
|
||||
util.lua - Utility functions for the Lua sandbox
|
||||
]]--
|
||||
|
||||
-- ======================================================================
|
||||
-- CORE UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate a random token
|
||||
function generate_token(length)
|
||||
return __generate_token(length or 32)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- HTML ENTITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Convert special characters to HTML entities (like htmlspecialchars)
|
||||
function html_special_chars(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return __html_special_chars(str)
|
||||
end
|
||||
|
||||
-- Convert all applicable characters to HTML entities (like htmlentities)
|
||||
function html_entities(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return __html_entities(str)
|
||||
end
|
||||
|
||||
-- Convert HTML entities back to characters (simple version)
|
||||
function html_entity_decode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
str = str:gsub("<", "<")
|
||||
str = str:gsub(">", ">")
|
||||
str = str:gsub(""", '"')
|
||||
str = str:gsub("'", "'")
|
||||
str = str:gsub("&", "&")
|
||||
|
||||
return str
|
||||
end
|
||||
|
||||
-- Convert newlines to <br> tags
|
||||
function nl2br(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return str:gsub("\r\n", "<br>"):gsub("\n", "<br>"):gsub("\r", "<br>")
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- URL FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- URL encode a string
|
||||
function url_encode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
str = str:gsub("\n", "\r\n")
|
||||
str = str:gsub("([^%w %-%_%.%~])", function(c)
|
||||
return string.format("%%%02X", string.byte(c))
|
||||
end)
|
||||
str = str:gsub(" ", "+")
|
||||
return str
|
||||
end
|
||||
|
||||
-- URL decode a string
|
||||
function url_decode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
str = str:gsub("+", " ")
|
||||
str = str:gsub("%%(%x%x)", function(h)
|
||||
return string.char(tonumber(h, 16))
|
||||
end)
|
||||
return str
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- VALIDATION FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Email validation
|
||||
function is_email(str)
|
||||
if type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Simple email validation pattern
|
||||
local pattern = "^[%w%.%%%+%-]+@[%w%.%%%+%-]+%.%w%w%w?%w?$"
|
||||
return str:match(pattern) ~= nil
|
||||
end
|
||||
|
||||
-- URL validation
|
||||
function is_url(str)
|
||||
if type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Simple URL validation
|
||||
local pattern = "^https?://[%w-_%.%?%.:/%+=&%%]+$"
|
||||
return str:match(pattern) ~= nil
|
||||
end
|
||||
|
||||
-- IP address validation (IPv4)
|
||||
function is_ipv4(str)
|
||||
if type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
local pattern = "^(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)$"
|
||||
local a, b, c, d = str:match(pattern)
|
||||
|
||||
if not (a and b and c and d) then
|
||||
return false
|
||||
end
|
||||
|
||||
a, b, c, d = tonumber(a), tonumber(b), tonumber(c), tonumber(d)
|
||||
return a <= 255 and b <= 255 and c <= 255 and d <= 255
|
||||
end
|
||||
|
||||
-- Integer validation
|
||||
function is_int(str)
|
||||
if type(str) == "number" then
|
||||
return math.floor(str) == str
|
||||
elseif type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
return str:match("^-?%d+$") ~= nil
|
||||
end
|
||||
|
||||
-- Float validation
|
||||
function is_float(str)
|
||||
if type(str) == "number" then
|
||||
return true
|
||||
elseif type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
return str:match("^-?%d+%.?%d*$") ~= nil
|
||||
end
|
||||
|
||||
-- Boolean validation
|
||||
function is_bool(value)
|
||||
if type(value) == "boolean" then
|
||||
return true
|
||||
elseif type(value) ~= "string" and type(value) ~= "number" then
|
||||
return false
|
||||
end
|
||||
|
||||
local v = type(value) == "string" and value:lower() or value
|
||||
return v == "1" or v == "true" or v == "on" or v == "yes" or
|
||||
v == "0" or v == "false" or v == "off" or v == "no" or
|
||||
v == 1 or v == 0
|
||||
end
|
||||
|
||||
-- Convert to boolean
|
||||
function to_bool(value)
|
||||
if type(value) == "boolean" then
|
||||
return value
|
||||
elseif type(value) ~= "string" and type(value) ~= "number" then
|
||||
return false
|
||||
end
|
||||
|
||||
local v = type(value) == "string" and value:lower() or value
|
||||
return v == "1" or v == "true" or v == "on" or v == "yes" or v == 1
|
||||
end
|
||||
|
||||
-- Sanitize string (simple version)
|
||||
function sanitize_string(str)
|
||||
if type(str) ~= "string" then
|
||||
return ""
|
||||
end
|
||||
|
||||
return html_special_chars(str)
|
||||
end
|
||||
|
||||
-- Sanitize to integer
|
||||
function sanitize_int(value)
|
||||
if type(value) ~= "string" and type(value) ~= "number" then
|
||||
return 0
|
||||
end
|
||||
|
||||
value = tostring(value)
|
||||
local result = value:match("^-?%d+")
|
||||
return result and tonumber(result) or 0
|
||||
end
|
||||
|
||||
-- Sanitize to float
|
||||
function sanitize_float(value)
|
||||
if type(value) ~= "string" and type(value) ~= "number" then
|
||||
return 0
|
||||
end
|
||||
|
||||
value = tostring(value)
|
||||
local result = value:match("^-?%d+%.?%d*")
|
||||
return result and tonumber(result) or 0
|
||||
end
|
||||
|
||||
-- Sanitize URL
|
||||
function sanitize_url(str)
|
||||
if type(str) ~= "string" then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Basic sanitization by removing control characters
|
||||
str = str:gsub("[\000-\031]", "")
|
||||
|
||||
-- Make sure it's a valid URL
|
||||
if is_url(str) then
|
||||
return str
|
||||
end
|
||||
|
||||
-- Try to prepend http:// if it's missing
|
||||
if not str:match("^https?://") and is_url("http://" .. str) then
|
||||
return "http://" .. str
|
||||
end
|
||||
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Sanitize email
|
||||
function sanitize_email(str)
|
||||
if type(str) ~= "string" then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Remove all characters except common email characters
|
||||
str = str:gsub("[^%a%d%!%#%$%%%&%'%*%+%-%/%=%?%^%_%`%{%|%}%~%@%.%[%]]", "")
|
||||
|
||||
-- Return only if it's a valid email
|
||||
if is_email(str) then
|
||||
return str
|
||||
end
|
||||
|
||||
return ""
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- SECURITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Basic XSS prevention
|
||||
function xss_clean(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
-- Convert problematic characters to entities
|
||||
local result = html_special_chars(str)
|
||||
|
||||
-- Remove JavaScript event handlers
|
||||
result = result:gsub("on%w+%s*=", "")
|
||||
|
||||
-- Remove JavaScript protocol
|
||||
result = result:gsub("javascript:", "")
|
||||
|
||||
-- Remove CSS expression
|
||||
result = result:gsub("expression%s*%(", "")
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- Base64 encode
|
||||
function base64_encode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return __base64_encode(str)
|
||||
end
|
||||
|
||||
-- Base64 decode
|
||||
function base64_decode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return __base64_decode(str)
|
||||
end
|
||||
430
runner/moduleLoader.go
Normal file
430
runner/moduleLoader.go
Normal file
@ -0,0 +1,430 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"Moonshark/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// ModuleConfig holds configuration for Lua's module loading system
|
||||
type ModuleConfig struct {
|
||||
ScriptDir string // Base directory for script being executed
|
||||
LibDirs []string // Additional library directories
|
||||
}
|
||||
|
||||
// ModuleLoader manages module loading and caching
|
||||
type ModuleLoader struct {
|
||||
config *ModuleConfig
|
||||
pathCache map[string]string // Cache module paths for fast lookups
|
||||
bytecodeCache map[string][]byte // Cache of compiled bytecode
|
||||
debug bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewModuleLoader creates a new module loader
|
||||
func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
|
||||
if config == nil {
|
||||
config = &ModuleConfig{
|
||||
ScriptDir: "",
|
||||
LibDirs: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
return &ModuleLoader{
|
||||
config: config,
|
||||
pathCache: make(map[string]string),
|
||||
bytecodeCache: make(map[string][]byte),
|
||||
debug: false,
|
||||
}
|
||||
}
|
||||
|
||||
// EnableDebug turns on debug logging
|
||||
func (l *ModuleLoader) EnableDebug() {
|
||||
l.debug = true
|
||||
}
|
||||
|
||||
// SetScriptDir sets the script directory
|
||||
func (l *ModuleLoader) SetScriptDir(dir string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.config.ScriptDir = dir
|
||||
}
|
||||
|
||||
// debugLog logs a message if debug mode is enabled
|
||||
func (l *ModuleLoader) debugLog(format string, args ...interface{}) {
|
||||
if l.debug {
|
||||
logger.Debug("ModuleLoader "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupRequire configures the require system in a Lua state
|
||||
func (l *ModuleLoader) SetupRequire(state *luajit.State) error {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
// Initialize our module registry in Lua
|
||||
err := state.DoString(`
|
||||
-- Initialize global module registry
|
||||
__module_paths = {}
|
||||
__module_bytecode = {}
|
||||
__ready_modules = {}
|
||||
|
||||
-- Create module preload table
|
||||
package.preload = package.preload or {}
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up package.path based on search paths
|
||||
paths := l.getSearchPaths()
|
||||
pathStr := strings.Join(paths, ";")
|
||||
escapedPathStr := escapeLuaString(pathStr)
|
||||
|
||||
return state.DoString(`package.path = "` + escapedPathStr + `"`)
|
||||
}
|
||||
|
||||
// getSearchPaths returns a list of Lua search paths
|
||||
func (l *ModuleLoader) getSearchPaths() []string {
|
||||
absPaths := []string{}
|
||||
seen := map[string]bool{}
|
||||
|
||||
// Add script directory (highest priority)
|
||||
if l.config.ScriptDir != "" {
|
||||
absPath, err := filepath.Abs(l.config.ScriptDir)
|
||||
if err == nil && !seen[absPath] {
|
||||
absPaths = append(absPaths, filepath.Join(absPath, "?.lua"))
|
||||
seen[absPath] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add lib directories
|
||||
for _, dir := range l.config.LibDirs {
|
||||
if dir == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(dir)
|
||||
if err == nil && !seen[absPath] {
|
||||
absPaths = append(absPaths, filepath.Join(absPath, "?.lua"))
|
||||
seen[absPath] = true
|
||||
}
|
||||
}
|
||||
|
||||
return absPaths
|
||||
}
|
||||
|
||||
// PreloadModules preloads modules from library directories
|
||||
func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// Reset caches
|
||||
l.pathCache = make(map[string]string)
|
||||
l.bytecodeCache = make(map[string][]byte)
|
||||
|
||||
// Reset module registry in Lua
|
||||
if err := state.DoString(`
|
||||
-- Reset module registry
|
||||
__module_paths = {}
|
||||
__module_bytecode = {}
|
||||
__ready_modules = {}
|
||||
|
||||
-- Clear non-core modules from package.loaded
|
||||
local core_modules = {
|
||||
string = true, table = true, math = true, os = true,
|
||||
package = true, io = true, coroutine = true, debug = true, _G = true
|
||||
}
|
||||
|
||||
for name in pairs(package.loaded) do
|
||||
if not core_modules[name] then
|
||||
package.loaded[name] = nil
|
||||
end
|
||||
end
|
||||
|
||||
-- Reset preload table
|
||||
package.preload = {}
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Scan and preload modules from all library directories
|
||||
for _, dir := range l.config.LibDirs {
|
||||
if dir == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
l.debugLog("Scanning directory: %s", absDir)
|
||||
|
||||
// Find all Lua files
|
||||
err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get module name from path
|
||||
relPath, err := filepath.Rel(absDir, path)
|
||||
if err != nil || strings.HasPrefix(relPath, "..") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert path to module name
|
||||
modName := strings.TrimSuffix(relPath, ".lua")
|
||||
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
||||
|
||||
l.debugLog("Found module: %s at %s", modName, path)
|
||||
|
||||
// Register in our caches
|
||||
l.pathCache[modName] = path
|
||||
|
||||
// Load file content
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
l.debugLog("Failed to read module file: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compile to bytecode
|
||||
bytecode, err := state.CompileBytecode(string(content), path)
|
||||
if err != nil {
|
||||
l.debugLog("Failed to compile module: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cache bytecode
|
||||
l.bytecodeCache[modName] = bytecode
|
||||
|
||||
// Register in Lua - store path info
|
||||
escapedPath := escapeLuaString(path)
|
||||
escapedName := escapeLuaString(modName)
|
||||
|
||||
if err := state.DoString(`__module_paths["` + escapedName + `"] = "` + escapedPath + `"`); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load bytecode and register in package.preload properly
|
||||
if err := state.LoadBytecode(bytecode, path); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store the function in package.preload - the function is on the stack
|
||||
state.GetGlobal("package")
|
||||
state.GetField(-1, "preload")
|
||||
state.PushString(modName)
|
||||
state.PushCopy(-4) // Copy the compiled function
|
||||
state.SetTable(-3) // preload[modName] = function
|
||||
state.Pop(2) // Pop package and preload tables
|
||||
|
||||
// Mark as ready
|
||||
if err := state.DoString(`__ready_modules["` + escapedName + `"] = true`); err != nil {
|
||||
state.Pop(1) // Remove the function from stack
|
||||
return nil
|
||||
}
|
||||
|
||||
state.Pop(1) // Remove the function from stack
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Install optimized require implementation
|
||||
return state.DoString(`
|
||||
-- Setup environment-aware require function
|
||||
function __setup_require(env)
|
||||
-- Create require function specific to this environment
|
||||
env.require = function(modname)
|
||||
-- Check if already loaded
|
||||
if package.loaded[modname] then
|
||||
return package.loaded[modname]
|
||||
end
|
||||
|
||||
-- Check preloaded modules
|
||||
if __ready_modules[modname] then
|
||||
local loader = package.preload[modname]
|
||||
if loader then
|
||||
-- Set environment for loader
|
||||
setfenv(loader, env)
|
||||
|
||||
-- Execute and store result
|
||||
local result = loader()
|
||||
if result == nil then
|
||||
result = true
|
||||
end
|
||||
|
||||
package.loaded[modname] = result
|
||||
return result
|
||||
end
|
||||
end
|
||||
|
||||
-- Direct file load as fallback
|
||||
if __module_paths[modname] then
|
||||
local path = __module_paths[modname]
|
||||
local chunk, err = loadfile(path)
|
||||
if chunk then
|
||||
setfenv(chunk, env)
|
||||
local result = chunk()
|
||||
if result == nil then
|
||||
result = true
|
||||
end
|
||||
package.loaded[modname] = result
|
||||
return result
|
||||
end
|
||||
end
|
||||
|
||||
-- Full path search as last resort
|
||||
local errors = {}
|
||||
for path in package.path:gmatch("[^;]+") do
|
||||
local file_path = path:gsub("?", modname:gsub("%.", "/"))
|
||||
local chunk, err = loadfile(file_path)
|
||||
if chunk then
|
||||
setfenv(chunk, env)
|
||||
local result = chunk()
|
||||
if result == nil then
|
||||
result = true
|
||||
end
|
||||
package.loaded[modname] = result
|
||||
return result
|
||||
end
|
||||
table.insert(errors, "\tno file '" .. file_path .. "'")
|
||||
end
|
||||
|
||||
error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2)
|
||||
end
|
||||
|
||||
return env
|
||||
end
|
||||
`)
|
||||
}
|
||||
|
||||
// GetModuleByPath finds the module name for a file path
|
||||
func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
// Convert to absolute path for consistent comparison
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
absPath = filepath.Clean(path)
|
||||
}
|
||||
|
||||
// Try direct lookup from cache with absolute path
|
||||
for modName, modPath := range l.pathCache {
|
||||
if modPath == absPath {
|
||||
return modName, true
|
||||
}
|
||||
}
|
||||
|
||||
// Try to construct module name from lib dirs
|
||||
for _, dir := range l.config.LibDirs {
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if the file is under this lib directory
|
||||
relPath, err := filepath.Rel(absDir, absPath)
|
||||
if err != nil || strings.HasPrefix(relPath, "..") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasSuffix(relPath, ".lua") {
|
||||
modName := strings.TrimSuffix(relPath, ".lua")
|
||||
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
||||
|
||||
l.debugLog("Found module %s for path %s", modName, path)
|
||||
return modName, true
|
||||
}
|
||||
}
|
||||
|
||||
l.debugLog("No module found for path %s", path)
|
||||
return "", false
|
||||
}
|
||||
|
||||
// RefreshModule recompiles and updates a specific module
|
||||
func (l *ModuleLoader) RefreshModule(state *luajit.State, moduleName string) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// Get module path
|
||||
path, exists := l.pathCache[moduleName]
|
||||
if !exists {
|
||||
l.debugLog("Module not found in cache: %s", moduleName)
|
||||
return fmt.Errorf("module %s not found", moduleName)
|
||||
}
|
||||
|
||||
l.debugLog("Refreshing module: %s at %s", moduleName, path)
|
||||
|
||||
// Read updated file content
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read module file: %w", err)
|
||||
}
|
||||
|
||||
// Recompile to bytecode
|
||||
bytecode, err := state.CompileBytecode(string(content), path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compile module: %w", err)
|
||||
}
|
||||
|
||||
// Update bytecode cache
|
||||
l.bytecodeCache[moduleName] = bytecode
|
||||
|
||||
// Load new bytecode
|
||||
if err := state.LoadBytecode(bytecode, path); err != nil {
|
||||
return fmt.Errorf("failed to load bytecode: %w", err)
|
||||
}
|
||||
|
||||
// Update package.preload with new function (function is on stack)
|
||||
state.GetGlobal("package")
|
||||
state.GetField(-1, "preload")
|
||||
state.PushString(moduleName)
|
||||
state.PushCopy(-4) // Copy the new compiled function
|
||||
state.SetTable(-3) // preload[moduleName] = new_function
|
||||
state.Pop(2) // Pop package and preload tables
|
||||
state.Pop(1) // Pop the function
|
||||
|
||||
// Clear from package.loaded so it gets reloaded
|
||||
escapedName := escapeLuaString(moduleName)
|
||||
if err := state.DoString(`package.loaded["` + escapedName + `"] = nil`); err != nil {
|
||||
return fmt.Errorf("failed to clear loaded module: %w", err)
|
||||
}
|
||||
|
||||
l.debugLog("Successfully refreshed module: %s", moduleName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshModuleByPath refreshes a module by its file path
|
||||
func (l *ModuleLoader) RefreshModuleByPath(state *luajit.State, filePath string) error {
|
||||
moduleName, exists := l.GetModuleByPath(filePath)
|
||||
if !exists {
|
||||
return fmt.Errorf("no module found for path: %s", filePath)
|
||||
}
|
||||
return l.RefreshModule(state, moduleName)
|
||||
}
|
||||
|
||||
// escapeLuaString escapes special characters in a string for Lua
|
||||
func escapeLuaString(s string) string {
|
||||
replacer := strings.NewReplacer(
|
||||
"\\", "\\\\",
|
||||
"\"", "\\\"",
|
||||
"\n", "\\n",
|
||||
"\r", "\\r",
|
||||
"\t", "\\t",
|
||||
)
|
||||
return replacer.Replace(s)
|
||||
}
|
||||
98
runner/password.go
Normal file
98
runner/password.go
Normal file
@ -0,0 +1,98 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/alexedwards/argon2id"
|
||||
)
|
||||
|
||||
// RegisterPasswordFunctions registers password-related functions in the Lua state
|
||||
func RegisterPasswordFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__password_hash", passwordHash); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__password_verify", passwordVerify); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// passwordHash implements the Argon2id password hashing using alexedwards/argon2id
|
||||
func passwordHash(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("password_hash error: expected string password")
|
||||
return 1
|
||||
}
|
||||
|
||||
password := state.ToString(1)
|
||||
|
||||
params := &argon2id.Params{
|
||||
Memory: 128 * 1024,
|
||||
Iterations: 4,
|
||||
Parallelism: 4,
|
||||
SaltLength: 16,
|
||||
KeyLength: 32,
|
||||
}
|
||||
|
||||
if state.IsTable(2) {
|
||||
state.GetField(2, "memory")
|
||||
if state.IsNumber(-1) {
|
||||
params.Memory = max(uint32(state.ToNumber(-1)), 8*1024)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(2, "iterations")
|
||||
if state.IsNumber(-1) {
|
||||
params.Iterations = max(uint32(state.ToNumber(-1)), 1)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(2, "parallelism")
|
||||
if state.IsNumber(-1) {
|
||||
params.Parallelism = max(uint8(state.ToNumber(-1)), 1)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(2, "salt_length")
|
||||
if state.IsNumber(-1) {
|
||||
params.SaltLength = max(uint32(state.ToNumber(-1)), 8)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(2, "key_length")
|
||||
if state.IsNumber(-1) {
|
||||
params.KeyLength = max(uint32(state.ToNumber(-1)), 16)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
hash, err := argon2id.CreateHash(password, params)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("password_hash error: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
state.PushString(hash)
|
||||
return 1
|
||||
}
|
||||
|
||||
// passwordVerify verifies a password against a hash
|
||||
func passwordVerify(state *luajit.State) int {
|
||||
if !state.IsString(1) || !state.IsString(2) {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
password := state.ToString(1)
|
||||
hash := state.ToString(2)
|
||||
|
||||
match, err := argon2id.ComparePasswordAndHash(password, hash)
|
||||
if err != nil {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
state.PushBoolean(match)
|
||||
return 1
|
||||
}
|
||||
56
runner/response.go
Normal file
56
runner/response.go
Normal file
@ -0,0 +1,56 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Response represents a unified response from script execution
|
||||
type Response struct {
|
||||
// Basic properties
|
||||
Body any // Body content (any type)
|
||||
Metadata map[string]any // Additional metadata
|
||||
|
||||
// HTTP specific properties
|
||||
Status int // HTTP status code
|
||||
Headers map[string]string // HTTP headers
|
||||
Cookies []*fasthttp.Cookie // HTTP cookies
|
||||
|
||||
// Session information
|
||||
SessionData map[string]any
|
||||
}
|
||||
|
||||
// Response pool to reduce allocations
|
||||
var responsePool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Response{
|
||||
Status: 200,
|
||||
Headers: make(map[string]string, 8),
|
||||
Metadata: make(map[string]any, 8),
|
||||
Cookies: make([]*fasthttp.Cookie, 0, 4),
|
||||
SessionData: make(map[string]any, 8),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// NewResponse creates a new response object from the pool
|
||||
func NewResponse() *Response {
|
||||
return responsePool.Get().(*Response)
|
||||
}
|
||||
|
||||
// Release returns a response to the pool after cleaning it
|
||||
func ReleaseResponse(resp *Response) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp.Body = nil
|
||||
resp.Status = 200
|
||||
resp.Headers = make(map[string]string, 8)
|
||||
resp.Metadata = make(map[string]any, 8)
|
||||
resp.Cookies = resp.Cookies[:0]
|
||||
resp.SessionData = make(map[string]any, 8)
|
||||
|
||||
responsePool.Put(resp)
|
||||
}
|
||||
566
runner/runner.go
Normal file
566
runner/runner.go
Normal file
@ -0,0 +1,566 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"Moonshark/utils/color"
|
||||
"Moonshark/utils/logger"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrRunnerClosed = errors.New("lua runner is closed")
|
||||
ErrInitFailed = errors.New("initialization failed")
|
||||
ErrStateNotReady = errors.New("lua state not ready")
|
||||
ErrTimeout = errors.New("operation timed out")
|
||||
)
|
||||
|
||||
// RunnerOption defines a functional option for configuring the Runner
|
||||
type RunnerOption func(*Runner)
|
||||
|
||||
// State wraps a Lua state with its sandbox
|
||||
type State struct {
|
||||
L *luajit.State // The Lua state
|
||||
sandbox *Sandbox // Associated sandbox
|
||||
index int // Index for debugging
|
||||
inUse atomic.Bool // Whether the state is currently in use
|
||||
}
|
||||
|
||||
// Runner runs Lua scripts using a pool of Lua states
|
||||
type Runner struct {
|
||||
states []*State // All states managed by this runner
|
||||
statePool chan int // Pool of available state indexes
|
||||
poolSize int // Size of the state pool
|
||||
moduleLoader *ModuleLoader // Module loader
|
||||
dataDir string // Data directory for SQLite databases
|
||||
fsDir string // Virtual filesystem directory
|
||||
isRunning atomic.Bool // Whether the runner is active
|
||||
mu sync.RWMutex // Mutex for thread safety
|
||||
scriptDir string // Current script directory
|
||||
}
|
||||
|
||||
// WithPoolSize sets the state pool size
|
||||
func WithPoolSize(size int) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
if size > 0 {
|
||||
r.poolSize = size
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithLibDirs sets additional library directories
|
||||
func WithLibDirs(dirs ...string) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
if r.moduleLoader == nil {
|
||||
r.moduleLoader = NewModuleLoader(&ModuleConfig{
|
||||
LibDirs: dirs,
|
||||
})
|
||||
} else {
|
||||
r.moduleLoader.config.LibDirs = dirs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithDataDir sets the data directory for SQLite databases
|
||||
func WithDataDir(dataDir string) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
if dataDir != "" {
|
||||
r.dataDir = dataDir
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithFsDir sets the virtual filesystem directory
|
||||
func WithFsDir(fsDir string) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
if fsDir != "" {
|
||||
r.fsDir = fsDir
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewRunner creates a new Runner with a pool of states
|
||||
func NewRunner(options ...RunnerOption) (*Runner, error) {
|
||||
// Default configuration
|
||||
runner := &Runner{
|
||||
poolSize: runtime.GOMAXPROCS(0),
|
||||
dataDir: "data",
|
||||
fsDir: "fs",
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range options {
|
||||
opt(runner)
|
||||
}
|
||||
|
||||
// Set up module loader if not already initialized
|
||||
if runner.moduleLoader == nil {
|
||||
config := &ModuleConfig{
|
||||
ScriptDir: "",
|
||||
LibDirs: []string{},
|
||||
}
|
||||
runner.moduleLoader = NewModuleLoader(config)
|
||||
}
|
||||
|
||||
InitSQLite(runner.dataDir)
|
||||
InitFS(runner.fsDir)
|
||||
|
||||
SetSQLitePoolSize(runner.poolSize)
|
||||
|
||||
// Initialize states and pool
|
||||
runner.states = make([]*State, runner.poolSize)
|
||||
runner.statePool = make(chan int, runner.poolSize)
|
||||
|
||||
// Create and initialize all states
|
||||
if err := runner.initializeStates(); err != nil {
|
||||
CleanupSQLite()
|
||||
runner.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
runner.isRunning.Store(true)
|
||||
return runner, nil
|
||||
}
|
||||
|
||||
// initializeStates creates and initializes all states in the pool
|
||||
func (r *Runner) initializeStates() error {
|
||||
logger.Info("[LuaRunner] Creating %s states...", color.Apply(strconv.Itoa(r.poolSize), color.Yellow))
|
||||
|
||||
for i := range r.poolSize {
|
||||
state, err := r.createState(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.states[i] = state
|
||||
r.statePool <- i // Add index to the pool
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createState initializes a new Lua state
|
||||
func (r *Runner) createState(index int) (*State, error) {
|
||||
verbose := index == 0
|
||||
if verbose {
|
||||
logger.Debug("Creating Lua state %d", index)
|
||||
}
|
||||
|
||||
L := luajit.New()
|
||||
if L == nil {
|
||||
return nil, errors.New("failed to create Lua state")
|
||||
}
|
||||
|
||||
sb := NewSandbox()
|
||||
|
||||
// Set up sandbox
|
||||
if err := sb.Setup(L, verbose); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Set up module loader
|
||||
if err := r.moduleLoader.SetupRequire(L); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Preload modules
|
||||
if err := r.moduleLoader.PreloadModules(L); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, errors.New("failed to preload modules")
|
||||
}
|
||||
|
||||
if verbose {
|
||||
logger.Debug("Lua state %d initialized successfully", index)
|
||||
}
|
||||
|
||||
return &State{
|
||||
L: L,
|
||||
sandbox: sb,
|
||||
index: index,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Execute runs a script in a sandbox with context
|
||||
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
|
||||
if !r.isRunning.Load() {
|
||||
return nil, ErrRunnerClosed
|
||||
}
|
||||
|
||||
// Set script directory if provided
|
||||
if scriptPath != "" {
|
||||
r.mu.Lock()
|
||||
r.scriptDir = filepath.Dir(scriptPath)
|
||||
r.moduleLoader.SetScriptDir(r.scriptDir)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// Get a state from the pool
|
||||
var stateIndex int
|
||||
select {
|
||||
case stateIndex = <-r.statePool:
|
||||
// Got a state
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(1 * time.Second):
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
state := r.states[stateIndex]
|
||||
if state == nil {
|
||||
r.statePool <- stateIndex
|
||||
return nil, ErrStateNotReady
|
||||
}
|
||||
|
||||
// Use atomic operations
|
||||
state.inUse.Store(true)
|
||||
|
||||
defer func() {
|
||||
state.inUse.Store(false)
|
||||
if r.isRunning.Load() {
|
||||
select {
|
||||
case r.statePool <- stateIndex:
|
||||
default:
|
||||
// Pool is full or closed, state will be cleaned up by Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Execute in sandbox
|
||||
response, err := state.sandbox.Execute(state.L, bytecode, execCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Run executes a Lua script with immediate context
|
||||
func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
|
||||
return r.Execute(context.Background(), bytecode, execCtx, scriptPath)
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the Runner
|
||||
func (r *Runner) Close() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if !r.isRunning.Load() {
|
||||
return ErrRunnerClosed
|
||||
}
|
||||
|
||||
r.isRunning.Store(false)
|
||||
|
||||
// Drain all states from the pool
|
||||
for {
|
||||
select {
|
||||
case <-r.statePool:
|
||||
default:
|
||||
goto waitForInUse
|
||||
}
|
||||
}
|
||||
|
||||
waitForInUse:
|
||||
// Wait for in-use states to finish (with timeout)
|
||||
timeout := time.Now().Add(10 * time.Second)
|
||||
for {
|
||||
allIdle := true
|
||||
for _, state := range r.states {
|
||||
if state != nil && state.inUse.Load() {
|
||||
allIdle = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allIdle {
|
||||
break
|
||||
}
|
||||
|
||||
if time.Now().After(timeout) {
|
||||
logger.Warning("Timeout waiting for states to finish during shutdown, forcing close")
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Now safely close all states
|
||||
for i, state := range r.states {
|
||||
if state != nil {
|
||||
if state.inUse.Load() {
|
||||
logger.Warning("Force closing state %d that is still in use", i)
|
||||
}
|
||||
state.L.Cleanup()
|
||||
state.L.Close()
|
||||
r.states[i] = nil
|
||||
}
|
||||
}
|
||||
|
||||
CleanupFS()
|
||||
CleanupSQLite()
|
||||
|
||||
logger.Debug("Runner closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshStates rebuilds all states in the pool
|
||||
func (r *Runner) RefreshStates() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if !r.isRunning.Load() {
|
||||
return ErrRunnerClosed
|
||||
}
|
||||
|
||||
logger.Info("Runner is refreshing all states...")
|
||||
|
||||
// Drain all states from the pool
|
||||
for {
|
||||
select {
|
||||
case <-r.statePool:
|
||||
default:
|
||||
goto waitForInUse
|
||||
}
|
||||
}
|
||||
|
||||
waitForInUse:
|
||||
// Wait for in-use states to finish (with timeout)
|
||||
timeout := time.Now().Add(10 * time.Second)
|
||||
for {
|
||||
allIdle := true
|
||||
for _, state := range r.states {
|
||||
if state != nil && state.inUse.Load() {
|
||||
allIdle = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allIdle {
|
||||
break
|
||||
}
|
||||
|
||||
if time.Now().After(timeout) {
|
||||
logger.Warning("Timeout waiting for states to finish, forcing refresh")
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Now safely destroy all states
|
||||
for i, state := range r.states {
|
||||
if state != nil {
|
||||
if state.inUse.Load() {
|
||||
logger.Warning("Force closing state %d that is still in use", i)
|
||||
}
|
||||
state.L.Cleanup()
|
||||
state.L.Close()
|
||||
r.states[i] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Reinitialize all states
|
||||
if err := r.initializeStates(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Debug("All states refreshed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// NotifyFileChanged alerts the runner about file changes
|
||||
func (r *Runner) NotifyFileChanged(filePath string) bool {
|
||||
logger.Debug("Runner notified of file change: %s", filePath)
|
||||
|
||||
module, isModule := r.moduleLoader.GetModuleByPath(filePath)
|
||||
if isModule {
|
||||
logger.Debug("Refreshing module: %s", module)
|
||||
return r.RefreshModule(module)
|
||||
}
|
||||
|
||||
logger.Debug("File change noted but no refresh needed: %s", filePath)
|
||||
return true
|
||||
}
|
||||
|
||||
// RefreshModule refreshes a specific module across all states
|
||||
func (r *Runner) RefreshModule(moduleName string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if !r.isRunning.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debug("Refreshing module: %s", moduleName)
|
||||
|
||||
success := true
|
||||
for _, state := range r.states {
|
||||
if state == nil || state.inUse.Load() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the enhanced module loader refresh
|
||||
if err := r.moduleLoader.RefreshModule(state.L, moduleName); err != nil {
|
||||
success = false
|
||||
logger.Debug("Failed to refresh module %s in state %d: %v", moduleName, state.index, err)
|
||||
}
|
||||
}
|
||||
|
||||
if success {
|
||||
logger.Debug("Successfully refreshed module: %s", moduleName)
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
// RefreshModuleByPath refreshes a module by its file path
|
||||
func (r *Runner) RefreshModuleByPath(filePath string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if !r.isRunning.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debug("Refreshing module by path: %s", filePath)
|
||||
|
||||
success := true
|
||||
for _, state := range r.states {
|
||||
if state == nil || state.inUse.Load() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the enhanced module loader refresh by path
|
||||
if err := r.moduleLoader.RefreshModuleByPath(state.L, filePath); err != nil {
|
||||
success = false
|
||||
logger.Debug("Failed to refresh module at %s in state %d: %v", filePath, state.index, err)
|
||||
}
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
// GetStateCount returns the number of initialized states
|
||||
func (r *Runner) GetStateCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, state := range r.states {
|
||||
if state != nil {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// GetActiveStateCount returns the number of states currently in use
|
||||
func (r *Runner) GetActiveStateCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, state := range r.states {
|
||||
if state != nil && state.inUse.Load() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// RunScriptFile loads, compiles and executes a Lua script file
|
||||
func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
|
||||
if !r.isRunning.Load() {
|
||||
return nil, ErrRunnerClosed
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("script file not found: %s", filePath)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
scriptDir := filepath.Dir(absPath)
|
||||
|
||||
r.mu.Lock()
|
||||
prevScriptDir := r.scriptDir
|
||||
r.scriptDir = scriptDir
|
||||
r.moduleLoader.SetScriptDir(scriptDir)
|
||||
r.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
r.mu.Lock()
|
||||
r.scriptDir = prevScriptDir
|
||||
r.moduleLoader.SetScriptDir(prevScriptDir)
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
var stateIndex int
|
||||
select {
|
||||
case stateIndex = <-r.statePool:
|
||||
// Got a state
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
state := r.states[stateIndex]
|
||||
if state == nil {
|
||||
r.statePool <- stateIndex
|
||||
return nil, ErrStateNotReady
|
||||
}
|
||||
|
||||
state.inUse.Store(true)
|
||||
|
||||
defer func() {
|
||||
state.inUse.Store(false)
|
||||
if r.isRunning.Load() {
|
||||
select {
|
||||
case r.statePool <- stateIndex:
|
||||
// State returned to pool
|
||||
default:
|
||||
// Pool is full or closed
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
bytecode, err := state.L.CompileBytecode(string(content), filepath.Base(absPath))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compilation error: %w", err)
|
||||
}
|
||||
|
||||
ctx := NewContext()
|
||||
defer ctx.Release()
|
||||
|
||||
ctx.Set("_script_path", absPath)
|
||||
ctx.Set("_script_dir", scriptDir)
|
||||
|
||||
response, err := state.sandbox.Execute(state.L, bytecode, ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execution error: %w", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
334
runner/sandbox.go
Normal file
334
runner/sandbox.go
Normal file
@ -0,0 +1,334 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"Moonshark/utils/logger"
|
||||
|
||||
"maps"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Error represents a simple error string
|
||||
type Error string
|
||||
|
||||
func (e Error) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Error types
|
||||
var (
|
||||
ErrSandboxNotInitialized = Error("sandbox not initialized")
|
||||
)
|
||||
|
||||
// Sandbox provides a secure execution environment for Lua scripts
|
||||
type Sandbox struct {
|
||||
modules map[string]any
|
||||
debug bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSandbox creates a new sandbox environment
|
||||
func NewSandbox() *Sandbox {
|
||||
return &Sandbox{
|
||||
modules: make(map[string]any, 8),
|
||||
debug: false,
|
||||
}
|
||||
}
|
||||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
func (s *Sandbox) AddModule(name string, module any) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.modules[name] = module
|
||||
logger.Debug("Added module: %s", name)
|
||||
}
|
||||
|
||||
// Setup initializes the sandbox in a Lua state
|
||||
func (s *Sandbox) Setup(state *luajit.State, verbose bool) error {
|
||||
if verbose {
|
||||
logger.Debug("Setting up sandbox...")
|
||||
}
|
||||
|
||||
if err := loadSandboxIntoState(state, verbose); err != nil {
|
||||
logger.Error("Failed to load sandbox: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.registerCoreFunctions(state); err != nil {
|
||||
logger.Error("Failed to register core functions: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
for name, module := range s.modules {
|
||||
logger.Debug("Registering module: %s", name)
|
||||
if err := state.PushValue(module); err != nil {
|
||||
s.mu.RUnlock()
|
||||
logger.Error("Failed to register module %s: %v", name, err)
|
||||
return err
|
||||
}
|
||||
state.SetGlobal(name)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if verbose {
|
||||
logger.Debug("Sandbox setup complete")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerCoreFunctions registers all built-in functions in the Lua state
|
||||
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := state.RegisterGoFunction("__json_marshal", jsonMarshal); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := state.RegisterGoFunction("__json_unmarshal", jsonUnmarshal); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterSQLiteFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterFSFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterPasswordFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterUtilFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterCryptoFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := RegisterEnvFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute runs a Lua script in the sandbox with the given context
|
||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
|
||||
state.GetGlobal("__execute_script")
|
||||
if !state.IsFunction(-1) {
|
||||
state.Pop(1)
|
||||
return nil, ErrSandboxNotInitialized
|
||||
}
|
||||
|
||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||
state.Pop(1) // Pop the __execute_script function
|
||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||
}
|
||||
|
||||
// Push context values
|
||||
if err := state.PushTable(ctx.Values); err != nil {
|
||||
state.Pop(2) // Pop bytecode and __execute_script
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute with 2 args, 1 result
|
||||
if err := state.Call(2, 1); err != nil {
|
||||
return nil, fmt.Errorf("script execution failed: %w", err)
|
||||
}
|
||||
|
||||
body, err := state.ToValue(-1)
|
||||
state.Pop(1)
|
||||
|
||||
response := NewResponse()
|
||||
if err == nil {
|
||||
response.Body = body
|
||||
}
|
||||
|
||||
extractHTTPResponseData(state, response)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// extractResponseData pulls response info from the Lua state
|
||||
func extractHTTPResponseData(state *luajit.State, response *Response) {
|
||||
state.GetGlobal("__http_response")
|
||||
if !state.IsTable(-1) {
|
||||
state.Pop(1)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract status
|
||||
state.GetField(-1, "status")
|
||||
if state.IsNumber(-1) {
|
||||
response.Status = int(state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract headers
|
||||
state.GetField(-1, "headers")
|
||||
if state.IsTable(-1) {
|
||||
state.PushNil() // Start iteration
|
||||
for state.Next(-2) {
|
||||
if state.IsString(-2) && state.IsString(-1) {
|
||||
key := state.ToString(-2)
|
||||
value := state.ToString(-1)
|
||||
response.Headers[key] = value
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract cookies
|
||||
state.GetField(-1, "cookies")
|
||||
if state.IsTable(-1) {
|
||||
length := state.GetTableLength(-1)
|
||||
for i := 1; i <= length; i++ {
|
||||
state.PushNumber(float64(i))
|
||||
state.GetTable(-2)
|
||||
|
||||
if state.IsTable(-1) {
|
||||
extractCookie(state, response)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract metadata
|
||||
state.GetField(-1, "metadata")
|
||||
if state.IsTable(-1) {
|
||||
table, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
maps.Copy(response.Metadata, table)
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract session data
|
||||
state.GetField(-1, "session")
|
||||
if state.IsTable(-1) {
|
||||
table, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
maps.Copy(response.SessionData, table)
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.Pop(1) // Pop __http_response
|
||||
}
|
||||
|
||||
// extractCookie pulls cookie data from the current table on the stack
|
||||
func extractCookie(state *luajit.State, response *Response) {
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
|
||||
// Get name (required)
|
||||
state.GetField(-1, "name")
|
||||
if !state.IsString(-1) {
|
||||
state.Pop(1)
|
||||
fasthttp.ReleaseCookie(cookie)
|
||||
return
|
||||
}
|
||||
cookie.SetKey(state.ToString(-1))
|
||||
state.Pop(1)
|
||||
|
||||
// Get value
|
||||
state.GetField(-1, "value")
|
||||
if state.IsString(-1) {
|
||||
cookie.SetValue(state.ToString(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get path
|
||||
state.GetField(-1, "path")
|
||||
if state.IsString(-1) {
|
||||
cookie.SetPath(state.ToString(-1))
|
||||
} else {
|
||||
cookie.SetPath("/") // Default
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get domain
|
||||
state.GetField(-1, "domain")
|
||||
if state.IsString(-1) {
|
||||
cookie.SetDomain(state.ToString(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get other parameters
|
||||
state.GetField(-1, "http_only")
|
||||
if state.IsBoolean(-1) {
|
||||
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(-1, "secure")
|
||||
if state.IsBoolean(-1) {
|
||||
cookie.SetSecure(state.ToBoolean(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(-1, "max_age")
|
||||
if state.IsNumber(-1) {
|
||||
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
response.Cookies = append(response.Cookies, cookie)
|
||||
}
|
||||
|
||||
// jsonMarshal converts a Lua value to a JSON string
|
||||
func jsonMarshal(state *luajit.State) int {
|
||||
value, err := state.ToValue(1)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("json marshal error: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("json marshal error: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
state.PushString(string(bytes))
|
||||
return 1
|
||||
}
|
||||
|
||||
// jsonUnmarshal converts a JSON string to a Lua value
|
||||
func jsonUnmarshal(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushString("json unmarshal error: expected string")
|
||||
return 1
|
||||
}
|
||||
jsonStr := state.ToString(1)
|
||||
|
||||
var value any
|
||||
err := json.Unmarshal([]byte(jsonStr), &value)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("json unmarshal error: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
if err := state.PushValue(value); err != nil {
|
||||
state.PushString(fmt.Sprintf("json unmarshal error: %v", err))
|
||||
return 1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
427
runner/sqlite.go
Normal file
427
runner/sqlite.go
Normal file
@ -0,0 +1,427 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sqlite "zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
|
||||
"Moonshark/utils/color"
|
||||
"Moonshark/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
var (
|
||||
dbPools = make(map[string]*sqlitex.Pool)
|
||||
poolsMu sync.RWMutex
|
||||
dataDir string
|
||||
poolSize = 8 // Default, will be set to match runner pool size
|
||||
connTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// InitSQLite initializes the SQLite subsystem
|
||||
func InitSQLite(dir string) {
|
||||
dataDir = dir
|
||||
logger.Info("SQLite is g2g! %s", color.Apply(dir, color.Yellow))
|
||||
}
|
||||
|
||||
// SetSQLitePoolSize sets the pool size to match the runner pool size
|
||||
func SetSQLitePoolSize(size int) {
|
||||
if size > 0 {
|
||||
poolSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupSQLite closes all database connections
|
||||
func CleanupSQLite() {
|
||||
poolsMu.Lock()
|
||||
defer poolsMu.Unlock()
|
||||
|
||||
for name, pool := range dbPools {
|
||||
if err := pool.Close(); err != nil {
|
||||
logger.Error("Failed to close database %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
dbPools = make(map[string]*sqlitex.Pool)
|
||||
logger.Debug("SQLite connections closed")
|
||||
}
|
||||
|
||||
// getPool returns a connection pool for the database
|
||||
func getPool(dbName string) (*sqlitex.Pool, error) {
|
||||
// Validate database name
|
||||
dbName = filepath.Base(dbName)
|
||||
if dbName == "" || dbName[0] == '.' {
|
||||
return nil, fmt.Errorf("invalid database name")
|
||||
}
|
||||
|
||||
// Check for existing pool
|
||||
poolsMu.RLock()
|
||||
pool, exists := dbPools[dbName]
|
||||
if exists {
|
||||
poolsMu.RUnlock()
|
||||
return pool, nil
|
||||
}
|
||||
poolsMu.RUnlock()
|
||||
|
||||
// Create new pool under write lock
|
||||
poolsMu.Lock()
|
||||
defer poolsMu.Unlock()
|
||||
|
||||
// Double-check if a pool was created while waiting for lock
|
||||
if pool, exists = dbPools[dbName]; exists {
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Create new pool with proper size
|
||||
dbPath := filepath.Join(dataDir, dbName+".db")
|
||||
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
||||
PoolSize: poolSize,
|
||||
PrepareConn: func(conn *sqlite.Conn) error {
|
||||
// Execute PRAGMA statements individually
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode = WAL",
|
||||
"PRAGMA synchronous = NORMAL",
|
||||
"PRAGMA cache_size = 1000",
|
||||
"PRAGMA foreign_keys = ON",
|
||||
"PRAGMA temp_store = MEMORY",
|
||||
}
|
||||
for _, pragma := range pragmas {
|
||||
if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
dbPools[dbName] = pool
|
||||
logger.Debug("Created SQLite pool for %s (size: %d)", dbName, poolSize)
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// sqlQuery executes a SQL query and returns results
|
||||
func sqlQuery(state *luajit.State) int {
|
||||
// Get required parameters
|
||||
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
||||
state.PushString("sqlite.query: requires database name and query")
|
||||
return -1
|
||||
}
|
||||
|
||||
dbName := state.ToString(1)
|
||||
query := state.ToString(2)
|
||||
|
||||
// Get pool
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Get connection with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
// Create execution options
|
||||
var execOpts sqlitex.ExecOptions
|
||||
rows := make([]map[string]any, 0, 16)
|
||||
|
||||
// Set up parameters if provided
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
// Set up result function
|
||||
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
row := make(map[string]any)
|
||||
colCount := stmt.ColumnCount()
|
||||
|
||||
for i := range colCount {
|
||||
colName := stmt.ColumnName(i)
|
||||
switch stmt.ColumnType(i) {
|
||||
case sqlite.TypeInteger:
|
||||
row[colName] = stmt.ColumnInt64(i)
|
||||
case sqlite.TypeFloat:
|
||||
row[colName] = stmt.ColumnFloat(i)
|
||||
case sqlite.TypeText:
|
||||
row[colName] = stmt.ColumnText(i)
|
||||
case sqlite.TypeBlob:
|
||||
blobSize := stmt.ColumnLen(i)
|
||||
if blobSize > 0 {
|
||||
buf := make([]byte, blobSize)
|
||||
row[colName] = stmt.ColumnBytes(i, buf)
|
||||
} else {
|
||||
row[colName] = []byte{}
|
||||
}
|
||||
case sqlite.TypeNull:
|
||||
row[colName] = nil
|
||||
}
|
||||
}
|
||||
rows = append(rows, row)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute query
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Create result table
|
||||
state.NewTable()
|
||||
for i, row := range rows {
|
||||
state.PushNumber(float64(i + 1))
|
||||
if err := state.PushTable(row); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
state.SetTable(-3)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// sqlExec executes a SQL statement without returning results
|
||||
func sqlExec(state *luajit.State) int {
|
||||
// Get required parameters
|
||||
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
||||
state.PushString("sqlite.exec: requires database name and query")
|
||||
return -1
|
||||
}
|
||||
|
||||
dbName := state.ToString(1)
|
||||
query := state.ToString(2)
|
||||
|
||||
// Get pool
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Get connection with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
// Check if parameters are provided
|
||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||
|
||||
// Fast path for multi-statement scripts
|
||||
if strings.Contains(query, ";") && !hasParams {
|
||||
if err := sqlitex.ExecScript(conn, query); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Fast path for simple queries with no parameters
|
||||
if !hasParams {
|
||||
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Create execution options for parameterized query
|
||||
var execOpts sqlitex.ExecOptions
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Execute with parameters
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Return affected rows
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
}
|
||||
|
||||
// setupParams configures execution options with parameters from Lua
|
||||
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
|
||||
if state.IsTable(paramIndex) {
|
||||
params, err := state.ToTable(paramIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid parameters: %w", err)
|
||||
}
|
||||
|
||||
// Check for array-style params
|
||||
if arr, ok := params[""]; ok {
|
||||
if arrParams, ok := arr.([]any); ok {
|
||||
execOpts.Args = arrParams
|
||||
} else if floatArr, ok := arr.([]float64); ok {
|
||||
args := make([]any, len(floatArr))
|
||||
for i, v := range floatArr {
|
||||
args[i] = v
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
} else {
|
||||
// Named parameters
|
||||
named := make(map[string]any, len(params))
|
||||
for k, v := range params {
|
||||
if len(k) > 0 && k[0] != ':' {
|
||||
named[":"+k] = v
|
||||
} else {
|
||||
named[k] = v
|
||||
}
|
||||
}
|
||||
execOpts.Named = named
|
||||
}
|
||||
} else {
|
||||
// Positional parameters from stack
|
||||
count := state.GetTop() - 2
|
||||
args := make([]any, count)
|
||||
for i := range count {
|
||||
idx := i + 3
|
||||
val, err := state.ToValue(idx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid parameter %d: %w", i+1, err)
|
||||
}
|
||||
args[i] = val
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sqlGetOne executes a query and returns only the first row
|
||||
func sqlGetOne(state *luajit.State) int {
|
||||
// Get required parameters
|
||||
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
||||
state.PushString("sqlite.get_one: requires database name and query")
|
||||
return -1
|
||||
}
|
||||
|
||||
dbName := state.ToString(1)
|
||||
query := state.ToString(2)
|
||||
|
||||
// Get pool
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Get connection with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
// Create execution options
|
||||
var execOpts sqlitex.ExecOptions
|
||||
var result map[string]any
|
||||
|
||||
// Set up parameters if provided
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
// Set up result function to get only first row
|
||||
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
if result != nil {
|
||||
return nil // Already got first row
|
||||
}
|
||||
|
||||
result = make(map[string]any)
|
||||
colCount := stmt.ColumnCount()
|
||||
|
||||
for i := range colCount {
|
||||
colName := stmt.ColumnName(i)
|
||||
switch stmt.ColumnType(i) {
|
||||
case sqlite.TypeInteger:
|
||||
result[colName] = stmt.ColumnInt64(i)
|
||||
case sqlite.TypeFloat:
|
||||
result[colName] = stmt.ColumnFloat(i)
|
||||
case sqlite.TypeText:
|
||||
result[colName] = stmt.ColumnText(i)
|
||||
case sqlite.TypeBlob:
|
||||
blobSize := stmt.ColumnLen(i)
|
||||
if blobSize > 0 {
|
||||
buf := make([]byte, blobSize)
|
||||
result[colName] = stmt.ColumnBytes(i, buf)
|
||||
} else {
|
||||
result[colName] = []byte{}
|
||||
}
|
||||
case sqlite.TypeNull:
|
||||
result[colName] = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute query
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Return result or nil if no rows
|
||||
if result == nil {
|
||||
state.PushNil()
|
||||
} else {
|
||||
if err := state.PushTable(result); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
|
||||
func RegisterSQLiteFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
runner/util.go
Normal file
116
runner/util.go
Normal file
@ -0,0 +1,116 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"html"
|
||||
"strings"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// RegisterUtilFunctions registers utility functions with the Lua state
|
||||
func RegisterUtilFunctions(state *luajit.State) error {
|
||||
// HTML special chars
|
||||
if err := state.RegisterGoFunction("__html_special_chars", htmlSpecialChars); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// HTML entities
|
||||
if err := state.RegisterGoFunction("__html_entities", htmlEntities); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Base64 encode
|
||||
if err := state.RegisterGoFunction("__base64_encode", base64Encode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Base64 decode
|
||||
if err := state.RegisterGoFunction("__base64_decode", base64Decode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// htmlSpecialChars converts special characters to HTML entities
|
||||
func htmlSpecialChars(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
input := state.ToString(1)
|
||||
result := html.EscapeString(input)
|
||||
state.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
// htmlEntities is a more comprehensive version of htmlSpecialChars
|
||||
func htmlEntities(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
input := state.ToString(1)
|
||||
// First use HTML escape for standard entities
|
||||
result := html.EscapeString(input)
|
||||
|
||||
// Additional entities beyond what html.EscapeString handles
|
||||
replacements := map[string]string{
|
||||
"©": "©",
|
||||
"®": "®",
|
||||
"™": "™",
|
||||
"€": "€",
|
||||
"£": "£",
|
||||
"¥": "¥",
|
||||
"—": "—",
|
||||
"–": "–",
|
||||
"…": "…",
|
||||
"•": "•",
|
||||
"°": "°",
|
||||
"±": "±",
|
||||
"¼": "¼",
|
||||
"½": "½",
|
||||
"¾": "¾",
|
||||
}
|
||||
|
||||
for char, entity := range replacements {
|
||||
result = strings.ReplaceAll(result, char, entity)
|
||||
}
|
||||
|
||||
state.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
// base64Encode encodes a string to base64
|
||||
func base64Encode(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
input := state.ToString(1)
|
||||
result := base64.StdEncoding.EncodeToString([]byte(input))
|
||||
state.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
// base64Decode decodes a base64 string
|
||||
func base64Decode(state *luajit.State) int {
|
||||
if !state.IsString(1) {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
input := state.ToString(1)
|
||||
result, err := base64.StdEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.PushString(string(result))
|
||||
return 1
|
||||
}
|
||||
212
sessions/manager.go
Normal file
212
sessions/manager.go
Normal file
@ -0,0 +1,212 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/VictoriaMetrics/fastcache"
|
||||
gonanoid "github.com/matoous/go-nanoid/v2"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxSessions = 10000
|
||||
DefaultCookieName = "MoonsharkSID"
|
||||
DefaultCookiePath = "/"
|
||||
DefaultMaxAge = 86400 // 1 day in seconds
|
||||
CleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// SessionManager handles multiple sessions
|
||||
type SessionManager struct {
|
||||
cache *fastcache.Cache
|
||||
cookieName string
|
||||
cookiePath string
|
||||
cookieDomain string
|
||||
cookieSecure bool
|
||||
cookieHTTPOnly bool
|
||||
cookieMaxAge int
|
||||
cookieMu sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
cleanupDone chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager
|
||||
func NewSessionManager(maxSessions int) *SessionManager {
|
||||
if maxSessions <= 0 {
|
||||
maxSessions = DefaultMaxSessions
|
||||
}
|
||||
|
||||
sm := &SessionManager{
|
||||
cache: fastcache.New(maxSessions * 4096),
|
||||
cookieName: DefaultCookieName,
|
||||
cookiePath: DefaultCookiePath,
|
||||
cookieHTTPOnly: true,
|
||||
cookieMaxAge: DefaultMaxAge,
|
||||
cleanupDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Pre-populate session pool
|
||||
for i := 0; i < 100; i++ {
|
||||
s := NewSession("", 0)
|
||||
s.Release()
|
||||
}
|
||||
|
||||
sm.cleanupTicker = time.NewTicker(CleanupInterval)
|
||||
go sm.cleanupRoutine()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// Stop shuts down the session manager's cleanup routine
|
||||
func (sm *SessionManager) Stop() {
|
||||
close(sm.cleanupDone)
|
||||
}
|
||||
|
||||
func (sm *SessionManager) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-sm.cleanupTicker.C:
|
||||
sm.CleanupExpired()
|
||||
case <-sm.cleanupDone:
|
||||
sm.cleanupTicker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist
|
||||
func (sm *SessionManager) GetSession(id string) *Session {
|
||||
if id != "" {
|
||||
if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 {
|
||||
if s, err := Unmarshal(data); err == nil && !s.IsExpired() {
|
||||
s.UpdateLastUsed()
|
||||
s.ResetDirty()
|
||||
return s
|
||||
}
|
||||
sm.cache.Del([]byte(id))
|
||||
}
|
||||
}
|
||||
return sm.CreateSession()
|
||||
}
|
||||
|
||||
// CreateSession generates a new session with a unique ID
|
||||
func (sm *SessionManager) CreateSession() *Session {
|
||||
id, _ := gonanoid.New()
|
||||
|
||||
// Ensure uniqueness (max 3 attempts)
|
||||
for i := 0; i < 3 && sm.cache.Has([]byte(id)); i++ {
|
||||
id, _ = gonanoid.New()
|
||||
}
|
||||
|
||||
s := NewSession(id, sm.cookieMaxAge)
|
||||
if data, err := s.Marshal(); err == nil {
|
||||
sm.cache.Set([]byte(id), data)
|
||||
}
|
||||
s.ResetDirty()
|
||||
return s
|
||||
}
|
||||
|
||||
// DestroySession removes a session
|
||||
func (sm *SessionManager) DestroySession(id string) {
|
||||
if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 {
|
||||
if s, err := Unmarshal(data); err == nil {
|
||||
s.Release()
|
||||
}
|
||||
}
|
||||
sm.cache.Del([]byte(id))
|
||||
}
|
||||
|
||||
// CleanupExpired removes all expired sessions
|
||||
func (sm *SessionManager) CleanupExpired() int {
|
||||
// fastcache doesn't support iteration
|
||||
return 0
|
||||
}
|
||||
|
||||
// SetCookieOptions configures cookie parameters
|
||||
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
|
||||
sm.cookieMu.Lock()
|
||||
sm.cookieName = name
|
||||
sm.cookiePath = path
|
||||
sm.cookieDomain = domain
|
||||
sm.cookieSecure = secure
|
||||
sm.cookieHTTPOnly = httpOnly
|
||||
sm.cookieMaxAge = maxAge
|
||||
sm.cookieMu.Unlock()
|
||||
}
|
||||
|
||||
// GetSessionFromRequest extracts the session from a request
|
||||
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
|
||||
sm.cookieMu.RLock()
|
||||
name := sm.cookieName
|
||||
sm.cookieMu.RUnlock()
|
||||
|
||||
if cookie := ctx.Request.Header.Cookie(name); len(cookie) > 0 {
|
||||
return sm.GetSession(string(cookie))
|
||||
}
|
||||
return sm.CreateSession()
|
||||
}
|
||||
|
||||
// ApplySessionCookie adds the session cookie to the response
|
||||
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
|
||||
if session.IsDirty() {
|
||||
if data, err := session.Marshal(); err == nil {
|
||||
sm.cache.Set([]byte(session.ID), data)
|
||||
}
|
||||
session.ResetDirty()
|
||||
}
|
||||
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
|
||||
sm.cookieMu.RLock()
|
||||
cookie.SetKey(sm.cookieName)
|
||||
cookie.SetPath(sm.cookiePath)
|
||||
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
|
||||
cookie.SetMaxAge(sm.cookieMaxAge)
|
||||
if sm.cookieDomain != "" {
|
||||
cookie.SetDomain(sm.cookieDomain)
|
||||
}
|
||||
cookie.SetSecure(sm.cookieSecure)
|
||||
sm.cookieMu.RUnlock()
|
||||
|
||||
cookie.SetValue(session.ID)
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// CookieOptions returns the cookie options for this session manager
|
||||
func (sm *SessionManager) CookieOptions() map[string]any {
|
||||
sm.cookieMu.RLock()
|
||||
defer sm.cookieMu.RUnlock()
|
||||
|
||||
return map[string]any{
|
||||
"name": sm.cookieName,
|
||||
"path": sm.cookiePath,
|
||||
"domain": sm.cookieDomain,
|
||||
"secure": sm.cookieSecure,
|
||||
"http_only": sm.cookieHTTPOnly,
|
||||
"max_age": sm.cookieMaxAge,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCacheStats returns statistics about the session cache
|
||||
func (sm *SessionManager) GetCacheStats() map[string]uint64 {
|
||||
if sm == nil || sm.cache == nil {
|
||||
return map[string]uint64{}
|
||||
}
|
||||
|
||||
var stats fastcache.Stats
|
||||
sm.cache.UpdateStats(&stats)
|
||||
|
||||
return map[string]uint64{
|
||||
"entries": stats.EntriesCount,
|
||||
"bytes": stats.BytesSize,
|
||||
"max_bytes": stats.MaxBytesSize,
|
||||
"gets": stats.GetCalls,
|
||||
"sets": stats.SetCalls,
|
||||
"misses": stats.Misses,
|
||||
}
|
||||
}
|
||||
|
||||
// GlobalSessionManager is the default session manager instance
|
||||
var GlobalSessionManager = NewSessionManager(DefaultMaxSessions)
|
||||
434
sessions/session.go
Normal file
434
sessions/session.go
Normal file
@ -0,0 +1,434 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/deneonet/benc"
|
||||
bstd "github.com/deneonet/benc/std"
|
||||
)
|
||||
|
||||
// Session stores data for a single user session
|
||||
type Session struct {
|
||||
ID string
|
||||
Data map[string]any
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
LastUsed time.Time
|
||||
Expiry time.Time
|
||||
dirty bool
|
||||
}
|
||||
|
||||
var (
|
||||
sessionPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Session{Data: make(map[string]any, 8)}
|
||||
},
|
||||
}
|
||||
bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
|
||||
)
|
||||
|
||||
// NewSession creates a new session with the given ID
|
||||
func NewSession(id string, maxAge int) *Session {
|
||||
s := sessionPool.Get().(*Session)
|
||||
now := time.Now()
|
||||
*s = Session{
|
||||
ID: id,
|
||||
Data: s.Data, // Reuse map
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
LastUsed: now,
|
||||
Expiry: now.Add(time.Duration(maxAge) * time.Second),
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Release returns the session to the pool
|
||||
func (s *Session) Release() {
|
||||
for k := range s.Data {
|
||||
delete(s.Data, k)
|
||||
}
|
||||
sessionPool.Put(s)
|
||||
}
|
||||
|
||||
// Get returns a deep copy of a value
|
||||
func (s *Session) Get(key string) any {
|
||||
if v, ok := s.Data[key]; ok {
|
||||
return deepCopy(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTable returns a value as a table
|
||||
func (s *Session) GetTable(key string) map[string]any {
|
||||
if v := s.Get(key); v != nil {
|
||||
if t, ok := v.(map[string]any); ok {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAll returns a deep copy of all session data
|
||||
func (s *Session) GetAll() map[string]any {
|
||||
copy := make(map[string]any, len(s.Data))
|
||||
for k, v := range s.Data {
|
||||
copy[k] = deepCopy(v)
|
||||
}
|
||||
return copy
|
||||
}
|
||||
|
||||
// Set stores a value in the session
|
||||
func (s *Session) Set(key string, value any) {
|
||||
if existing, ok := s.Data[key]; ok && deepEqual(existing, value) {
|
||||
return // No change
|
||||
}
|
||||
s.Data[key] = value
|
||||
s.UpdatedAt = time.Now()
|
||||
s.dirty = true
|
||||
}
|
||||
|
||||
// SetSafe stores a value with validation
|
||||
func (s *Session) SetSafe(key string, value any) error {
|
||||
if err := validate(value); err != nil {
|
||||
return fmt.Errorf("session.SetSafe: %w", err)
|
||||
}
|
||||
s.Set(key, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTable is a convenience method for setting table data
|
||||
func (s *Session) SetTable(key string, table map[string]any) error {
|
||||
return s.SetSafe(key, table)
|
||||
}
|
||||
|
||||
// Delete removes a value from the session
|
||||
func (s *Session) Delete(key string) {
|
||||
delete(s.Data, key)
|
||||
s.UpdatedAt = time.Now()
|
||||
s.dirty = true
|
||||
}
|
||||
|
||||
// Clear removes all data from the session
|
||||
func (s *Session) Clear() {
|
||||
s.Data = make(map[string]any, 8)
|
||||
s.UpdatedAt = time.Now()
|
||||
s.dirty = true
|
||||
}
|
||||
|
||||
// IsExpired checks if the session has expired
|
||||
func (s *Session) IsExpired() bool {
|
||||
return time.Now().After(s.Expiry)
|
||||
}
|
||||
|
||||
// UpdateLastUsed updates the last used time
|
||||
func (s *Session) UpdateLastUsed() {
|
||||
now := time.Now()
|
||||
if now.Sub(s.LastUsed) > 5*time.Second {
|
||||
s.LastUsed = now
|
||||
}
|
||||
}
|
||||
|
||||
// IsDirty returns if the session has unsaved changes
|
||||
func (s *Session) IsDirty() bool {
|
||||
return s.dirty
|
||||
}
|
||||
|
||||
// ResetDirty marks the session as clean after saving
|
||||
func (s *Session) ResetDirty() {
|
||||
s.dirty = false
|
||||
}
|
||||
|
||||
// SizePlain calculates the size needed to marshal the session
|
||||
func (s *Session) SizePlain() int {
|
||||
return bstd.SizeString(s.ID) +
|
||||
bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) +
|
||||
bstd.SizeInt64()*4
|
||||
}
|
||||
|
||||
// MarshalPlain serializes the session to binary
|
||||
func (s *Session) MarshalPlain(n int, b []byte) int {
|
||||
n = bstd.MarshalString(n, b, s.ID)
|
||||
n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny)
|
||||
n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix())
|
||||
n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix())
|
||||
n = bstd.MarshalInt64(n, b, s.LastUsed.Unix())
|
||||
return bstd.MarshalInt64(n, b, s.Expiry.Unix())
|
||||
}
|
||||
|
||||
// UnmarshalPlain deserializes the session from binary
|
||||
func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) {
|
||||
var err error
|
||||
n, s.ID, err = bstd.UnmarshalString(n, b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
var ts int64
|
||||
for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} {
|
||||
n, ts, err = bstd.UnmarshalInt64(n, b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
*t = time.Unix(ts, 0)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Marshal serializes the session using benc
|
||||
func (s *Session) Marshal() ([]byte, error) {
|
||||
return bufPool.Marshal(s.SizePlain(), func(b []byte) int {
|
||||
return s.MarshalPlain(0, b)
|
||||
})
|
||||
}
|
||||
|
||||
// Unmarshal deserializes a session using benc
|
||||
func Unmarshal(data []byte) (*Session, error) {
|
||||
s := sessionPool.Get().(*Session)
|
||||
if _, err := s.UnmarshalPlain(0, data); err != nil {
|
||||
s.Release()
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Type identifiers
|
||||
const (
|
||||
typeNull byte = 0
|
||||
typeString byte = 1
|
||||
typeInt byte = 2
|
||||
typeFloat byte = 3
|
||||
typeBool byte = 4
|
||||
typeBytes byte = 5
|
||||
typeTable byte = 6
|
||||
typeArray byte = 7
|
||||
)
|
||||
|
||||
// sizeAny calculates the size needed for any value
|
||||
func sizeAny(v any) int {
|
||||
if v == nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
size := 1 // type byte
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
size += bstd.SizeString(v)
|
||||
case int:
|
||||
size += bstd.SizeInt64()
|
||||
case int64:
|
||||
size += bstd.SizeInt64()
|
||||
case float64:
|
||||
size += bstd.SizeFloat64()
|
||||
case bool:
|
||||
size += bstd.SizeBool()
|
||||
case []byte:
|
||||
size += bstd.SizeBytes(v)
|
||||
case map[string]any:
|
||||
size += bstd.SizeMap(v, bstd.SizeString, sizeAny)
|
||||
case []any:
|
||||
size += bstd.SizeSlice(v, sizeAny)
|
||||
default:
|
||||
size += bstd.SizeString("unknown")
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// marshalAny serializes any value
|
||||
func marshalAny(n int, b []byte, v any) int {
|
||||
if v == nil {
|
||||
b[n] = typeNull
|
||||
return n + 1
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
b[n] = typeString
|
||||
return bstd.MarshalString(n+1, b, v)
|
||||
case int:
|
||||
b[n] = typeInt
|
||||
return bstd.MarshalInt64(n+1, b, int64(v))
|
||||
case int64:
|
||||
b[n] = typeInt
|
||||
return bstd.MarshalInt64(n+1, b, v)
|
||||
case float64:
|
||||
b[n] = typeFloat
|
||||
return bstd.MarshalFloat64(n+1, b, v)
|
||||
case bool:
|
||||
b[n] = typeBool
|
||||
return bstd.MarshalBool(n+1, b, v)
|
||||
case []byte:
|
||||
b[n] = typeBytes
|
||||
return bstd.MarshalBytes(n+1, b, v)
|
||||
case map[string]any:
|
||||
b[n] = typeTable
|
||||
return bstd.MarshalMap(n+1, b, v, bstd.MarshalString, marshalAny)
|
||||
case []any:
|
||||
b[n] = typeArray
|
||||
return bstd.MarshalSlice(n+1, b, v, marshalAny)
|
||||
default:
|
||||
b[n] = typeString
|
||||
return bstd.MarshalString(n+1, b, "unknown")
|
||||
}
|
||||
}
|
||||
|
||||
// unmarshalAny deserializes any value
|
||||
func unmarshalAny(n int, b []byte) (int, any, error) {
|
||||
if len(b) <= n {
|
||||
return n, nil, benc.ErrBufTooSmall
|
||||
}
|
||||
|
||||
switch b[n] {
|
||||
case typeNull:
|
||||
return n + 1, nil, nil
|
||||
case typeString:
|
||||
return bstd.UnmarshalString(n+1, b)
|
||||
case typeInt:
|
||||
n, v, err := bstd.UnmarshalInt64(n+1, b)
|
||||
return n, v, err
|
||||
case typeFloat:
|
||||
return bstd.UnmarshalFloat64(n+1, b)
|
||||
case typeBool:
|
||||
return bstd.UnmarshalBool(n+1, b)
|
||||
case typeBytes:
|
||||
return bstd.UnmarshalBytesCopied(n+1, b)
|
||||
case typeTable:
|
||||
return bstd.UnmarshalMap[string, any](n+1, b, bstd.UnmarshalString, unmarshalAny)
|
||||
case typeArray:
|
||||
return bstd.UnmarshalSlice[any](n+1, b, unmarshalAny)
|
||||
default:
|
||||
return n + 1, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// deepCopy creates a deep copy of any value
|
||||
func deepCopy(v any) any {
|
||||
switch v := v.(type) {
|
||||
case map[string]any:
|
||||
cp := make(map[string]any, len(v))
|
||||
for k, val := range v {
|
||||
cp[k] = deepCopy(val)
|
||||
}
|
||||
return cp
|
||||
case []any:
|
||||
cp := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
cp[i] = deepCopy(val)
|
||||
}
|
||||
return cp
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// validate ensures a value can be safely serialized
|
||||
func validate(v any) error {
|
||||
switch v := v.(type) {
|
||||
case nil, string, int, int64, float64, bool, []byte:
|
||||
return nil
|
||||
case map[string]any:
|
||||
for k, val := range v {
|
||||
if err := validate(val); err != nil {
|
||||
return fmt.Errorf("invalid value for key %q: %w", k, err)
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for i, val := range v {
|
||||
if err := validate(val); err != nil {
|
||||
return fmt.Errorf("invalid value at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported type: %T", v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deepEqual efficiently compares two values for deep equality
|
||||
func deepEqual(a, b any) bool {
|
||||
if a == b {
|
||||
return true
|
||||
}
|
||||
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch va := a.(type) {
|
||||
case string:
|
||||
if vb, ok := b.(string); ok {
|
||||
return va == vb
|
||||
}
|
||||
case int:
|
||||
if vb, ok := b.(int); ok {
|
||||
return va == vb
|
||||
}
|
||||
if vb, ok := b.(int64); ok {
|
||||
return int64(va) == vb
|
||||
}
|
||||
case int64:
|
||||
if vb, ok := b.(int64); ok {
|
||||
return va == vb
|
||||
}
|
||||
if vb, ok := b.(int); ok {
|
||||
return va == int64(vb)
|
||||
}
|
||||
case float64:
|
||||
if vb, ok := b.(float64); ok {
|
||||
return va == vb
|
||||
}
|
||||
case bool:
|
||||
if vb, ok := b.(bool); ok {
|
||||
return va == vb
|
||||
}
|
||||
case []byte:
|
||||
if vb, ok := b.([]byte); ok {
|
||||
if len(va) != len(vb) {
|
||||
return false
|
||||
}
|
||||
for i, v := range va {
|
||||
if v != vb[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
case map[string]any:
|
||||
if vb, ok := b.(map[string]any); ok {
|
||||
if len(va) != len(vb) {
|
||||
return false
|
||||
}
|
||||
for k, v := range va {
|
||||
if bv, exists := vb[k]; !exists || !deepEqual(v, bv) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
case []any:
|
||||
if vb, ok := b.([]any); ok {
|
||||
if len(va) != len(vb) {
|
||||
return false
|
||||
}
|
||||
for i, v := range va {
|
||||
if !deepEqual(v, vb[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the session has no data
|
||||
func (s *Session) IsEmpty() bool {
|
||||
return len(s.Data) == 0
|
||||
}
|
||||
@ -1,93 +0,0 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// BytecodeEntry holds compiled bytecode with metadata
|
||||
type BytecodeEntry struct {
|
||||
Bytecode []byte
|
||||
Name string
|
||||
Hash [32]byte
|
||||
}
|
||||
|
||||
// Global bytecode cache
|
||||
var (
|
||||
bytecodeCache = struct {
|
||||
sync.RWMutex
|
||||
entries map[string]*BytecodeEntry
|
||||
}{
|
||||
entries: make(map[string]*BytecodeEntry),
|
||||
}
|
||||
)
|
||||
|
||||
// CompileAndCache compiles code to bytecode and stores it globally
|
||||
func CompileAndCache(state *luajit.State, code, name string) (*BytecodeEntry, error) {
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
cacheKey := fmt.Sprintf("%x", hash)
|
||||
|
||||
// Check cache first
|
||||
bytecodeCache.RLock()
|
||||
if entry, exists := bytecodeCache.entries[cacheKey]; exists {
|
||||
bytecodeCache.RUnlock()
|
||||
return entry, nil
|
||||
}
|
||||
bytecodeCache.RUnlock()
|
||||
|
||||
// Compile bytecode
|
||||
bytecode, err := state.CompileBytecode(code, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store in cache
|
||||
entry := &BytecodeEntry{
|
||||
Bytecode: bytecode,
|
||||
Name: name,
|
||||
Hash: hash,
|
||||
}
|
||||
|
||||
bytecodeCache.Lock()
|
||||
bytecodeCache.entries[cacheKey] = entry
|
||||
bytecodeCache.Unlock()
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// GetCached retrieves bytecode from cache by code hash
|
||||
func GetCached(code string) (*BytecodeEntry, bool) {
|
||||
hash := sha256.Sum256([]byte(code))
|
||||
cacheKey := fmt.Sprintf("%x", hash)
|
||||
|
||||
bytecodeCache.RLock()
|
||||
defer bytecodeCache.RUnlock()
|
||||
|
||||
entry, exists := bytecodeCache.entries[cacheKey]
|
||||
return entry, exists
|
||||
}
|
||||
|
||||
// ClearCache removes all cached bytecode entries
|
||||
func ClearCache() {
|
||||
bytecodeCache.Lock()
|
||||
defer bytecodeCache.Unlock()
|
||||
|
||||
bytecodeCache.entries = make(map[string]*BytecodeEntry)
|
||||
}
|
||||
|
||||
// CacheStats returns cache statistics
|
||||
func CacheStats() (int, int64) {
|
||||
bytecodeCache.RLock()
|
||||
defer bytecodeCache.RUnlock()
|
||||
|
||||
count := len(bytecodeCache.entries)
|
||||
var totalSize int64
|
||||
for _, entry := range bytecodeCache.entries {
|
||||
totalSize += int64(len(entry.Bytecode))
|
||||
}
|
||||
|
||||
return count, totalSize
|
||||
}
|
||||
327
state/state.go
327
state/state.go
@ -1,327 +0,0 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"Moonshark/modules"
|
||||
"Moonshark/modules/http"
|
||||
"Moonshark/modules/kv"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// State wraps luajit.State with enhanced functionality
|
||||
type State struct {
|
||||
*luajit.State
|
||||
initialized bool
|
||||
isWorker bool
|
||||
scriptDir string
|
||||
}
|
||||
|
||||
// Config holds state initialization options
|
||||
type Config struct {
|
||||
OpenLibs bool
|
||||
InstallModules bool
|
||||
ScriptDir string
|
||||
IsWorker bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns default configuration
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
OpenLibs: true,
|
||||
InstallModules: true,
|
||||
IsWorker: false,
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new enhanced Lua state
|
||||
func New(config ...Config) (*State, error) {
|
||||
cfg := DefaultConfig()
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
|
||||
// Create base state
|
||||
baseState := luajit.New(cfg.OpenLibs)
|
||||
if baseState == nil {
|
||||
return nil, fmt.Errorf("failed to create Lua state")
|
||||
}
|
||||
|
||||
state := &State{
|
||||
State: baseState,
|
||||
isWorker: cfg.IsWorker,
|
||||
scriptDir: cfg.ScriptDir,
|
||||
}
|
||||
|
||||
// Set worker global flag if this is a worker state
|
||||
if cfg.IsWorker {
|
||||
state.PushBoolean(true)
|
||||
state.SetGlobal("__IS_WORKER")
|
||||
}
|
||||
|
||||
// Install module system if requested
|
||||
if cfg.InstallModules {
|
||||
if err := state.initializeModules(); err != nil {
|
||||
state.Close()
|
||||
return nil, fmt.Errorf("failed to install modules: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set script directory if provided
|
||||
if cfg.ScriptDir != "" {
|
||||
if err := state.SetScriptDirectory(cfg.ScriptDir); err != nil {
|
||||
state.Close()
|
||||
return nil, fmt.Errorf("failed to set script directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
state.initialized = true
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// NewFromScript creates a state configured for a specific script file
|
||||
func NewFromScript(scriptPath string, config ...Config) (*State, error) {
|
||||
cfg := DefaultConfig()
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
|
||||
// Set script directory from file path
|
||||
scriptDir := filepath.Dir(scriptPath)
|
||||
absScriptDir, err := filepath.Abs(scriptDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
cfg.ScriptDir = absScriptDir
|
||||
|
||||
return New(cfg)
|
||||
}
|
||||
|
||||
// NewWorker creates a new worker state with __IS_WORKER flag set
|
||||
func NewWorker(config ...Config) (*State, error) {
|
||||
cfg := DefaultConfig()
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
cfg.IsWorker = true
|
||||
|
||||
return New(cfg)
|
||||
}
|
||||
|
||||
// NewWorkerFromScript creates a worker state configured for a specific script
|
||||
func NewWorkerFromScript(scriptPath string, config ...Config) (*State, error) {
|
||||
cfg := DefaultConfig()
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
cfg.IsWorker = true
|
||||
|
||||
scriptDir := filepath.Dir(scriptPath)
|
||||
absScriptDir, err := filepath.Abs(scriptDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
cfg.ScriptDir = absScriptDir
|
||||
|
||||
return New(cfg)
|
||||
}
|
||||
|
||||
// Store main state initialization for worker replication
|
||||
var (
|
||||
mainStateScriptDir string
|
||||
mainStateScript string
|
||||
mainStateScriptName string
|
||||
)
|
||||
|
||||
// initializeModules sets up the module system
|
||||
func (s *State) initializeModules() error {
|
||||
// Initialize global registry if needed
|
||||
if modules.Global == nil {
|
||||
if err := modules.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Install modules first
|
||||
if err := modules.Global.InstallInState(s.State); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupStateCreator sets up the state creator after script is loaded
|
||||
func (s *State) SetupStateCreator() {
|
||||
if s.isWorker {
|
||||
return
|
||||
}
|
||||
|
||||
http.SetStateCreator(func() (*luajit.State, error) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.IsWorker = true
|
||||
cfg.ScriptDir = mainStateScriptDir
|
||||
|
||||
workerState, err := New(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the same script as main state to get identical environment
|
||||
if mainStateScript != "" {
|
||||
if err := workerState.ExecuteString(mainStateScript, mainStateScriptName); err != nil {
|
||||
workerState.Close()
|
||||
return nil, fmt.Errorf("failed to execute script in worker: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return workerState.State, nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetScriptDirectory adds a directory to Lua's package.path
|
||||
func (s *State) SetScriptDirectory(dir string) error {
|
||||
packagePath := filepath.Join(dir, "?.lua")
|
||||
return s.AddPackagePath(packagePath)
|
||||
}
|
||||
|
||||
// ExecuteFile compiles and runs a Lua script file
|
||||
func (s *State) ExecuteFile(scriptPath string) error {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(scriptPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("script file '%s' not found", scriptPath)
|
||||
}
|
||||
|
||||
// Read script content
|
||||
scriptContent, err := os.ReadFile(scriptPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read script file '%s': %w", scriptPath, err)
|
||||
}
|
||||
|
||||
// Store for worker replication if this is main state
|
||||
if !s.isWorker {
|
||||
mainStateScript = string(scriptContent)
|
||||
mainStateScriptName = scriptPath
|
||||
mainStateScriptDir = s.scriptDir
|
||||
// Set up state creator now that we have the script
|
||||
s.SetupStateCreator()
|
||||
}
|
||||
|
||||
return s.ExecuteString(string(scriptContent), scriptPath)
|
||||
}
|
||||
|
||||
// ExecuteString compiles and runs Lua code with bytecode caching
|
||||
func (s *State) ExecuteString(code, name string) error {
|
||||
entry, err := CompileAndCache(s.State, code, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compilation error in '%s': %w", name, err)
|
||||
}
|
||||
|
||||
if err := s.LoadAndRunBytecode(entry.Bytecode, name); err != nil {
|
||||
return fmt.Errorf("execution error in '%s': %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteFileWithResults executes a script and returns results
|
||||
func (s *State) ExecuteFileWithResults(scriptPath string) ([]any, error) {
|
||||
if _, err := os.Stat(scriptPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("script file '%s' not found", scriptPath)
|
||||
}
|
||||
|
||||
scriptContent, err := os.ReadFile(scriptPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read script file '%s': %w", scriptPath, err)
|
||||
}
|
||||
|
||||
return s.ExecuteStringWithResults(string(scriptContent), scriptPath)
|
||||
}
|
||||
|
||||
// ExecuteStringWithResults executes code and returns all results with bytecode caching
|
||||
func (s *State) ExecuteStringWithResults(code, name string) ([]any, error) {
|
||||
baseTop := s.GetTop()
|
||||
defer s.SetTop(baseTop)
|
||||
|
||||
entry, err := CompileAndCache(s.State, code, name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compilation error in '%s': %w", name, err)
|
||||
}
|
||||
|
||||
if err := s.LoadAndRunBytecodeWithResults(entry.Bytecode, name, -1); err != nil {
|
||||
return nil, fmt.Errorf("execution error in '%s': %w", name, err)
|
||||
}
|
||||
|
||||
nresults := s.GetTop() - baseTop
|
||||
if nresults == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
results := make([]any, nresults)
|
||||
for i := range nresults {
|
||||
val, err := s.ToValue(baseTop + i + 1)
|
||||
if err != nil {
|
||||
results[i] = nil
|
||||
} else {
|
||||
results[i] = val
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// IsInitialized returns whether the state has been properly initialized
|
||||
func (s *State) IsInitialized() bool {
|
||||
return s.initialized
|
||||
}
|
||||
|
||||
// IsWorker returns whether this state is a worker state
|
||||
func (s *State) IsWorker() bool {
|
||||
return s.isWorker
|
||||
}
|
||||
|
||||
// Close cleans up the state and releases resources
|
||||
func (s *State) Close() {
|
||||
if s.State != nil {
|
||||
if !s.isWorker {
|
||||
kv.CloseAllStores()
|
||||
}
|
||||
|
||||
s.Cleanup()
|
||||
s.State.Close()
|
||||
s.State = nil
|
||||
}
|
||||
s.initialized = false
|
||||
}
|
||||
|
||||
// Quick creates a minimal state for one-off script execution
|
||||
func Quick() (*State, error) {
|
||||
return New(Config{
|
||||
OpenLibs: true,
|
||||
InstallModules: true,
|
||||
})
|
||||
}
|
||||
|
||||
// QuickExecute creates a state, executes code, and cleans up
|
||||
func QuickExecute(code, name string) error {
|
||||
state, err := Quick()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
return state.ExecuteString(code, name)
|
||||
}
|
||||
|
||||
// QuickExecuteWithResults executes code and returns results
|
||||
func QuickExecuteWithResults(code, name string) ([]any, error) {
|
||||
state, err := Quick()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
return state.ExecuteStringWithResults(code, name)
|
||||
}
|
||||
507
tests/crypto.lua
507
tests/crypto.lua
@ -1,507 +0,0 @@
|
||||
require("tests")
|
||||
local crypto = require("crypto")
|
||||
|
||||
-- Test data
|
||||
local test_data = "Hello, World!"
|
||||
local test_key = "secret-key-123"
|
||||
|
||||
-- ======================================================================
|
||||
-- ENCODING/DECODING TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Base64 Encoding/Decoding", function()
|
||||
local encoded = crypto.base64_encode(test_data)
|
||||
assert_equal(type(encoded), "string")
|
||||
assert(#encoded > 0, "encoded string should not be empty")
|
||||
|
||||
local decoded = crypto.base64_decode(encoded)
|
||||
assert_equal(decoded, test_data)
|
||||
end)
|
||||
|
||||
test("Base64 URL Encoding/Decoding", function()
|
||||
local encoded = crypto.base64_url_encode(test_data)
|
||||
assert_equal(type(encoded), "string")
|
||||
|
||||
local decoded = crypto.base64_url_decode(encoded)
|
||||
assert_equal(decoded, test_data)
|
||||
end)
|
||||
|
||||
test("Hex Encoding/Decoding", function()
|
||||
local encoded = crypto.hex_encode(test_data)
|
||||
assert_equal(type(encoded), "string")
|
||||
assert(encoded:match("^[0-9a-f]+$"), "hex should only contain hex characters")
|
||||
|
||||
local decoded = crypto.hex_decode(encoded)
|
||||
assert_equal(decoded, test_data)
|
||||
end)
|
||||
|
||||
test("Encoding Chain", function()
|
||||
local chain_encoded = crypto.encode_chain(test_data, {"hex", "base64"})
|
||||
local chain_decoded = crypto.decode_chain(chain_encoded, {"hex", "base64"})
|
||||
assert_equal(chain_decoded, test_data)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- HASHING TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("MD5 Hash", function()
|
||||
local hash = crypto.md5(test_data)
|
||||
assert_equal(type(hash), "string")
|
||||
assert_equal(#hash, 32) -- MD5 is 32 hex characters
|
||||
assert(hash:match("^[0-9a-f]+$"), "hash should be hex")
|
||||
|
||||
-- Same input should produce same hash
|
||||
assert_equal(crypto.md5(test_data), hash)
|
||||
end)
|
||||
|
||||
test("SHA1 Hash", function()
|
||||
local hash = crypto.sha1(test_data)
|
||||
assert_equal(type(hash), "string")
|
||||
assert_equal(#hash, 40) -- SHA1 is 40 hex characters
|
||||
assert(hash:match("^[0-9a-f]+$"), "hash should be hex")
|
||||
end)
|
||||
|
||||
test("SHA256 Hash", function()
|
||||
local hash = crypto.sha256(test_data)
|
||||
assert_equal(type(hash), "string")
|
||||
assert_equal(#hash, 64) -- SHA256 is 64 hex characters
|
||||
assert(hash:match("^[0-9a-f]+$"), "hash should be hex")
|
||||
end)
|
||||
|
||||
test("SHA512 Hash", function()
|
||||
local hash = crypto.sha512(test_data)
|
||||
assert_equal(type(hash), "string")
|
||||
assert_equal(#hash, 128) -- SHA512 is 128 hex characters
|
||||
assert(hash:match("^[0-9a-f]+$"), "hash should be hex")
|
||||
end)
|
||||
|
||||
test("Hash Multiple Inputs", function()
|
||||
local hash1 = crypto.hash_multiple({"hello", "world"})
|
||||
local hash2 = crypto.hash_multiple({"hello", "world"})
|
||||
local hash3 = crypto.hash_multiple({"world", "hello"})
|
||||
|
||||
assert_equal(hash1, hash2)
|
||||
assert(hash1 ~= hash3, "different order should produce different hash")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- HMAC TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("HMAC SHA1", function()
|
||||
local hmac = crypto.hmac_sha1(test_data, test_key)
|
||||
assert_equal(type(hmac), "string")
|
||||
assert_equal(#hmac, 40) -- SHA1 HMAC is 40 hex characters
|
||||
assert(hmac:match("^[0-9a-f]+$"), "hmac should be hex")
|
||||
end)
|
||||
|
||||
test("HMAC SHA256", function()
|
||||
local hmac = crypto.hmac_sha256(test_data, test_key)
|
||||
assert_equal(type(hmac), "string")
|
||||
assert_equal(#hmac, 64) -- SHA256 HMAC is 64 hex characters
|
||||
assert(hmac:match("^[0-9a-f]+$"), "hmac should be hex")
|
||||
end)
|
||||
|
||||
test("MAC Functions", function()
|
||||
local mac = crypto.mac(test_data, test_key)
|
||||
assert_equal(type(mac), "string")
|
||||
assert(#mac > 0, "mac should not be empty")
|
||||
|
||||
local valid = crypto.verify_mac(test_data, test_key, mac)
|
||||
assert_equal(valid, true)
|
||||
|
||||
local invalid = crypto.verify_mac("different data", test_key, mac)
|
||||
assert_equal(invalid, false)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- UUID TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("UUID Generation", function()
|
||||
local uuid1 = crypto.uuid()
|
||||
local uuid2 = crypto.uuid_v4()
|
||||
|
||||
assert_equal(type(uuid1), "string")
|
||||
assert_equal(type(uuid2), "string")
|
||||
assert_equal(#uuid1, 36) -- Standard UUID length
|
||||
assert_equal(#uuid2, 36)
|
||||
assert(uuid1 ~= uuid2, "UUIDs should be unique")
|
||||
|
||||
assert_equal(crypto.is_uuid(uuid1), true)
|
||||
assert_equal(crypto.is_uuid(uuid2), true)
|
||||
assert_equal(crypto.is_uuid("not-a-uuid"), false)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- RANDOM GENERATION TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Random Bytes", function()
|
||||
local bytes1 = crypto.random_bytes(16)
|
||||
local bytes2 = crypto.random_bytes(16)
|
||||
|
||||
assert_equal(type(bytes1), "string")
|
||||
assert_equal(#bytes1, 16)
|
||||
assert(bytes1 ~= bytes2, "random bytes should be different")
|
||||
end)
|
||||
|
||||
test("Random Hex", function()
|
||||
local hex1 = crypto.random_hex(8)
|
||||
local hex2 = crypto.random_hex(8)
|
||||
|
||||
assert_equal(type(hex1), "string")
|
||||
assert_equal(#hex1, 16) -- 8 bytes = 16 hex characters
|
||||
assert(hex1:match("^[0-9a-f]+$"), "should be hex")
|
||||
assert(hex1 ~= hex2, "random hex should be different")
|
||||
end)
|
||||
|
||||
test("Random String", function()
|
||||
local str1 = crypto.random_string(10)
|
||||
local str2 = crypto.random_string(10)
|
||||
|
||||
assert_equal(type(str1), "string")
|
||||
assert_equal(#str1, 10)
|
||||
assert(str1 ~= str2, "random strings should be different")
|
||||
|
||||
local custom = crypto.random_string(5, "abc")
|
||||
assert_equal(#custom, 5)
|
||||
assert(custom:match("^[abc]+$"), "should only contain specified characters")
|
||||
end)
|
||||
|
||||
test("Random Alphanumeric", function()
|
||||
local str = crypto.random_alphanumeric(20)
|
||||
assert_equal(#str, 20)
|
||||
assert(str:match("^[a-zA-Z0-9]+$"), "should be alphanumeric")
|
||||
end)
|
||||
|
||||
test("Random Password", function()
|
||||
local pass1 = crypto.random_password(12)
|
||||
local pass2 = crypto.random_password(12, true) -- with symbols
|
||||
|
||||
assert_equal(#pass1, 12)
|
||||
assert_equal(#pass2, 12)
|
||||
assert(pass1 ~= pass2, "passwords should be different")
|
||||
end)
|
||||
|
||||
test("Token Generation", function()
|
||||
local token1 = crypto.token(32)
|
||||
local token2 = crypto.token(32)
|
||||
|
||||
assert_equal(#token1, 64) -- 32 bytes = 64 hex characters
|
||||
assert(token1:match("^[0-9a-f]+$"), "token should be hex")
|
||||
assert(token1 ~= token2, "tokens should be unique")
|
||||
end)
|
||||
|
||||
test("Nonce Generation", function()
|
||||
local nonce1 = crypto.nonce()
|
||||
local nonce2 = crypto.nonce(32)
|
||||
|
||||
assert_equal(#nonce1, 32) -- default 16 bytes = 32 hex
|
||||
assert_equal(#nonce2, 64) -- 32 bytes = 64 hex
|
||||
assert(nonce1 ~= nonce2, "nonces should be unique")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- UTILITY TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Secure Compare", function()
|
||||
local str1 = "hello"
|
||||
local str2 = "hello"
|
||||
local str3 = "world"
|
||||
|
||||
assert_equal(crypto.secure_compare(str1, str2), true)
|
||||
assert_equal(crypto.secure_compare(str1, str3), false)
|
||||
end)
|
||||
|
||||
test("Checksum Functions", function()
|
||||
local checksum = crypto.checksum(test_data)
|
||||
assert_equal(type(checksum), "string")
|
||||
|
||||
local valid = crypto.verify_checksum(test_data, checksum)
|
||||
assert_equal(valid, true)
|
||||
|
||||
local invalid = crypto.verify_checksum("different", checksum)
|
||||
assert_equal(invalid, false)
|
||||
end)
|
||||
|
||||
test("XOR Encryption", function()
|
||||
local key = "mykey"
|
||||
local encrypted = crypto.xor_encrypt(test_data, key)
|
||||
local decrypted = crypto.xor_decrypt(encrypted, key)
|
||||
|
||||
assert_equal(decrypted, test_data)
|
||||
assert(encrypted ~= test_data, "encrypted should be different")
|
||||
end)
|
||||
|
||||
test("Hash Chain", function()
|
||||
local chain1 = crypto.hash_chain(test_data, 100)
|
||||
local chain2 = crypto.hash_chain(test_data, 100)
|
||||
local chain3 = crypto.hash_chain(test_data, 101)
|
||||
|
||||
assert_equal(chain1, chain2)
|
||||
assert(chain1 ~= chain3, "different iterations should produce different results")
|
||||
end)
|
||||
|
||||
test("Key Derivation", function()
|
||||
local derived1, salt1 = crypto.derive_key("password", "salt", 1000)
|
||||
local derived2, salt2 = crypto.derive_key("password", salt1, 1000)
|
||||
|
||||
assert_equal(derived1, derived2)
|
||||
assert_equal(salt1, salt2)
|
||||
assert_equal(type(derived1), "string")
|
||||
assert(#derived1 > 0, "derived key should not be empty")
|
||||
end)
|
||||
|
||||
test("Fingerprint", function()
|
||||
local data = {name = "test", value = 42}
|
||||
local fp1 = crypto.fingerprint(data)
|
||||
local fp2 = crypto.fingerprint(data)
|
||||
|
||||
assert_equal(fp1, fp2)
|
||||
assert_equal(type(fp1), "string")
|
||||
assert(#fp1 > 0, "fingerprint should not be empty")
|
||||
end)
|
||||
|
||||
test("Integrity Check", function()
|
||||
local check = crypto.integrity_check(test_data)
|
||||
|
||||
assert_equal(check.data, test_data)
|
||||
assert_equal(type(check.hash), "string")
|
||||
assert_equal(type(check.timestamp), "number")
|
||||
assert_equal(type(check.uuid), "string")
|
||||
|
||||
local valid = crypto.verify_integrity(check)
|
||||
assert_equal(valid, true)
|
||||
|
||||
-- Tamper with data
|
||||
check.data = "tampered"
|
||||
local invalid = crypto.verify_integrity(check)
|
||||
assert_equal(invalid, false)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- ERROR HANDLING TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Error Handling", function()
|
||||
-- Invalid base64
|
||||
local success, err = pcall(crypto.base64_decode, "invalid===base64")
|
||||
assert_equal(success, false)
|
||||
|
||||
-- Invalid hex
|
||||
local success2, err2 = pcall(crypto.hex_decode, "invalid_hex")
|
||||
assert_equal(success2, false)
|
||||
|
||||
-- Invalid UUID validation (returns boolean, doesn't throw)
|
||||
assert_equal(crypto.is_uuid("not-a-uuid"), false)
|
||||
assert_equal(crypto.is_uuid(""), false)
|
||||
assert_equal(crypto.is_uuid("12345"), false)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- PASSWORD TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Password Hash and Verification", function()
|
||||
local password = "hubba-ba-loo117!@#"
|
||||
local hash = crypto.hash_password(password)
|
||||
local hash_fast = crypto.hash_password_fast(password)
|
||||
local hash_strong = crypto.hash_password_strong(password)
|
||||
|
||||
assert(crypto.verify_password(password, hash))
|
||||
assert(crypto.verify_password(password, hash_fast))
|
||||
assert(crypto.verify_password(password, hash_strong))
|
||||
|
||||
assert(not crypto.verify_password("failure", hash))
|
||||
assert(not crypto.verify_password("failure", hash_fast))
|
||||
assert(not crypto.verify_password("failure", hash_strong))
|
||||
end)
|
||||
|
||||
test("Algorithm-Specific Password Hashing", function()
|
||||
local password = "test123!@#"
|
||||
|
||||
-- Test each algorithm individually
|
||||
local argon2_hash = crypto.hash_password(password, "argon2id")
|
||||
local bcrypt_hash = crypto.hash_password(password, "bcrypt")
|
||||
local scrypt_hash = crypto.hash_password(password, "scrypt")
|
||||
local pbkdf2_hash = crypto.hash_password(password, "pbkdf2")
|
||||
|
||||
assert(crypto.verify_password(password, argon2_hash))
|
||||
assert(crypto.verify_password(password, bcrypt_hash))
|
||||
assert(crypto.verify_password(password, scrypt_hash))
|
||||
assert(crypto.verify_password(password, pbkdf2_hash))
|
||||
|
||||
-- Verify wrong passwords fail
|
||||
assert(not crypto.verify_password("wrong", argon2_hash))
|
||||
assert(not crypto.verify_password("wrong", bcrypt_hash))
|
||||
assert(not crypto.verify_password("wrong", scrypt_hash))
|
||||
assert(not crypto.verify_password("wrong", pbkdf2_hash))
|
||||
end)
|
||||
|
||||
test("Algorithm Detection", function()
|
||||
local password = "detectme123"
|
||||
|
||||
local argon2_hash = crypto.hash_password(password, "argon2id")
|
||||
local bcrypt_hash = crypto.hash_password(password, "bcrypt")
|
||||
local scrypt_hash = crypto.hash_password(password, "scrypt")
|
||||
local pbkdf2_hash = crypto.hash_password(password, "pbkdf2")
|
||||
|
||||
assert(crypto.detect_algorithm(argon2_hash) == "argon2id")
|
||||
assert(crypto.detect_algorithm(bcrypt_hash) == "bcrypt")
|
||||
assert(crypto.detect_algorithm(scrypt_hash) == "scrypt")
|
||||
assert(crypto.detect_algorithm(pbkdf2_hash) == "pbkdf2")
|
||||
assert(crypto.detect_algorithm("invalid$format") == "unknown")
|
||||
end)
|
||||
|
||||
test("Custom Algorithm Options", function()
|
||||
local password = "custom123"
|
||||
|
||||
-- Test custom argon2id options
|
||||
local custom_argon2 = crypto.hash_password(password, "argon2id", {
|
||||
time = 2,
|
||||
memory = 32768,
|
||||
threads = 2
|
||||
})
|
||||
assert(crypto.verify_password(password, custom_argon2))
|
||||
|
||||
-- Test custom bcrypt cost
|
||||
local custom_bcrypt = crypto.hash_password(password, "bcrypt", {cost = 10})
|
||||
assert(crypto.verify_password(password, custom_bcrypt))
|
||||
|
||||
-- Test custom scrypt parameters
|
||||
local custom_scrypt = crypto.hash_password(password, "scrypt", {
|
||||
N = 16384,
|
||||
r = 4,
|
||||
p = 2
|
||||
})
|
||||
assert(crypto.verify_password(password, custom_scrypt))
|
||||
|
||||
-- Test custom pbkdf2 iterations
|
||||
local custom_pbkdf2 = crypto.hash_password(password, "pbkdf2", {
|
||||
iterations = 50000
|
||||
})
|
||||
assert(crypto.verify_password(password, custom_pbkdf2))
|
||||
end)
|
||||
|
||||
test("Direct Algorithm Functions", function()
|
||||
local password = "direct123"
|
||||
|
||||
-- Test direct algorithm calls
|
||||
local argon2_direct = crypto.argon2_hash(password)
|
||||
local bcrypt_direct = crypto.bcrypt_hash(password)
|
||||
local scrypt_direct = crypto.scrypt_hash(password)
|
||||
local pbkdf2_direct = crypto.pbkdf2_hash(password)
|
||||
|
||||
assert(crypto.argon2_verify(password, argon2_direct))
|
||||
assert(crypto.bcrypt_verify(password, bcrypt_direct))
|
||||
assert(crypto.scrypt_verify(password, scrypt_direct))
|
||||
assert(crypto.pbkdf2_verify(password, pbkdf2_direct))
|
||||
|
||||
-- Test with custom options
|
||||
local argon2_custom = crypto.argon2_hash(password, {time = 1, memory = 16384})
|
||||
local scrypt_custom = crypto.scrypt_hash(password, {N = 8192})
|
||||
|
||||
assert(crypto.argon2_verify(password, argon2_custom))
|
||||
assert(crypto.scrypt_verify(password, scrypt_custom))
|
||||
end)
|
||||
|
||||
test("Security Level Presets", function()
|
||||
local password = "preset123"
|
||||
local algorithms = {"argon2id", "bcrypt", "scrypt", "pbkdf2"}
|
||||
|
||||
for _, algo in ipairs(algorithms) do
|
||||
local fast_hash = crypto.hash_password_fast(password, algo)
|
||||
local strong_hash = crypto.hash_password_strong(password, algo)
|
||||
|
||||
assert(crypto.verify_password(password, fast_hash))
|
||||
assert(crypto.verify_password(password, strong_hash))
|
||||
|
||||
-- Verify algorithm detection
|
||||
assert(crypto.detect_algorithm(fast_hash) == algo)
|
||||
assert(crypto.detect_algorithm(strong_hash) == algo)
|
||||
end
|
||||
end)
|
||||
|
||||
test("Edge Cases and Error Handling", function()
|
||||
-- Test empty password
|
||||
local empty_hash = crypto.hash_password("")
|
||||
assert(crypto.verify_password("", empty_hash))
|
||||
assert(not crypto.verify_password("not-empty", empty_hash))
|
||||
|
||||
-- Test long password
|
||||
local long_password = string.rep("a", 1000)
|
||||
local long_hash = crypto.hash_password(long_password)
|
||||
assert(crypto.verify_password(long_password, long_hash))
|
||||
|
||||
-- Test unicode password
|
||||
local unicode_password = "🔐password123🔑"
|
||||
local unicode_hash = crypto.hash_password(unicode_password)
|
||||
assert(crypto.verify_password(unicode_password, unicode_hash))
|
||||
|
||||
-- Test invalid hash formats
|
||||
assert(not crypto.verify_password("test", "invalid-hash"))
|
||||
assert(not crypto.verify_password("test", "$invalid$format$"))
|
||||
|
||||
-- Test unsupported algorithm error
|
||||
local success, err = pcall(crypto.hash_password, "test", "invalid-algo")
|
||||
assert(not success)
|
||||
assert(string.find(err, "unsupported algorithm"))
|
||||
end)
|
||||
|
||||
test("Cross-Algorithm Verification", function()
|
||||
local password = "cross123"
|
||||
|
||||
-- Create hashes with different algorithms
|
||||
local hashes = {
|
||||
crypto.hash_password(password, "argon2id"),
|
||||
crypto.hash_password(password, "bcrypt"),
|
||||
crypto.hash_password(password, "scrypt"),
|
||||
crypto.hash_password(password, "pbkdf2")
|
||||
}
|
||||
|
||||
-- Each hash should only verify with correct password
|
||||
for _, hash in ipairs(hashes) do
|
||||
assert(crypto.verify_password(password, hash))
|
||||
assert(not crypto.verify_password("wrong", hash))
|
||||
end
|
||||
|
||||
-- Hashes should be different from each other
|
||||
for i = 1, #hashes do
|
||||
for j = i + 1, #hashes do
|
||||
assert(hashes[i] ~= hashes[j])
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- PERFORMANCE TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Performance Test", function()
|
||||
local large_data = string.rep("test data for performance ", 1000)
|
||||
|
||||
local start = os.clock()
|
||||
local hash = crypto.sha256(large_data)
|
||||
local hash_time = os.clock() - start
|
||||
|
||||
start = os.clock()
|
||||
local encoded = crypto.base64_encode(large_data)
|
||||
local encode_time = os.clock() - start
|
||||
|
||||
start = os.clock()
|
||||
local decoded = crypto.base64_decode(encoded)
|
||||
local decode_time = os.clock() - start
|
||||
|
||||
print(string.format(" SHA256 of %d bytes: %.3fs", #large_data, hash_time))
|
||||
print(string.format(" Base64 encode: %.3fs", encode_time))
|
||||
print(string.format(" Base64 decode: %.3fs", decode_time))
|
||||
|
||||
assert_equal(decoded, large_data)
|
||||
assert_equal(type(hash), "string")
|
||||
end)
|
||||
|
||||
summary()
|
||||
test_exit()
|
||||
456
tests/fs.lua
456
tests/fs.lua
@ -1,456 +0,0 @@
|
||||
require("tests")
|
||||
local fs = require("fs")
|
||||
|
||||
-- Test data
|
||||
local test_content = "Hello, filesystem!\nThis is a test file.\n"
|
||||
local test_dir = "test_fs_dir"
|
||||
local test_file = fs.join(test_dir, "test.txt")
|
||||
|
||||
-- Clean up function
|
||||
local function cleanup()
|
||||
if fs.exists(test_file) then fs.remove(test_file) end
|
||||
if fs.exists(test_dir) then fs.rmdir(test_dir) end
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- SETUP AND CLEANUP
|
||||
-- ======================================================================
|
||||
|
||||
-- Clean up before tests
|
||||
cleanup()
|
||||
|
||||
-- ======================================================================
|
||||
-- BASIC FILE OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("File Write and Read", function()
|
||||
fs.mkdir(test_dir)
|
||||
|
||||
fs.write(test_file, test_content)
|
||||
assert_equal(fs.exists(test_file), true)
|
||||
assert_equal(fs.is_file(test_file), true)
|
||||
assert_equal(fs.is_dir(test_file), false)
|
||||
|
||||
local content = fs.read(test_file)
|
||||
assert_equal(content, test_content)
|
||||
end)
|
||||
|
||||
test("File Size", function()
|
||||
local size = fs.size(test_file)
|
||||
assert_equal(size, #test_content)
|
||||
assert_equal(fs.size("nonexistent.txt"), nil)
|
||||
end)
|
||||
|
||||
test("File Append", function()
|
||||
local additional = "Appended content.\n"
|
||||
fs.append(test_file, additional)
|
||||
|
||||
local content = fs.read(test_file)
|
||||
assert_equal(content, test_content .. additional)
|
||||
end)
|
||||
|
||||
test("File Copy", function()
|
||||
local copy_file = fs.join(test_dir, "copy.txt")
|
||||
fs.copy(test_file, copy_file)
|
||||
|
||||
assert_equal(fs.exists(copy_file), true)
|
||||
assert_equal(fs.read(copy_file), fs.read(test_file))
|
||||
|
||||
fs.remove(copy_file)
|
||||
end)
|
||||
|
||||
test("File Move", function()
|
||||
local move_file = fs.join(test_dir, "moved.txt")
|
||||
local original_content = fs.read(test_file)
|
||||
|
||||
fs.move(test_file, move_file)
|
||||
|
||||
assert_equal(fs.exists(test_file), false)
|
||||
assert_equal(fs.exists(move_file), true)
|
||||
assert_equal(fs.read(move_file), original_content)
|
||||
|
||||
-- Move back for other tests
|
||||
fs.move(move_file, test_file)
|
||||
end)
|
||||
|
||||
test("File Lines", function()
|
||||
local lines = fs.lines(test_file)
|
||||
assert_equal(type(lines), "table")
|
||||
assert(#lines >= 2, "should have multiple lines")
|
||||
assert(string.find(lines[1], "Hello"), "first line should contain 'Hello'")
|
||||
end)
|
||||
|
||||
test("File Touch", function()
|
||||
local touch_file = fs.join(test_dir, "touched.txt")
|
||||
|
||||
fs.touch(touch_file)
|
||||
assert_equal(fs.exists(touch_file), true)
|
||||
assert_equal(fs.size(touch_file), 0)
|
||||
|
||||
fs.remove(touch_file)
|
||||
end)
|
||||
|
||||
test("File Modification Time", function()
|
||||
local mtime = fs.mtime(test_file)
|
||||
assert_equal(type(mtime), "number")
|
||||
assert(mtime > 0, "mtime should be positive")
|
||||
|
||||
-- Touch should update mtime
|
||||
local old_mtime = mtime
|
||||
fs.touch(test_file)
|
||||
local new_mtime = fs.mtime(test_file)
|
||||
assert(new_mtime >= old_mtime, "mtime should be updated")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- DIRECTORY OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Directory Creation and Removal", function()
|
||||
local nested_dir = fs.join(test_dir, "nested", "deep")
|
||||
|
||||
fs.mkdir(nested_dir)
|
||||
assert_equal(fs.exists(nested_dir), true)
|
||||
assert_equal(fs.is_dir(nested_dir), true)
|
||||
|
||||
-- Clean up nested directories
|
||||
fs.rmdir(fs.join(test_dir, "nested"))
|
||||
end)
|
||||
|
||||
test("Directory Listing", function()
|
||||
-- Create some test files
|
||||
fs.write(fs.join(test_dir, "file1.txt"), "content1")
|
||||
fs.write(fs.join(test_dir, "file2.log"), "content2")
|
||||
fs.mkdir(fs.join(test_dir, "subdir"))
|
||||
|
||||
local entries = fs.list(test_dir)
|
||||
assert_equal(type(entries), "table")
|
||||
assert(#entries >= 3, "should have at least 3 entries")
|
||||
|
||||
-- Check entry structure
|
||||
local found_file = false
|
||||
for _, entry in ipairs(entries) do
|
||||
assert_equal(type(entry.name), "string")
|
||||
assert_equal(type(entry.is_dir), "boolean")
|
||||
if entry.name == "file1.txt" then
|
||||
found_file = true
|
||||
assert_equal(entry.is_dir, false)
|
||||
assert_equal(type(entry.size), "number")
|
||||
end
|
||||
end
|
||||
assert_equal(found_file, true)
|
||||
|
||||
-- Test filtered listings
|
||||
local files = fs.list_files(test_dir)
|
||||
local dirs = fs.list_dirs(test_dir)
|
||||
local names = fs.list_names(test_dir)
|
||||
|
||||
assert_equal(type(files), "table")
|
||||
assert_equal(type(dirs), "table")
|
||||
assert_equal(type(names), "table")
|
||||
assert(#files > 0, "should have files")
|
||||
assert(#dirs > 0, "should have directories")
|
||||
|
||||
-- Clean up
|
||||
fs.remove(fs.join(test_dir, "file1.txt"))
|
||||
fs.remove(fs.join(test_dir, "file2.log"))
|
||||
fs.rmdir(fs.join(test_dir, "subdir"))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- PATH OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Path Join", function()
|
||||
local path = fs.join("a", "b", "c", "file.txt")
|
||||
assert(string.find(path, "file.txt"), "should contain filename")
|
||||
|
||||
local empty_path = fs.join()
|
||||
assert_equal(type(empty_path), "string")
|
||||
end)
|
||||
|
||||
test("Path Components", function()
|
||||
local test_path = fs.join("home", "user", "documents", "file.txt")
|
||||
|
||||
local dir = fs.dirname(test_path)
|
||||
local base = fs.basename(test_path)
|
||||
local ext = fs.ext(test_path)
|
||||
|
||||
assert_equal(base, "file.txt")
|
||||
assert_equal(ext, ".txt")
|
||||
assert(string.find(dir, "documents"), "dirname should contain 'documents'")
|
||||
end)
|
||||
|
||||
test("Path Split Extension", function()
|
||||
local test_path = fs.join("home", "user", "file.tar.gz")
|
||||
local dir, name, ext = fs.splitext(test_path)
|
||||
|
||||
assert_equal(name, "file.tar")
|
||||
assert_equal(ext, ".gz")
|
||||
assert(string.find(dir, "user"), "dir should contain 'user'")
|
||||
end)
|
||||
|
||||
test("Path Absolute", function()
|
||||
local abs_path = fs.abs(".")
|
||||
assert_equal(type(abs_path), "string")
|
||||
assert(#abs_path > 1, "absolute path should not be empty")
|
||||
end)
|
||||
|
||||
test("Path Clean", function()
|
||||
local messy_path = "./test/../test/./file.txt"
|
||||
local clean_path = fs.clean(messy_path)
|
||||
assert_equal(type(clean_path), "string")
|
||||
assert(not string.find(clean_path, "%.%."), "should not contain '..'")
|
||||
end)
|
||||
|
||||
test("Path Split", function()
|
||||
local test_path = fs.join("home", "user", "file.txt")
|
||||
local dir, file = fs.split(test_path)
|
||||
|
||||
assert_equal(file, "file.txt")
|
||||
assert(string.find(dir, "user"), "dir should contain 'user'")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- WORKING DIRECTORY
|
||||
-- ======================================================================
|
||||
|
||||
test("Working Directory", function()
|
||||
local original_cwd = fs.getcwd()
|
||||
assert_equal(type(original_cwd), "string")
|
||||
assert(#original_cwd > 0, "cwd should not be empty")
|
||||
|
||||
-- Test directory change
|
||||
fs.chdir(test_dir)
|
||||
local new_cwd = fs.getcwd()
|
||||
assert(string.find(new_cwd, test_dir), "cwd should contain test_dir")
|
||||
|
||||
-- Change back
|
||||
fs.chdir(original_cwd)
|
||||
assert_equal(fs.getcwd(), original_cwd)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- TEMPORARY FILES
|
||||
-- ======================================================================
|
||||
|
||||
test("Temporary Files", function()
|
||||
local temp_file = fs.tempfile("test_")
|
||||
local temp_dir = fs.tempdir("test_")
|
||||
|
||||
assert_equal(type(temp_file), "string")
|
||||
assert_equal(type(temp_dir), "string")
|
||||
assert_equal(fs.exists(temp_file), true)
|
||||
assert_equal(fs.exists(temp_dir), true)
|
||||
assert_equal(fs.is_dir(temp_dir), true)
|
||||
|
||||
-- Clean up
|
||||
fs.remove(temp_file)
|
||||
fs.rmdir(temp_dir)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- PATTERN MATCHING
|
||||
-- ======================================================================
|
||||
|
||||
test("Glob Patterns", function()
|
||||
-- Create test files for globbing
|
||||
fs.write(fs.join(test_dir, "test1.txt"), "content")
|
||||
fs.write(fs.join(test_dir, "test2.txt"), "content")
|
||||
fs.write(fs.join(test_dir, "other.log"), "content")
|
||||
|
||||
local pattern = fs.join(test_dir, "*.txt")
|
||||
local matches = fs.glob(pattern)
|
||||
|
||||
assert_equal(type(matches), "table")
|
||||
assert(#matches >= 2, "should match txt files")
|
||||
|
||||
-- Clean up
|
||||
fs.remove(fs.join(test_dir, "test1.txt"))
|
||||
fs.remove(fs.join(test_dir, "test2.txt"))
|
||||
fs.remove(fs.join(test_dir, "other.log"))
|
||||
end)
|
||||
|
||||
test("Walk Directory", function()
|
||||
-- Create nested structure
|
||||
fs.mkdir(fs.join(test_dir, "sub1"))
|
||||
fs.mkdir(fs.join(test_dir, "sub2"))
|
||||
fs.write(fs.join(test_dir, "root.txt"), "content")
|
||||
fs.write(fs.join(test_dir, "sub1", "nested.txt"), "content")
|
||||
|
||||
local files = fs.walk(test_dir)
|
||||
assert_equal(type(files), "table")
|
||||
assert(#files > 3, "should find multiple files and directories")
|
||||
|
||||
-- Clean up
|
||||
fs.remove(fs.join(test_dir, "root.txt"))
|
||||
fs.remove(fs.join(test_dir, "sub1", "nested.txt"))
|
||||
fs.rmdir(fs.join(test_dir, "sub1"))
|
||||
fs.rmdir(fs.join(test_dir, "sub2"))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("File Extension Functions", function()
|
||||
local path = "document.pdf"
|
||||
assert_equal(fs.extension(path), "pdf")
|
||||
|
||||
local new_path = fs.change_ext(path, "txt")
|
||||
assert_equal(new_path, "document.txt")
|
||||
|
||||
local new_path2 = fs.change_ext(path, ".docx")
|
||||
assert_equal(new_path2, "document.docx")
|
||||
end)
|
||||
|
||||
test("Ensure Directory", function()
|
||||
local ensure_dir = fs.join(test_dir, "ensure_test")
|
||||
|
||||
fs.ensure_dir(ensure_dir)
|
||||
assert_equal(fs.exists(ensure_dir), true)
|
||||
assert_equal(fs.is_dir(ensure_dir), true)
|
||||
|
||||
-- Should not error if already exists
|
||||
fs.ensure_dir(ensure_dir)
|
||||
|
||||
fs.rmdir(ensure_dir)
|
||||
end)
|
||||
|
||||
test("Human Readable Size", function()
|
||||
local small_file = fs.join(test_dir, "small.txt")
|
||||
fs.write(small_file, "test")
|
||||
|
||||
local size_str = fs.size_human(small_file)
|
||||
assert_equal(type(size_str), "string")
|
||||
assert(string.find(size_str, "B"), "should contain byte unit")
|
||||
|
||||
fs.remove(small_file)
|
||||
end)
|
||||
|
||||
test("Safe Path Check", function()
|
||||
assert_equal(fs.is_safe_path("safe/path.txt"), true)
|
||||
assert_equal(fs.is_safe_path("../dangerous"), false)
|
||||
assert_equal(fs.is_safe_path("/absolute/path"), false)
|
||||
assert_equal(fs.is_safe_path("~/home/path"), false)
|
||||
end)
|
||||
|
||||
test("Copy Tree", function()
|
||||
-- Create source structure
|
||||
local src_dir = fs.join(test_dir, "src")
|
||||
local dst_dir = fs.join(test_dir, "dst")
|
||||
|
||||
fs.mkdir(src_dir)
|
||||
fs.mkdir(fs.join(src_dir, "subdir"))
|
||||
fs.write(fs.join(src_dir, "file1.txt"), "content1")
|
||||
fs.write(fs.join(src_dir, "subdir", "file2.txt"), "content2")
|
||||
|
||||
fs.copytree(src_dir, dst_dir)
|
||||
|
||||
assert_equal(fs.exists(dst_dir), true)
|
||||
assert_equal(fs.exists(fs.join(dst_dir, "file1.txt")), true)
|
||||
assert_equal(fs.exists(fs.join(dst_dir, "subdir", "file2.txt")), true)
|
||||
assert_equal(fs.read(fs.join(dst_dir, "file1.txt")), "content1")
|
||||
|
||||
-- Clean up
|
||||
fs.rmdir(src_dir)
|
||||
fs.rmdir(dst_dir)
|
||||
end)
|
||||
|
||||
test("Find Files", function()
|
||||
-- Create test files
|
||||
fs.write(fs.join(test_dir, "find1.txt"), "content")
|
||||
fs.write(fs.join(test_dir, "find2.txt"), "content")
|
||||
fs.write(fs.join(test_dir, "other.log"), "content")
|
||||
fs.mkdir(fs.join(test_dir, "subdir"))
|
||||
fs.write(fs.join(test_dir, "subdir", "find3.txt"), "content")
|
||||
|
||||
local txt_files = fs.find(test_dir, "%.txt$", true)
|
||||
assert_equal(type(txt_files), "table")
|
||||
assert(#txt_files >= 3, "should find txt files recursively")
|
||||
|
||||
local txt_files_flat = fs.find(test_dir, "%.txt$", false)
|
||||
assert(#txt_files_flat < #txt_files, "non-recursive should find fewer files")
|
||||
|
||||
-- Clean up
|
||||
fs.remove(fs.join(test_dir, "find1.txt"))
|
||||
fs.remove(fs.join(test_dir, "find2.txt"))
|
||||
fs.remove(fs.join(test_dir, "other.log"))
|
||||
fs.remove(fs.join(test_dir, "subdir", "find3.txt"))
|
||||
fs.rmdir(fs.join(test_dir, "subdir"))
|
||||
end)
|
||||
|
||||
test("Directory Tree", function()
|
||||
-- Create test structure
|
||||
fs.mkdir(fs.join(test_dir, "tree_test"))
|
||||
fs.write(fs.join(test_dir, "tree_test", "file.txt"), "content")
|
||||
fs.mkdir(fs.join(test_dir, "tree_test", "subdir"))
|
||||
|
||||
local tree = fs.tree(test_dir)
|
||||
assert_equal(type(tree), "table")
|
||||
assert_equal(tree.is_dir, true)
|
||||
assert_equal(type(tree.children), "table")
|
||||
assert(#tree.children > 0, "should have children")
|
||||
|
||||
-- Clean up
|
||||
fs.remove(fs.join(test_dir, "tree_test", "file.txt"))
|
||||
fs.rmdir(fs.join(test_dir, "tree_test", "subdir"))
|
||||
fs.rmdir(fs.join(test_dir, "tree_test"))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- ERROR HANDLING
|
||||
-- ======================================================================
|
||||
|
||||
test("Error Handling", function()
|
||||
-- Reading non-existent file
|
||||
local success, err = pcall(fs.read, "nonexistent.txt")
|
||||
assert_equal(false, success)
|
||||
|
||||
-- Writing to invalid path
|
||||
local success2, err2 = pcall(fs.write, "/invalid/path/file.txt", "content")
|
||||
assert_equal(false, success2)
|
||||
|
||||
-- Listing non-existent directory
|
||||
local success3, err3 = pcall(fs.list, "nonexistent_dir")
|
||||
assert_equal(false, success3)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- PERFORMANCE TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Performance Test", function()
|
||||
local large_content = string.rep("performance test data\n", 1000)
|
||||
local perf_file = fs.join(test_dir, "performance.txt")
|
||||
|
||||
local start = os.clock()
|
||||
fs.write(perf_file, large_content)
|
||||
local write_time = os.clock() - start
|
||||
|
||||
start = os.clock()
|
||||
local read_content = fs.read(perf_file)
|
||||
local read_time = os.clock() - start
|
||||
|
||||
start = os.clock()
|
||||
local lines = fs.lines(perf_file)
|
||||
local lines_time = os.clock() - start
|
||||
|
||||
print(string.format(" Write %d bytes: %.3fs", #large_content, write_time))
|
||||
print(string.format(" Read %d bytes: %.3fs", #read_content, read_time))
|
||||
print(string.format(" Parse %d lines: %.3fs", #lines, lines_time))
|
||||
|
||||
assert_equal(read_content, large_content)
|
||||
assert_equal(#lines, 1000)
|
||||
|
||||
fs.remove(perf_file)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- CLEANUP
|
||||
-- ======================================================================
|
||||
|
||||
cleanup()
|
||||
|
||||
summary()
|
||||
test_exit()
|
||||
188
tests/json.lua
188
tests/json.lua
@ -1,188 +0,0 @@
|
||||
require("tests")
|
||||
--local json = require("json")
|
||||
|
||||
-- Test data
|
||||
local test_data = {
|
||||
name = "John Doe",
|
||||
age = 30,
|
||||
active = true,
|
||||
scores = {85, 92, 78, 90},
|
||||
address = {
|
||||
street = "123 Main St",
|
||||
city = "Springfield",
|
||||
zip = "12345"
|
||||
},
|
||||
tags = {"developer", "golang", "lua"}
|
||||
}
|
||||
|
||||
-- Test 1: Basic encoding
|
||||
test("Basic JSON Encoding", function()
|
||||
local encoded = json.encode(test_data)
|
||||
assert_equal(type(encoded), "string")
|
||||
assert(string.find(encoded, "John Doe"), "should contain name")
|
||||
assert(string.find(encoded, "30"), "should contain age")
|
||||
end)
|
||||
|
||||
-- Test 2: Basic decoding
|
||||
test("Basic JSON Decoding", function()
|
||||
local encoded = json.encode(test_data)
|
||||
local decoded = json.decode(encoded)
|
||||
assert_equal(decoded.name, "John Doe")
|
||||
assert_equal(decoded.age, 30)
|
||||
assert_equal(decoded.active, true)
|
||||
assert_equal(#decoded.scores, 4)
|
||||
end)
|
||||
|
||||
-- Test 3: Round-trip encoding/decoding
|
||||
test("Round-trip Encoding/Decoding", function()
|
||||
local encoded = json.encode(test_data)
|
||||
local decoded = json.decode(encoded)
|
||||
local re_encoded = json.encode(decoded)
|
||||
local re_decoded = json.decode(re_encoded)
|
||||
|
||||
assert_equal(re_decoded.name, test_data.name)
|
||||
assert_equal(re_decoded.address.city, test_data.address.city)
|
||||
end)
|
||||
|
||||
-- Test 4: Pretty printing
|
||||
test("Pretty Printing", function()
|
||||
local pretty = json.pretty(test_data)
|
||||
assert_equal(type(pretty), "string")
|
||||
assert(string.find(pretty, "\n"), "pretty should contain newlines")
|
||||
assert(string.find(pretty, " "), "pretty should contain indentation")
|
||||
|
||||
-- Should still be valid JSON
|
||||
local decoded = json.decode(pretty)
|
||||
assert_equal(decoded.name, test_data.name)
|
||||
end)
|
||||
|
||||
-- Test 5: Object merging
|
||||
test("Object Merging", function()
|
||||
local obj1 = {a = 1, b = 2}
|
||||
local obj2 = {b = 3, c = 4}
|
||||
local obj3 = {d = 5}
|
||||
|
||||
local merged = json.merge(obj1, obj2, obj3)
|
||||
assert_equal(merged.a, 1)
|
||||
assert_equal(merged.b, 3) -- later wins
|
||||
assert_equal(merged.c, 4)
|
||||
assert_equal(merged.d, 5)
|
||||
end)
|
||||
|
||||
-- Test 6: Data extraction
|
||||
test("Data Extraction", function()
|
||||
local name = json.extract(test_data, "name")
|
||||
assert_equal(name, "John Doe")
|
||||
|
||||
local city = json.extract(test_data, "address.city")
|
||||
assert_equal(city, "Springfield")
|
||||
|
||||
local first_score = json.extract(test_data, "scores.[0]")
|
||||
assert_equal(first_score, 85)
|
||||
|
||||
local missing = json.extract(test_data, "nonexistent.field")
|
||||
assert_equal(missing, nil)
|
||||
end)
|
||||
|
||||
-- Test 7: Schema validation
|
||||
test("Schema Validation", function()
|
||||
local schema = {
|
||||
type = "table",
|
||||
properties = {
|
||||
name = {type = "string"},
|
||||
age = {type = "number"},
|
||||
active = {type = "boolean"}
|
||||
},
|
||||
required = {name = true, age = true}
|
||||
}
|
||||
|
||||
local valid, err = json.validate(test_data, schema)
|
||||
assert_equal(valid, true)
|
||||
|
||||
local invalid_data = {name = "John", age = "not_a_number"}
|
||||
local invalid, err2 = json.validate(invalid_data, schema)
|
||||
assert_equal(invalid, false)
|
||||
assert_equal(type(err2), "string")
|
||||
end)
|
||||
|
||||
-- Test 8: File operations
|
||||
test("File Save/Load", function()
|
||||
local filename = "test_output.json"
|
||||
|
||||
-- Save to file
|
||||
json.save_file(filename, test_data, true) -- pretty format
|
||||
|
||||
-- Check file exists
|
||||
assert(file_exists(filename), "file should exist after save")
|
||||
|
||||
-- Load from file
|
||||
local loaded = json.load_file(filename)
|
||||
assert_equal(loaded.name, test_data.name)
|
||||
assert_equal(loaded.address.zip, test_data.address.zip)
|
||||
|
||||
-- Clean up
|
||||
os.remove(filename)
|
||||
end)
|
||||
|
||||
-- Test 9: Error handling
|
||||
test("Error Handling", function()
|
||||
-- Invalid JSON should throw error
|
||||
local success, err = pcall(json.decode, '{"invalid": json}')
|
||||
assert_equal(success, false)
|
||||
assert_equal(type(err), "string")
|
||||
|
||||
-- Missing file should throw error
|
||||
local success2, err2 = pcall(json.load_file, "nonexistent_file.json")
|
||||
assert_equal(success2, false)
|
||||
assert_equal(type(err2), "string")
|
||||
end)
|
||||
|
||||
-- Test 10: Edge cases
|
||||
test("Edge Cases", function()
|
||||
-- Empty objects
|
||||
local empty_obj = {}
|
||||
local encoded_empty = json.encode(empty_obj)
|
||||
local decoded_empty = json.decode(encoded_empty)
|
||||
assert_equal(type(decoded_empty), "table")
|
||||
|
||||
-- Special numbers
|
||||
local special = {
|
||||
zero = 0,
|
||||
negative = -42,
|
||||
decimal = 3.14159
|
||||
}
|
||||
local encoded_special = json.encode(special)
|
||||
local decoded_special = json.decode(encoded_special)
|
||||
assert_equal(decoded_special.zero, 0)
|
||||
assert_equal(decoded_special.negative, -42)
|
||||
assert_close(decoded_special.decimal, 3.14159, 0.00001)
|
||||
end)
|
||||
|
||||
-- Performance test
|
||||
test("Performance Test", function()
|
||||
local large_data = {}
|
||||
for i = 1, 1000 do
|
||||
large_data[i] = {
|
||||
id = i,
|
||||
name = "User " .. i,
|
||||
data = {x = i * 2, y = i * 3, z = i * 4}
|
||||
}
|
||||
end
|
||||
|
||||
local start = os.clock()
|
||||
local encoded = json.encode(large_data)
|
||||
local encode_time = os.clock() - start
|
||||
|
||||
start = os.clock()
|
||||
local decoded = json.decode(encoded)
|
||||
local decode_time = os.clock() - start
|
||||
|
||||
print(string.format(" Encoded 1000 objects in %.3f seconds", encode_time))
|
||||
print(string.format(" Decoded 1000 objects in %.3f seconds", decode_time))
|
||||
|
||||
assert_equal(#decoded, 1000)
|
||||
assert_equal(decoded[500].name, "User 500")
|
||||
end)
|
||||
|
||||
summary()
|
||||
test_exit()
|
||||
362
tests/kv.lua
362
tests/kv.lua
@ -1,362 +0,0 @@
|
||||
require("tests")
|
||||
local kv = require("kv")
|
||||
|
||||
-- Clean up any existing test files
|
||||
os.remove("test_store.json")
|
||||
os.remove("test_store.txt")
|
||||
os.remove("test_oop.json")
|
||||
os.remove("test_temp.json")
|
||||
|
||||
-- ======================================================================
|
||||
-- BASIC OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Store creation and opening", function()
|
||||
assert(kv.open("test", "test_store.json"))
|
||||
assert(kv.open("memory_only", ""))
|
||||
assert(not kv.open("", "test.json")) -- Empty name should fail
|
||||
end)
|
||||
|
||||
test("Set and get operations", function()
|
||||
assert(kv.set("test", "key1", "value1"))
|
||||
assert_equal("value1", kv.get("test", "key1"))
|
||||
assert_equal("default", kv.get("test", "nonexistent", "default"))
|
||||
assert_equal(nil, kv.get("test", "nonexistent"))
|
||||
|
||||
-- Test with special characters
|
||||
assert(kv.set("test", "special:key", "value with spaces & symbols!"))
|
||||
assert_equal("value with spaces & symbols!", kv.get("test", "special:key"))
|
||||
end)
|
||||
|
||||
test("Key existence and deletion", function()
|
||||
kv.set("test", "temp_key", "temp_value")
|
||||
assert(kv.has("test", "temp_key"))
|
||||
assert(not kv.has("test", "missing_key"))
|
||||
|
||||
assert(kv.delete("test", "temp_key"))
|
||||
assert(not kv.has("test", "temp_key"))
|
||||
assert(not kv.delete("test", "missing_key"))
|
||||
end)
|
||||
|
||||
test("Store size tracking", function()
|
||||
kv.clear("test")
|
||||
assert_equal(0, kv.size("test"))
|
||||
|
||||
kv.set("test", "k1", "v1")
|
||||
kv.set("test", "k2", "v2")
|
||||
kv.set("test", "k3", "v3")
|
||||
assert_equal(3, kv.size("test"))
|
||||
|
||||
kv.delete("test", "k2")
|
||||
assert_equal(2, kv.size("test"))
|
||||
end)
|
||||
|
||||
test("Keys and values retrieval", function()
|
||||
kv.clear("test")
|
||||
kv.set("test", "a", "apple")
|
||||
kv.set("test", "b", "banana")
|
||||
kv.set("test", "c", "cherry")
|
||||
|
||||
local keys = kv.keys("test")
|
||||
local values = kv.values("test")
|
||||
|
||||
assert_equal(3, #keys)
|
||||
assert_equal(3, #values)
|
||||
|
||||
-- Check all keys exist
|
||||
local key_set = {}
|
||||
for _, k in ipairs(keys) do
|
||||
key_set[k] = true
|
||||
end
|
||||
assert(key_set["a"] and key_set["b"] and key_set["c"])
|
||||
|
||||
-- Check all values exist
|
||||
local value_set = {}
|
||||
for _, v in ipairs(values) do
|
||||
value_set[v] = true
|
||||
end
|
||||
assert(value_set["apple"] and value_set["banana"] and value_set["cherry"])
|
||||
end)
|
||||
|
||||
test("Clear store", function()
|
||||
kv.set("test", "temp1", "value1")
|
||||
kv.set("test", "temp2", "value2")
|
||||
assert(kv.size("test") > 0)
|
||||
|
||||
assert(kv.clear("test"))
|
||||
assert_equal(0, kv.size("test"))
|
||||
assert(not kv.has("test", "temp1"))
|
||||
end)
|
||||
|
||||
test("Save and close operations", function()
|
||||
-- Ensure store is properly opened with filename
|
||||
kv.close("test") -- Close if already open
|
||||
assert(kv.open("test", "test_store.json"))
|
||||
|
||||
assert(kv.set("test", "persistent", "data"))
|
||||
assert(kv.save("test"))
|
||||
assert(kv.close("test"))
|
||||
|
||||
-- Reopen and verify data persists
|
||||
assert(kv.open("test", "test_store.json"))
|
||||
assert_equal("data", kv.get("test", "persistent"))
|
||||
end)
|
||||
|
||||
test("Invalid store operations", function()
|
||||
assert(not kv.set("nonexistent", "key", "value"))
|
||||
assert_equal(nil, kv.get("nonexistent", "key"))
|
||||
assert(not kv.has("nonexistent", "key"))
|
||||
assert_equal(0, kv.size("nonexistent"))
|
||||
assert(not kv.delete("nonexistent", "key"))
|
||||
assert(not kv.clear("nonexistent"))
|
||||
assert(not kv.save("nonexistent"))
|
||||
assert(not kv.close("nonexistent"))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- OBJECT-ORIENTED INTERFACE
|
||||
-- ======================================================================
|
||||
|
||||
test("OOP store creation", function()
|
||||
local store = kv.create("oop_test", "test_oop.json")
|
||||
assert_equal("oop_test", store.name)
|
||||
|
||||
local memory_store = kv.create("memory_oop")
|
||||
assert_equal("memory_oop", memory_store.name)
|
||||
end)
|
||||
|
||||
test("OOP basic operations", function()
|
||||
local store = kv.create("oop_basic")
|
||||
|
||||
assert(store:set("foo", "bar"))
|
||||
assert_equal("bar", store:get("foo"))
|
||||
assert_equal("default", store:get("missing", "default"))
|
||||
assert(store:has("foo"))
|
||||
assert(not store:has("missing"))
|
||||
assert_equal(1, store:size())
|
||||
|
||||
assert(store:delete("foo"))
|
||||
assert(not store:has("foo"))
|
||||
assert_equal(0, store:size())
|
||||
end)
|
||||
|
||||
test("OOP collections", function()
|
||||
local store = kv.create("oop_collections")
|
||||
|
||||
store:set("a", "apple")
|
||||
store:set("b", "banana")
|
||||
store:set("c", "cherry")
|
||||
|
||||
local keys = store:keys()
|
||||
local values = store:values()
|
||||
|
||||
assert_equal(3, #keys)
|
||||
assert_equal(3, #values)
|
||||
|
||||
store:clear()
|
||||
assert_equal(0, store:size())
|
||||
|
||||
store:close()
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Increment operations", function()
|
||||
kv.open("util_test")
|
||||
|
||||
-- Increment non-existent key
|
||||
assert_equal(1, kv.increment("util_test", "counter"))
|
||||
assert_equal("1", kv.get("util_test", "counter"))
|
||||
|
||||
-- Increment existing key
|
||||
assert_equal(6, kv.increment("util_test", "counter", 5))
|
||||
assert_equal("6", kv.get("util_test", "counter"))
|
||||
|
||||
-- Decrement
|
||||
assert_equal(4, kv.increment("util_test", "counter", -2))
|
||||
|
||||
kv.close("util_test")
|
||||
end)
|
||||
|
||||
test("Append operations", function()
|
||||
kv.open("append_test")
|
||||
|
||||
-- Append to non-existent key
|
||||
assert(kv.append("append_test", "list", "first"))
|
||||
assert_equal("first", kv.get("append_test", "list"))
|
||||
|
||||
-- Append with separator
|
||||
assert(kv.append("append_test", "list", "second", ","))
|
||||
assert_equal("first,second", kv.get("append_test", "list"))
|
||||
|
||||
assert(kv.append("append_test", "list", "third", ","))
|
||||
assert_equal("first,second,third", kv.get("append_test", "list"))
|
||||
|
||||
kv.close("append_test")
|
||||
end)
|
||||
|
||||
test("TTL and expiration", function()
|
||||
kv.open("ttl_test", "test_temp.json")
|
||||
|
||||
kv.set("ttl_test", "temp_key", "temp_value")
|
||||
assert(kv.expire("ttl_test", "temp_key", 1)) -- 1 second TTL
|
||||
|
||||
-- Key should still exist immediately
|
||||
assert(kv.has("ttl_test", "temp_key"))
|
||||
|
||||
-- Wait for expiration
|
||||
os.execute("sleep 2")
|
||||
local expired = kv.cleanup_expired("ttl_test")
|
||||
assert(expired >= 1)
|
||||
|
||||
-- Key should be gone
|
||||
assert(not kv.has("ttl_test", "temp_key"))
|
||||
|
||||
kv.close("ttl_test")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- FILE FORMAT TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("JSON file format", function()
|
||||
kv.open("json_test", "test.json")
|
||||
kv.set("json_test", "key1", "value1")
|
||||
kv.set("json_test", "key2", "value2")
|
||||
kv.save("json_test")
|
||||
kv.close("json_test")
|
||||
|
||||
-- Verify file exists and reload
|
||||
assert(file_exists("test.json"))
|
||||
kv.open("json_test", "test.json")
|
||||
assert_equal("value1", kv.get("json_test", "key1"))
|
||||
assert_equal("value2", kv.get("json_test", "key2"))
|
||||
kv.close("json_test")
|
||||
|
||||
os.remove("test.json")
|
||||
end)
|
||||
|
||||
test("Text file format", function()
|
||||
kv.open("txt_test", "test.txt")
|
||||
kv.set("txt_test", "setting1", "value1")
|
||||
kv.set("txt_test", "setting2", "value2")
|
||||
kv.save("txt_test")
|
||||
kv.close("txt_test")
|
||||
|
||||
-- Verify file exists and reload
|
||||
assert(file_exists("test.txt"))
|
||||
kv.open("txt_test", "test.txt")
|
||||
assert_equal("value1", kv.get("txt_test", "setting1"))
|
||||
assert_equal("value2", kv.get("txt_test", "setting2"))
|
||||
kv.close("txt_test")
|
||||
|
||||
os.remove("test.txt")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- EDGE CASES AND ERROR HANDLING
|
||||
-- ======================================================================
|
||||
|
||||
test("Empty values and keys", function()
|
||||
kv.open("edge_test")
|
||||
|
||||
-- Empty value
|
||||
assert(kv.set("edge_test", "empty", ""))
|
||||
assert_equal("", kv.get("edge_test", "empty"))
|
||||
|
||||
-- Unicode keys and values
|
||||
assert(kv.set("edge_test", "ключ", "значение"))
|
||||
assert_equal("значение", kv.get("edge_test", "ключ"))
|
||||
|
||||
kv.close("edge_test")
|
||||
end)
|
||||
|
||||
test("Special characters in data", function()
|
||||
kv.open("special_test")
|
||||
|
||||
local special_value = 'Special chars: "quotes", \'apostrophes\', \n newlines, \t tabs, \\ backslashes'
|
||||
assert(kv.set("special_test", "special", special_value))
|
||||
assert_equal(special_value, kv.get("special_test", "special"))
|
||||
|
||||
kv.close("special_test")
|
||||
end)
|
||||
|
||||
test("Large data handling", function()
|
||||
kv.open("large_test")
|
||||
|
||||
-- Large value
|
||||
local large_value = string.rep("x", 10000)
|
||||
assert(kv.set("large_test", "large", large_value))
|
||||
assert_equal(large_value, kv.get("large_test", "large"))
|
||||
|
||||
-- Many keys
|
||||
for i = 1, 100 do
|
||||
kv.set("large_test", "key" .. i, "value" .. i)
|
||||
end
|
||||
assert_equal(101, kv.size("large_test")) -- 100 + 1 large value
|
||||
|
||||
kv.close("large_test")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- PERFORMANCE TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Performance test", function()
|
||||
kv.open("perf_test")
|
||||
|
||||
local start = os.clock()
|
||||
|
||||
-- Bulk insert
|
||||
for i = 1, 1000 do
|
||||
kv.set("perf_test", "key" .. i, "value" .. i)
|
||||
end
|
||||
local insert_time = os.clock() - start
|
||||
|
||||
-- Bulk read
|
||||
start = os.clock()
|
||||
for i = 1, 1000 do
|
||||
local value = kv.get("perf_test", "key" .. i)
|
||||
assert_equal("value" .. i, value)
|
||||
end
|
||||
local read_time = os.clock() - start
|
||||
|
||||
-- Check final size
|
||||
assert_equal(1000, kv.size("perf_test"))
|
||||
|
||||
print(string.format(" Insert 1000 items: %.3fs", insert_time))
|
||||
print(string.format(" Read 1000 items: %.3fs", read_time))
|
||||
|
||||
kv.close("perf_test")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- INTEGRATION TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Multiple store integration", function()
|
||||
local users = kv.create("users_int")
|
||||
local cache = kv.create("cache_int")
|
||||
|
||||
-- Simulate user data
|
||||
users:set("user:123", "john_doe")
|
||||
cache:set("user:123:last_seen", tostring(os.time()))
|
||||
|
||||
-- Verify data in stores
|
||||
assert_equal("john_doe", users:get("user:123"))
|
||||
assert(cache:has("user:123:last_seen"))
|
||||
|
||||
-- Clean up
|
||||
users:close()
|
||||
cache:close()
|
||||
end)
|
||||
|
||||
-- Clean up test files
|
||||
os.remove("test_store.json")
|
||||
os.remove("test_oop.json")
|
||||
os.remove("test_temp.json")
|
||||
|
||||
summary()
|
||||
test_exit()
|
||||
393
tests/math.lua
393
tests/math.lua
@ -1,393 +0,0 @@
|
||||
require("tests")
|
||||
|
||||
-- Constants tests
|
||||
test("math.pi", function()
|
||||
assert_close(math.pi, 3.14159265358979323846)
|
||||
assert(math.pi > 3.14 and math.pi < 3.15)
|
||||
end)
|
||||
|
||||
test("math.tau", function()
|
||||
assert_close(math.tau, 6.28318530717958647693)
|
||||
assert_close(math.tau, 2 * math.pi)
|
||||
end)
|
||||
|
||||
test("math.e", function()
|
||||
assert_close(math.e, 2.71828182845904523536)
|
||||
assert(math.e > 2.7 and math.e < 2.8)
|
||||
end)
|
||||
|
||||
test("math.phi", function()
|
||||
assert_close(math.phi, 1.61803398874989484820)
|
||||
assert_close(math.phi, (1 + math.sqrt(5)) / 2)
|
||||
end)
|
||||
|
||||
test("math.infinity", function()
|
||||
assert_equal(math.infinity, 1/0)
|
||||
assert(math.infinity > 0)
|
||||
end)
|
||||
|
||||
test("math.nan", function()
|
||||
assert_equal(math.isnan(math.nan), true)
|
||||
assert(math.nan ~= math.nan) -- NaN property
|
||||
end)
|
||||
|
||||
-- Extended functions tests
|
||||
test("math.cbrt", function()
|
||||
assert_close(math.cbrt(8), 2)
|
||||
assert_close(math.cbrt(-8), -2)
|
||||
end)
|
||||
|
||||
test("math.hypot", function()
|
||||
assert_close(math.hypot(3, 4), 5)
|
||||
assert_close(math.hypot(5, 12), 13)
|
||||
end)
|
||||
|
||||
test("math.isnan", function()
|
||||
assert_equal(math.isnan(0/0), true)
|
||||
assert_equal(math.isnan(5), false)
|
||||
end)
|
||||
|
||||
test("math.isfinite", function()
|
||||
assert_equal(math.isfinite(5), true)
|
||||
assert_equal(math.isfinite(math.infinity), false)
|
||||
end)
|
||||
|
||||
test("math.sign", function()
|
||||
assert_equal(math.sign(5), 1)
|
||||
assert_equal(math.sign(-5), -1)
|
||||
assert_equal(math.sign(0), 0)
|
||||
end)
|
||||
|
||||
test("math.clamp", function()
|
||||
assert_equal(math.clamp(5, 0, 3), 3)
|
||||
assert_equal(math.clamp(-1, 0, 3), 0)
|
||||
assert_equal(math.clamp(2, 0, 3), 2)
|
||||
end)
|
||||
|
||||
test("math.lerp", function()
|
||||
assert_close(math.lerp(0, 10, 0.5), 5)
|
||||
assert_close(math.lerp(2, 8, 0.25), 3.5)
|
||||
end)
|
||||
|
||||
test("math.smoothstep", function()
|
||||
assert_close(math.smoothstep(0, 1, 0.5), 0.5)
|
||||
assert_close(math.smoothstep(0, 10, 5), 0.5)
|
||||
end)
|
||||
|
||||
test("math.map", function()
|
||||
assert_close(math.map(5, 0, 10, 0, 100), 50)
|
||||
assert_close(math.map(2, 0, 4, 10, 20), 15)
|
||||
end)
|
||||
|
||||
test("math.round", function()
|
||||
assert_equal(math.round(2.7), 3)
|
||||
assert_equal(math.round(-2.7), -3)
|
||||
assert_equal(math.round(2.3), 2)
|
||||
end)
|
||||
|
||||
test("math.roundto", function()
|
||||
assert_close(math.roundto(3.14159, 2), 3.14)
|
||||
assert_close(math.roundto(123.456, 1), 123.5)
|
||||
end)
|
||||
|
||||
test("math.normalize_angle", function()
|
||||
assert_close(math.normalize_angle(math.pi * 2.5), math.pi * 0.5)
|
||||
assert_close(math.normalize_angle(-math.pi * 2.5), -math.pi * 0.5)
|
||||
end)
|
||||
|
||||
test("math.distance", function()
|
||||
assert_close(math.distance(0, 0, 3, 4), 5)
|
||||
assert_close(math.distance(1, 1, 4, 5), 5)
|
||||
end)
|
||||
|
||||
test("math.factorial", function()
|
||||
assert_equal(math.factorial(5), 120)
|
||||
assert_equal(math.factorial(0), 1)
|
||||
assert_equal(math.factorial(-1), nil)
|
||||
end)
|
||||
|
||||
test("math.gcd", function()
|
||||
assert_equal(math.gcd(48, 18), 6)
|
||||
assert_equal(math.gcd(100, 25), 25)
|
||||
end)
|
||||
|
||||
test("math.lcm", function()
|
||||
assert_equal(math.lcm(4, 6), 12)
|
||||
assert_equal(math.lcm(15, 20), 60)
|
||||
end)
|
||||
|
||||
-- Random functions tests
|
||||
test("math.randomf", function()
|
||||
local r1 = math.randomf(0, 1)
|
||||
local r2 = math.randomf(5, 10)
|
||||
assert(r1 >= 0 and r1 < 1)
|
||||
assert(r2 >= 5 and r2 < 10)
|
||||
end)
|
||||
|
||||
test("math.randint", function()
|
||||
local i1 = math.randint(1, 10)
|
||||
local i2 = math.randint(50, 60)
|
||||
assert(i1 >= 1 and i1 <= 10)
|
||||
assert(i2 >= 50 and i2 <= 60)
|
||||
end)
|
||||
|
||||
test("math.randboolean", function()
|
||||
local b1 = math.randboolean()
|
||||
local b2 = math.randboolean(0.8)
|
||||
assert_equal(type(b1), "boolean")
|
||||
assert_equal(type(b2), "boolean")
|
||||
end)
|
||||
|
||||
-- Statistics tests
|
||||
test("math.sum", function()
|
||||
assert_equal(math.sum({1, 2, 3, 4, 5}), 15)
|
||||
assert_equal(math.sum({10, 20, 30}), 60)
|
||||
end)
|
||||
|
||||
test("math.mean", function()
|
||||
assert_equal(math.mean({1, 2, 3, 4, 5}), 3)
|
||||
assert_equal(math.mean({10, 20, 30}), 20)
|
||||
end)
|
||||
|
||||
test("math.median", function()
|
||||
assert_equal(math.median({1, 2, 3, 4, 5}), 3)
|
||||
assert_equal(math.median({1, 2, 3, 4}), 2.5)
|
||||
end)
|
||||
|
||||
test("math.variance", function()
|
||||
assert_close(math.variance({1, 2, 3, 4, 5}), 2)
|
||||
assert_close(math.variance({10, 10, 10}), 0)
|
||||
end)
|
||||
|
||||
test("math.stdev", function()
|
||||
assert_close(math.stdev({1, 2, 3, 4, 5}), math.sqrt(2))
|
||||
assert_close(math.stdev({10, 10, 10}), 0)
|
||||
end)
|
||||
|
||||
test("math.mode", function()
|
||||
assert_equal(math.mode({1, 2, 2, 3}), 2)
|
||||
assert_equal(math.mode({5, 5, 4, 4, 4}), 4)
|
||||
end)
|
||||
|
||||
test("math.minmax", function()
|
||||
local min1, max1 = math.minmax({1, 2, 3, 4, 5})
|
||||
local min2, max2 = math.minmax({-5, 0, 10})
|
||||
assert_equal(min1, 1)
|
||||
assert_equal(max1, 5)
|
||||
assert_equal(min2, -5)
|
||||
assert_equal(max2, 10)
|
||||
end)
|
||||
|
||||
-- 2D Vector tests
|
||||
test("math.vec2.new", function()
|
||||
local v1 = math.vec2.new(3, 4)
|
||||
local v2 = math.vec2.new()
|
||||
assert_equal(v1.x, 3)
|
||||
assert_equal(v1.y, 4)
|
||||
assert_equal(v2.x, 0)
|
||||
assert_equal(v2.y, 0)
|
||||
end)
|
||||
|
||||
test("math.vec2.add", function()
|
||||
local v1 = math.vec2.new(3, 4)
|
||||
local v2 = math.vec2.new(1, 2)
|
||||
local result = math.vec2.add(v1, v2)
|
||||
assert_equal(result.x, 4)
|
||||
assert_equal(result.y, 6)
|
||||
end)
|
||||
|
||||
test("math.vec2.sub", function()
|
||||
local v1 = math.vec2.new(3, 4)
|
||||
local v2 = math.vec2.new(1, 2)
|
||||
local result = math.vec2.sub(v1, v2)
|
||||
assert_equal(result.x, 2)
|
||||
assert_equal(result.y, 2)
|
||||
end)
|
||||
|
||||
test("math.vec2.mul", function()
|
||||
local v1 = math.vec2.new(3, 4)
|
||||
local result1 = math.vec2.mul(v1, 2)
|
||||
local result2 = math.vec2.mul(v1, math.vec2.new(2, 3))
|
||||
assert_equal(result1.x, 6)
|
||||
assert_equal(result1.y, 8)
|
||||
assert_equal(result2.x, 6)
|
||||
assert_equal(result2.y, 12)
|
||||
end)
|
||||
|
||||
test("math.vec2.dot", function()
|
||||
local v1 = math.vec2.new(3, 4)
|
||||
local v2 = math.vec2.new(1, 2)
|
||||
assert_equal(math.vec2.dot(v1, v2), 11)
|
||||
assert_equal(math.vec2.dot(math.vec2.new(1, 0), math.vec2.new(0, 1)), 0)
|
||||
end)
|
||||
|
||||
test("math.vec2.length", function()
|
||||
local v1 = math.vec2.new(3, 4)
|
||||
local v2 = math.vec2.new(0, 5)
|
||||
assert_close(math.vec2.length(v1), 5)
|
||||
assert_close(math.vec2.length(v2), 5)
|
||||
end)
|
||||
|
||||
test("math.vec2.distance", function()
|
||||
local v1 = math.vec2.new(0, 0)
|
||||
local v2 = math.vec2.new(3, 4)
|
||||
assert_close(math.vec2.distance(v1, v2), 5)
|
||||
assert_close(math.vec2.distance(math.vec2.new(1, 1), math.vec2.new(4, 5)), 5)
|
||||
end)
|
||||
|
||||
test("math.vec2.normalize", function()
|
||||
local v1 = math.vec2.new(3, 4)
|
||||
local normalized = math.vec2.normalize(v1)
|
||||
assert_close(math.vec2.length(normalized), 1)
|
||||
assert_close(normalized.x, 0.6)
|
||||
assert_close(normalized.y, 0.8)
|
||||
end)
|
||||
|
||||
test("math.vec2.rotate", function()
|
||||
local v1 = math.vec2.new(1, 0)
|
||||
local rotated90 = math.vec2.rotate(v1, math.pi/2)
|
||||
local rotated180 = math.vec2.rotate(v1, math.pi)
|
||||
assert_close(rotated90.x, 0, 1e-10)
|
||||
assert_close(rotated90.y, 1)
|
||||
assert_close(rotated180.x, -1)
|
||||
assert_close(rotated180.y, 0, 1e-10)
|
||||
end)
|
||||
|
||||
-- 3D Vector tests
|
||||
test("math.vec3.new", function()
|
||||
local v1 = math.vec3.new(1, 2, 3)
|
||||
local v2 = math.vec3.new()
|
||||
assert_equal(v1.x, 1)
|
||||
assert_equal(v1.y, 2)
|
||||
assert_equal(v1.z, 3)
|
||||
assert_equal(v2.x, 0)
|
||||
assert_equal(v2.y, 0)
|
||||
assert_equal(v2.z, 0)
|
||||
end)
|
||||
|
||||
test("math.vec3.cross", function()
|
||||
local v1 = math.vec3.new(1, 0, 0)
|
||||
local v2 = math.vec3.new(0, 1, 0)
|
||||
local cross = math.vec3.cross(v1, v2)
|
||||
assert_equal(cross.x, 0)
|
||||
assert_equal(cross.y, 0)
|
||||
assert_equal(cross.z, 1)
|
||||
end)
|
||||
|
||||
test("math.vec3.length", function()
|
||||
local v1 = math.vec3.new(1, 2, 3)
|
||||
local v2 = math.vec3.new(0, 0, 5)
|
||||
assert_close(math.vec3.length(v1), math.sqrt(14))
|
||||
assert_close(math.vec3.length(v2), 5)
|
||||
end)
|
||||
|
||||
-- Matrix tests
|
||||
test("math.mat2.det", function()
|
||||
local m1 = math.mat2.new(1, 2, 3, 4)
|
||||
local m2 = math.mat2.new(2, 0, 0, 3)
|
||||
assert_equal(math.mat2.det(m1), -2)
|
||||
assert_equal(math.mat2.det(m2), 6)
|
||||
end)
|
||||
|
||||
test("math.mat2.mul", function()
|
||||
local m1 = math.mat2.new(1, 2, 3, 4)
|
||||
local m2 = math.mat2.new(2, 0, 1, 3)
|
||||
local product = math.mat2.mul(m1, m2)
|
||||
assert_equal(product[1][1], 4)
|
||||
assert_equal(product[1][2], 6)
|
||||
assert_equal(product[2][1], 10)
|
||||
assert_equal(product[2][2], 12)
|
||||
end)
|
||||
|
||||
test("math.mat2.rotation", function()
|
||||
local rot90 = math.mat2.rotation(math.pi/2)
|
||||
local rot180 = math.mat2.rotation(math.pi)
|
||||
assert_close(rot90[1][1], 0, 1e-10)
|
||||
assert_close(rot90[1][2], -1)
|
||||
assert_close(rot180[1][1], -1)
|
||||
assert_close(rot180[2][2], -1)
|
||||
end)
|
||||
|
||||
test("math.mat3.transform_point", function()
|
||||
local translation = math.mat3.translation(5, 10)
|
||||
local point1 = {x = 1, y = 2}
|
||||
local point2 = {x = 0, y = 0}
|
||||
local t1 = math.mat3.transform_point(translation, point1)
|
||||
local t2 = math.mat3.transform_point(translation, point2)
|
||||
assert_equal(t1.x, 6)
|
||||
assert_equal(t1.y, 12)
|
||||
assert_equal(t2.x, 5)
|
||||
assert_equal(t2.y, 10)
|
||||
end)
|
||||
|
||||
test("math.mat3.det", function()
|
||||
local identity = math.mat3.identity()
|
||||
local scale = math.mat3.scale(2, 3)
|
||||
assert_close(math.mat3.det(identity), 1)
|
||||
assert_close(math.mat3.det(scale), 6)
|
||||
end)
|
||||
|
||||
-- Geometry tests
|
||||
test("math.geometry.triangle_area", function()
|
||||
local area1 = math.geometry.triangle_area(0, 0, 4, 0, 0, 3)
|
||||
local area2 = math.geometry.triangle_area(0, 0, 2, 0, 1, 2)
|
||||
assert_close(area1, 6)
|
||||
assert_close(area2, 2)
|
||||
end)
|
||||
|
||||
test("math.geometry.point_in_triangle", function()
|
||||
local inside1 = math.geometry.point_in_triangle(1, 1, 0, 0, 4, 0, 0, 3)
|
||||
local inside2 = math.geometry.point_in_triangle(5, 5, 0, 0, 4, 0, 0, 3)
|
||||
assert_equal(inside1, true)
|
||||
assert_equal(inside2, false)
|
||||
end)
|
||||
|
||||
test("math.geometry.line_intersect", function()
|
||||
local intersects1, x1, y1 = math.geometry.line_intersect(0, 0, 2, 2, 0, 2, 2, 0)
|
||||
local intersects2, x2, y2 = math.geometry.line_intersect(0, 0, 1, 1, 2, 2, 3, 3)
|
||||
assert_equal(intersects1, true)
|
||||
assert_close(x1, 1)
|
||||
assert_close(y1, 1)
|
||||
assert_equal(intersects2, false)
|
||||
end)
|
||||
|
||||
test("math.geometry.closest_point_on_segment", function()
|
||||
local x1, y1 = math.geometry.closest_point_on_segment(1, 3, 0, 0, 4, 0)
|
||||
local x2, y2 = math.geometry.closest_point_on_segment(5, 1, 0, 0, 4, 0)
|
||||
assert_close(x1, 1)
|
||||
assert_close(y1, 0)
|
||||
assert_close(x2, 4)
|
||||
assert_close(y2, 0)
|
||||
end)
|
||||
|
||||
-- Interpolation tests
|
||||
test("math.interpolation.bezier", function()
|
||||
local result1 = math.interpolation.bezier(0.5, 0, 1, 2, 3)
|
||||
local result2 = math.interpolation.bezier(0, 0, 1, 2, 3)
|
||||
assert_close(result1, 1.5)
|
||||
assert_close(result2, 0)
|
||||
end)
|
||||
|
||||
test("math.interpolation.quadratic_bezier", function()
|
||||
local result1 = math.interpolation.quadratic_bezier(0.5, 0, 2, 4)
|
||||
local result2 = math.interpolation.quadratic_bezier(1, 0, 2, 4)
|
||||
assert_close(result1, 2)
|
||||
assert_close(result2, 4)
|
||||
end)
|
||||
|
||||
test("math.interpolation.smootherstep", function()
|
||||
local result1 = math.interpolation.smootherstep(0, 1, 0.5)
|
||||
local result2 = math.interpolation.smootherstep(0, 10, 5)
|
||||
assert_close(result1, 0.5)
|
||||
assert_close(result2, 0.5)
|
||||
end)
|
||||
|
||||
test("math.interpolation.catmull_rom", function()
|
||||
local result1 = math.interpolation.catmull_rom(0.5, 0, 1, 2, 3)
|
||||
local result2 = math.interpolation.catmull_rom(0, 0, 1, 2, 3)
|
||||
assert_close(result1, 1.5)
|
||||
assert_close(result2, 1)
|
||||
end)
|
||||
|
||||
summary()
|
||||
test_exit()
|
||||
@ -1,397 +0,0 @@
|
||||
require("tests")
|
||||
local sessions = require("sessions")
|
||||
|
||||
-- Clean up test files
|
||||
os.remove("test_sessions.json")
|
||||
os.remove("test_sessions2.json")
|
||||
|
||||
-- ======================================================================
|
||||
-- SESSION STORE INITIALIZATION
|
||||
-- ======================================================================
|
||||
|
||||
test("Session store initialization", function()
|
||||
assert(sessions.init("test_sessions", "test_sessions.json"))
|
||||
assert(sessions.init("memory_sessions"))
|
||||
|
||||
-- Test with explicit store name
|
||||
assert(sessions.init("named_sessions", "test_sessions2.json"))
|
||||
end)
|
||||
|
||||
test("Session ID generation", function()
|
||||
local id1 = sessions.generate_id()
|
||||
local id2 = sessions.generate_id()
|
||||
|
||||
assert_equal(32, #id1)
|
||||
assert_equal(32, #id2)
|
||||
assert(id1 ~= id2, "Session IDs should be unique")
|
||||
|
||||
-- Check character set (alphanumeric)
|
||||
assert(id1:match("^[a-zA-Z0-9]+$"), "Session ID should be alphanumeric")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- BASIC SESSION OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Session creation and retrieval", function()
|
||||
sessions.init("basic_test")
|
||||
|
||||
local session_id = sessions.generate_id()
|
||||
local session_data = {
|
||||
user_id = 123,
|
||||
username = "testuser",
|
||||
role = "admin",
|
||||
permissions = {"read", "write"}
|
||||
}
|
||||
|
||||
assert(sessions.create(session_id, session_data))
|
||||
|
||||
local retrieved = sessions.get(session_id)
|
||||
assert_equal(123, retrieved.user_id)
|
||||
assert_equal("testuser", retrieved.username)
|
||||
assert_equal("admin", retrieved.role)
|
||||
assert_table_equal({"read", "write"}, retrieved.permissions)
|
||||
assert(retrieved._created ~= nil)
|
||||
assert(retrieved._last_accessed ~= nil)
|
||||
|
||||
-- Test non-existent session
|
||||
assert_equal(nil, sessions.get("nonexistent"))
|
||||
end)
|
||||
|
||||
test("Session updates", function()
|
||||
sessions.init("update_test")
|
||||
|
||||
local session_id = sessions.generate_id()
|
||||
sessions.create(session_id, {count = 1})
|
||||
|
||||
local session = sessions.get(session_id)
|
||||
local new_data = {
|
||||
count = 2,
|
||||
new_field = "added"
|
||||
}
|
||||
|
||||
assert(sessions.update(session_id, new_data))
|
||||
|
||||
local updated = sessions.get(session_id)
|
||||
assert_equal(2, updated.count)
|
||||
assert_equal("added", updated.new_field)
|
||||
assert(updated._created ~= nil)
|
||||
assert(updated._last_accessed ~= nil)
|
||||
end)
|
||||
|
||||
test("Session deletion", function()
|
||||
sessions.init("delete_test")
|
||||
|
||||
local session_id = sessions.generate_id()
|
||||
sessions.create(session_id, {temp = true})
|
||||
|
||||
assert(sessions.get(session_id) ~= nil)
|
||||
assert(sessions.delete(session_id))
|
||||
assert_equal(nil, sessions.get(session_id))
|
||||
|
||||
-- Delete non-existent session
|
||||
assert(not sessions.delete("nonexistent"))
|
||||
end)
|
||||
|
||||
test("Session existence check", function()
|
||||
sessions.init("exists_test")
|
||||
|
||||
local session_id = sessions.generate_id()
|
||||
assert(not sessions.exists(session_id))
|
||||
|
||||
sessions.create(session_id, {test = true})
|
||||
assert(sessions.exists(session_id))
|
||||
|
||||
sessions.delete(session_id)
|
||||
assert(not sessions.exists(session_id))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- MULTI-STORE OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Multiple session stores", function()
|
||||
assert(sessions.init("store1", "store1.json"))
|
||||
assert(sessions.init("store2", "store2.json"))
|
||||
|
||||
local id1 = sessions.generate_id()
|
||||
local id2 = sessions.generate_id()
|
||||
|
||||
-- Create sessions in different stores
|
||||
assert(sessions.create(id1, {store = "store1"}, "store1"))
|
||||
assert(sessions.create(id2, {store = "store2"}, "store2"))
|
||||
|
||||
-- Verify isolation
|
||||
assert_equal(nil, sessions.get(id1, "store2"))
|
||||
assert_equal(nil, sessions.get(id2, "store1"))
|
||||
|
||||
-- Verify correct retrieval
|
||||
local s1 = sessions.get(id1, "store1")
|
||||
local s2 = sessions.get(id2, "store2")
|
||||
assert_equal("store1", s1.store)
|
||||
assert_equal("store2", s2.store)
|
||||
|
||||
os.remove("store1.json")
|
||||
os.remove("store2.json")
|
||||
end)
|
||||
|
||||
test("Default store behavior", function()
|
||||
sessions.reset()
|
||||
sessions.init("default_store")
|
||||
|
||||
local session_id = sessions.generate_id()
|
||||
sessions.create(session_id, {test = "default"})
|
||||
|
||||
-- Should work without specifying store
|
||||
local retrieved = sessions.get(session_id)
|
||||
assert_equal("default", retrieved.test)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- SESSION CLEANUP
|
||||
-- ======================================================================
|
||||
|
||||
test("Session cleanup", function()
|
||||
sessions.init("cleanup_test")
|
||||
|
||||
local old_session = sessions.generate_id()
|
||||
local new_session = sessions.generate_id()
|
||||
|
||||
sessions.create(old_session, {test = "old"}, "cleanup_test")
|
||||
sessions.create(new_session, {test = "new"}, "cleanup_test")
|
||||
|
||||
-- Wait to create age difference
|
||||
os.execute("sleep 2")
|
||||
|
||||
-- Access new session to update timestamp
|
||||
sessions.get(new_session, "cleanup_test")
|
||||
|
||||
-- Clean up sessions older than 1 second
|
||||
local deleted = sessions.cleanup(1, "cleanup_test")
|
||||
assert(deleted >= 1)
|
||||
|
||||
-- New session should remain (recently accessed)
|
||||
assert(sessions.get(new_session, "cleanup_test") ~= nil)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- SESSION LISTING AND COUNTING
|
||||
-- ======================================================================
|
||||
|
||||
test("Session listing and counting", function()
|
||||
sessions.init("list_test")
|
||||
|
||||
local ids = {}
|
||||
for i = 1, 5 do
|
||||
local id = sessions.generate_id()
|
||||
sessions.create(id, {index = i}, "list_test")
|
||||
table.insert(ids, id)
|
||||
end
|
||||
|
||||
local session_list = sessions.list("list_test")
|
||||
assert_equal(5, #session_list)
|
||||
|
||||
local count = sessions.count("list_test")
|
||||
assert_equal(5, count)
|
||||
|
||||
-- Verify all IDs are in the list
|
||||
local id_set = {}
|
||||
for _, id in ipairs(session_list) do
|
||||
id_set[id] = true
|
||||
end
|
||||
|
||||
for _, id in ipairs(ids) do
|
||||
assert(id_set[id], "Session ID should be in list")
|
||||
end
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- OBJECT-ORIENTED INTERFACE
|
||||
-- ======================================================================
|
||||
|
||||
test("OOP session store", function()
|
||||
local store = sessions.create_store("oop_sessions", "oop_test.json")
|
||||
|
||||
local session_id = sessions.generate_id()
|
||||
local data = {user = "oop_test", role = "admin"}
|
||||
|
||||
assert(store:create(session_id, data))
|
||||
|
||||
local retrieved = store:get(session_id)
|
||||
assert_equal("oop_test", retrieved.user)
|
||||
assert_equal("admin", retrieved.role)
|
||||
|
||||
retrieved.last_action = "login"
|
||||
assert(store:update(session_id, retrieved))
|
||||
|
||||
local updated = store:get(session_id)
|
||||
assert_equal("login", updated.last_action)
|
||||
|
||||
assert(store:exists(session_id))
|
||||
assert_equal(1, store:count())
|
||||
|
||||
local session_list = store:list()
|
||||
assert_equal(1, #session_list)
|
||||
assert_equal(session_id, session_list[1])
|
||||
|
||||
assert(store:delete(session_id))
|
||||
assert(not store:exists(session_id))
|
||||
|
||||
store:close()
|
||||
os.remove("oop_test.json")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- ERROR HANDLING
|
||||
-- ======================================================================
|
||||
|
||||
test("Session error handling", function()
|
||||
sessions.reset()
|
||||
|
||||
-- Try operations without initialization
|
||||
local success1, err1 = pcall(sessions.create, "test_id", {})
|
||||
assert(not success1)
|
||||
|
||||
local success2, err2 = pcall(sessions.get, "test_id")
|
||||
assert(not success2)
|
||||
|
||||
-- Initialize and test invalid inputs
|
||||
sessions.init("error_test")
|
||||
|
||||
-- Invalid session ID type
|
||||
local success3, err3 = pcall(sessions.create, 123, {})
|
||||
assert(not success3)
|
||||
|
||||
-- Invalid data type
|
||||
local success4, err4 = pcall(sessions.create, "test", "not_a_table")
|
||||
assert(not success4)
|
||||
|
||||
-- Invalid store name
|
||||
local success5, err5 = pcall(sessions.get, "test", 123)
|
||||
assert(not success5)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- DATA PERSISTENCE
|
||||
-- ======================================================================
|
||||
|
||||
test("Session persistence", function()
|
||||
sessions.init("persist_test", "persist_sessions.json")
|
||||
|
||||
local session_id = sessions.generate_id()
|
||||
local data = {
|
||||
user_id = 789,
|
||||
settings = {theme = "dark", lang = "en"},
|
||||
cart = {items = {"item1", "item2"}, total = 25.99}
|
||||
}
|
||||
|
||||
sessions.create(session_id, data, "persist_test")
|
||||
sessions.close("persist_test")
|
||||
|
||||
-- Reinitialize and verify data persists
|
||||
sessions.init("persist_test", "persist_sessions.json")
|
||||
local retrieved = sessions.get(session_id, "persist_test")
|
||||
|
||||
assert_equal(789, retrieved.user_id)
|
||||
assert_equal("dark", retrieved.settings.theme)
|
||||
assert_equal("en", retrieved.settings.lang)
|
||||
assert_equal(2, #retrieved.cart.items)
|
||||
assert_equal(25.99, retrieved.cart.total)
|
||||
|
||||
sessions.close("persist_test")
|
||||
os.remove("persist_sessions.json")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- COMPLEX DATA STRUCTURES
|
||||
-- ======================================================================
|
||||
|
||||
test("Complex session data", function()
|
||||
sessions.init("complex_test")
|
||||
|
||||
local session_id = sessions.generate_id()
|
||||
local complex_data = {
|
||||
user = {
|
||||
id = 456,
|
||||
profile = {
|
||||
name = "Jane Doe",
|
||||
email = "jane@example.com",
|
||||
preferences = {
|
||||
notifications = true,
|
||||
privacy = "friends_only"
|
||||
}
|
||||
}
|
||||
},
|
||||
activity = {
|
||||
pages_visited = {"home", "profile", "settings"},
|
||||
actions = {
|
||||
{type = "login", time = os.time()},
|
||||
{type = "view_page", page = "profile", time = os.time()}
|
||||
}
|
||||
},
|
||||
metadata = {
|
||||
ip = "192.168.1.1",
|
||||
user_agent = "Test Browser 1.0"
|
||||
}
|
||||
}
|
||||
|
||||
sessions.create(session_id, complex_data)
|
||||
local retrieved = sessions.get(session_id)
|
||||
|
||||
assert_equal(456, retrieved.user.id)
|
||||
assert_equal("Jane Doe", retrieved.user.profile.name)
|
||||
assert_equal(true, retrieved.user.profile.preferences.notifications)
|
||||
assert_equal(3, #retrieved.activity.pages_visited)
|
||||
assert_equal("login", retrieved.activity.actions[1].type)
|
||||
assert_equal("192.168.1.1", retrieved.metadata.ip)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- WORKFLOW INTEGRATION
|
||||
-- ======================================================================
|
||||
|
||||
test("Session workflow integration", function()
|
||||
sessions.init("workflow_test")
|
||||
|
||||
-- Simulate user workflow
|
||||
local session_id = sessions.generate_id()
|
||||
|
||||
-- User login
|
||||
sessions.create(session_id, {
|
||||
user_id = 999,
|
||||
username = "workflow_user",
|
||||
status = "logged_in"
|
||||
})
|
||||
|
||||
-- User adds items to cart
|
||||
local session = sessions.get(session_id)
|
||||
session.cart = {"item1", "item2"}
|
||||
session.cart_total = 19.99
|
||||
sessions.update(session_id, session)
|
||||
|
||||
-- User proceeds to checkout
|
||||
session = sessions.get(session_id)
|
||||
session.checkout_step = "payment"
|
||||
session.payment_method = "credit_card"
|
||||
sessions.update(session_id, session)
|
||||
|
||||
-- Verify final state
|
||||
local final_session = sessions.get(session_id)
|
||||
assert_equal(999, final_session.user_id)
|
||||
assert_equal("logged_in", final_session.status)
|
||||
assert_equal(2, #final_session.cart)
|
||||
assert_equal(19.99, final_session.cart_total)
|
||||
assert_equal("payment", final_session.checkout_step)
|
||||
assert_equal("credit_card", final_session.payment_method)
|
||||
|
||||
-- User completes order and logs out
|
||||
sessions.delete(session_id)
|
||||
assert_equal(nil, sessions.get(session_id))
|
||||
end)
|
||||
|
||||
-- Clean up test files
|
||||
os.remove("test_sessions.json")
|
||||
os.remove("test_sessions2.json")
|
||||
|
||||
summary()
|
||||
test_exit()
|
||||
778
tests/string.lua
778
tests/string.lua
@ -1,778 +0,0 @@
|
||||
require("tests")
|
||||
|
||||
-- Test data
|
||||
local test_string = "Hello, World!"
|
||||
local multi_line = "Line 1\nLine 2\nLine 3"
|
||||
local padded_string = " Hello World "
|
||||
local unicode_string = "café naïve résumé"
|
||||
local mixed_case = "hELLo WoRLd"
|
||||
|
||||
-- ======================================================================
|
||||
-- BASIC STRING OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("string.split", function()
|
||||
-- Basic split
|
||||
local parts = string.split("a,b,c,d", ",")
|
||||
assert_equal(4, #parts)
|
||||
assert_equal("a", parts[1])
|
||||
assert_equal("d", parts[4])
|
||||
|
||||
-- Empty delimiter (character split)
|
||||
local chars = string.split("abc", "")
|
||||
assert_equal(3, #chars)
|
||||
assert_equal("a", chars[1])
|
||||
assert_equal("c", chars[3])
|
||||
|
||||
-- Empty string
|
||||
local empty_parts = string.split("", ",")
|
||||
assert_equal(1, #empty_parts)
|
||||
assert_equal("", empty_parts[1])
|
||||
|
||||
-- No matches
|
||||
local no_match = string.split("hello", ",")
|
||||
assert_equal(1, #no_match)
|
||||
assert_equal("hello", no_match[1])
|
||||
|
||||
-- Multiple consecutive delimiters
|
||||
local multiple = string.split("a,,b,", ",")
|
||||
assert_equal(4, #multiple)
|
||||
assert_equal("a", multiple[1])
|
||||
assert_equal("", multiple[2])
|
||||
assert_equal("b", multiple[3])
|
||||
assert_equal("", multiple[4])
|
||||
end)
|
||||
|
||||
test("string.join", function()
|
||||
assert_equal("a-b-c", string.join({"a", "b", "c"}, "-"))
|
||||
assert_equal("abc", string.join({"a", "b", "c"}, ""))
|
||||
assert_equal("", string.join({}, ","))
|
||||
assert_equal("a", string.join({"a"}, ","))
|
||||
|
||||
-- Mixed types (should convert to string)
|
||||
assert_equal("1,2,3", string.join({1, 2, 3}, ","))
|
||||
end)
|
||||
|
||||
test("string.trim", function()
|
||||
assert_equal("Hello World", string.trim(padded_string))
|
||||
assert_equal("", string.trim(""))
|
||||
assert_equal("", string.trim(" "))
|
||||
assert_equal("hello", string.trim("hello"))
|
||||
assert_equal("a b", string.trim(" a b "))
|
||||
assert_equal("a b", string.trim("xxxa bxxx", "x"))
|
||||
end)
|
||||
|
||||
test("string.trim_left", function()
|
||||
assert_equal("Hello World ", string.trim_left(padded_string))
|
||||
assert_equal("hello", string.trim_left("hello"))
|
||||
assert_equal("", string.trim_left(""))
|
||||
|
||||
-- Custom cutset
|
||||
assert_equal("Helloxxx", string.trim_left("xxxHelloxxx", "x"))
|
||||
assert_equal("yHelloxxx", string.trim_left("xyHelloxxx", "x"))
|
||||
assert_equal("", string.trim_left("xxxx", "x"))
|
||||
end)
|
||||
|
||||
test("string.trim_right", function()
|
||||
assert_equal(" Hello World", string.trim_right(padded_string))
|
||||
assert_equal("hello", string.trim_right("hello"))
|
||||
assert_equal("", string.trim_right(""))
|
||||
|
||||
-- Custom cutset
|
||||
assert_equal("xxxHello", string.trim_right("xxxHelloxxx", "x"))
|
||||
assert_equal("", string.trim_right("xxxx", "x"))
|
||||
end)
|
||||
|
||||
test("string.upper", function()
|
||||
assert_equal("HELLO", string.upper("hello"))
|
||||
assert_equal("HELLO123!", string.upper("Hello123!"))
|
||||
assert_equal("", string.upper(""))
|
||||
assert_equal("ABC", string.upper("abc"))
|
||||
end)
|
||||
|
||||
test("string.lower", function()
|
||||
assert_equal("hello", string.lower("HELLO"))
|
||||
assert_equal("hello123!", string.lower("HELLO123!"))
|
||||
assert_equal("", string.lower(""))
|
||||
assert_equal("abc", string.lower("ABC"))
|
||||
end)
|
||||
|
||||
test("string.title", function()
|
||||
assert_equal("Hello World", string.title("hello world"))
|
||||
assert_equal("Hello World", string.title("HELLO WORLD"))
|
||||
assert_equal("", string.title(""))
|
||||
assert_equal("A", string.title("a"))
|
||||
assert_equal("Test_Case", string.title("test_case"))
|
||||
end)
|
||||
|
||||
test("string.contains", function()
|
||||
assert_equal(true, string.contains(test_string, "World"))
|
||||
assert_equal(false, string.contains(test_string, "world"))
|
||||
assert_equal(true, string.contains(test_string, ""))
|
||||
assert_equal(false, string.contains("", "a"))
|
||||
assert_equal(true, string.contains("abc", "b"))
|
||||
end)
|
||||
|
||||
test("string.starts_with", function()
|
||||
assert_equal(true, string.starts_with(test_string, "Hello"))
|
||||
assert_equal(false, string.starts_with(test_string, "hello"))
|
||||
assert_equal(true, string.starts_with(test_string, ""))
|
||||
assert_equal(false, string.starts_with("", "a"))
|
||||
assert_equal(true, string.starts_with("hello", "h"))
|
||||
end)
|
||||
|
||||
test("string.ends_with", function()
|
||||
assert_equal(true, string.ends_with(test_string, "!"))
|
||||
assert_equal(false, string.ends_with(test_string, "?"))
|
||||
assert_equal(true, string.ends_with(test_string, ""))
|
||||
assert_equal(false, string.ends_with("", "a"))
|
||||
assert_equal(true, string.ends_with("hello", "o"))
|
||||
end)
|
||||
|
||||
test("string.replace", function()
|
||||
assert_equal("hi world hi", string.replace("hello world hello", "hello", "hi"))
|
||||
assert_equal("hello", string.replace("hello", "xyz", "abc"))
|
||||
assert_equal("", string.replace("hello", "hello", ""))
|
||||
assert_equal("xyzxyz", string.replace("abcabc", "abc", "xyz"))
|
||||
|
||||
-- Special characters
|
||||
assert_equal("h*llo", string.replace("hello", "e", "*"))
|
||||
end)
|
||||
|
||||
test("string.replace_n", function()
|
||||
assert_equal("hi world hello", string.replace_n("hello world hello", "hello", "hi", 1))
|
||||
assert_equal("hi world hi", string.replace_n("hello world hello", "hello", "hi", 2))
|
||||
assert_equal("hello world hello", string.replace_n("hello world hello", "hello", "hi", 0))
|
||||
assert_equal("hello world hello", string.replace_n("hello world hello", "xyz", "hi", 5))
|
||||
end)
|
||||
|
||||
test("string.index", function()
|
||||
assert_equal(7, string.index("hello world", "world"))
|
||||
assert_equal(nil, string.index("hello world", "xyz"))
|
||||
assert_equal(1, string.index("hello", "h"))
|
||||
assert_equal(5, string.index("hello", "o"))
|
||||
assert_equal(1, string.index("hello", ""))
|
||||
end)
|
||||
|
||||
test("string.last_index", function()
|
||||
assert_equal(7, string.last_index("hello hello", "hello"))
|
||||
assert_equal(nil, string.last_index("hello world", "xyz"))
|
||||
assert_equal(1, string.last_index("hello", "h"))
|
||||
assert_equal(9, string.last_index("hello o o", "o"))
|
||||
end)
|
||||
|
||||
test("string.count", function()
|
||||
assert_equal(3, string.count("hello hello hello", "hello"))
|
||||
assert_equal(0, string.count("hello world", "xyz"))
|
||||
assert_equal(2, string.count("hello", "l"))
|
||||
assert_equal(6, string.count("hello", ""))
|
||||
end)
|
||||
|
||||
test("string.repeat_", function()
|
||||
assert_equal("abcabcabc", string.repeat_("abc", 3))
|
||||
assert_equal("", string.repeat_("x", 0))
|
||||
assert_equal("x", string.repeat_("x", 1))
|
||||
assert_equal("", string.repeat_("", 5))
|
||||
end)
|
||||
|
||||
test("string.reverse", function()
|
||||
assert_equal("olleh", string.reverse("hello"))
|
||||
assert_equal("", string.reverse(""))
|
||||
assert_equal("a", string.reverse("a"))
|
||||
assert_equal("dcba", string.reverse("abcd"))
|
||||
|
||||
-- Test with Go fallback for longer strings
|
||||
local long_str = string.rep("abc", 50)
|
||||
local reversed = string.reverse(long_str)
|
||||
assert_equal(string.length(long_str), string.length(reversed))
|
||||
end)
|
||||
|
||||
test("string.length", function()
|
||||
assert_equal(5, string.length("hello"))
|
||||
assert_equal(0, string.length(""))
|
||||
assert_equal(1, string.length("a"))
|
||||
|
||||
-- Unicode length
|
||||
assert_equal(4, string.length("café"))
|
||||
assert_equal(17, string.length(unicode_string))
|
||||
end)
|
||||
|
||||
test("string.byte_length", function()
|
||||
assert_equal(5, string.byte_length("hello"))
|
||||
assert_equal(0, string.byte_length(""))
|
||||
assert_equal(1, string.byte_length("a"))
|
||||
|
||||
-- Unicode byte length (UTF-8)
|
||||
assert_equal(5, string.byte_length("café")) -- é is 2 bytes
|
||||
assert_equal(21, string.byte_length(unicode_string)) -- accented chars are 2 bytes each
|
||||
end)
|
||||
|
||||
test("string.lines", function()
|
||||
local lines = string.lines(multi_line)
|
||||
assert_equal(3, #lines)
|
||||
assert_equal("Line 1", lines[1])
|
||||
assert_equal("Line 3", lines[3])
|
||||
|
||||
-- Empty string
|
||||
assert_table_equal({""}, string.lines(""))
|
||||
|
||||
-- Different line endings
|
||||
assert_table_equal({"a", "b"}, string.lines("a\nb"))
|
||||
assert_table_equal({"a", "b"}, string.lines("a\r\nb"))
|
||||
assert_table_equal({"a", "b"}, string.lines("a\rb"))
|
||||
|
||||
-- Trailing newline
|
||||
assert_table_equal({"a", "b"}, string.lines("a\nb\n"))
|
||||
end)
|
||||
|
||||
test("string.words", function()
|
||||
local words = string.words("Hello world test")
|
||||
assert_equal(3, #words)
|
||||
assert_equal("Hello", words[1])
|
||||
assert_equal("test", words[3])
|
||||
|
||||
-- Extra whitespace
|
||||
assert_table_equal({"Hello", "world"}, string.words(" Hello world "))
|
||||
|
||||
-- Empty string
|
||||
assert_table_equal({}, string.words(""))
|
||||
assert_table_equal({}, string.words(" "))
|
||||
|
||||
-- Single word
|
||||
assert_table_equal({"hello"}, string.words("hello"))
|
||||
end)
|
||||
|
||||
test("string.pad_left", function()
|
||||
assert_equal(" hi", string.pad_left("hi", 5))
|
||||
assert_equal("000hi", string.pad_left("hi", 5, "0"))
|
||||
assert_equal("hello", string.pad_left("hello", 3))
|
||||
assert_equal("hi", string.pad_left("hi", 2))
|
||||
assert_equal("hi", string.pad_left("hi", 0))
|
||||
|
||||
-- Unicode padding
|
||||
assert_equal(" café", string.pad_left("café", 6))
|
||||
end)
|
||||
|
||||
test("string.pad_right", function()
|
||||
assert_equal("hi ", string.pad_right("hi", 5))
|
||||
assert_equal("hi***", string.pad_right("hi", 5, "*"))
|
||||
assert_equal("hello", string.pad_right("hello", 3))
|
||||
assert_equal("hi", string.pad_right("hi", 2))
|
||||
assert_equal("hi", string.pad_right("hi", 0))
|
||||
|
||||
-- Unicode padding
|
||||
assert_equal("café ", string.pad_right("café", 6))
|
||||
end)
|
||||
|
||||
test("string.slice", function()
|
||||
assert_equal("ell", string.slice("hello", 2, 4))
|
||||
assert_equal("ello", string.slice("hello", 2))
|
||||
assert_equal("", string.slice("hello", 10))
|
||||
assert_equal("h", string.slice("hello", 1, 1))
|
||||
assert_equal("", string.slice("hello", 3, 2))
|
||||
|
||||
-- Negative indices
|
||||
assert_equal("lo", string.slice("hello", 4, -1))
|
||||
|
||||
-- Unicode slicing
|
||||
assert_equal("afé", string.slice("café", 2, 4))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- REGULAR EXPRESSIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("string.match", function()
|
||||
assert_equal(true, string.match("hello123", "%d+") ~= nil)
|
||||
assert_equal(false, string.match("hello", "%d+") ~= nil)
|
||||
assert_equal(true, string.match("hello", "^[a-z]+$") ~= nil)
|
||||
assert_equal(false, string.match("Hello", "^[a-z]+$") ~= nil)
|
||||
assert_equal(true, string.match("testing", "test") ~= nil)
|
||||
end)
|
||||
|
||||
test("string.find_match", function()
|
||||
assert_equal("123", string.find_match("hello123world", "%d+"))
|
||||
assert_equal(nil, string.find_match("hello", "%d+"))
|
||||
assert_equal("test", string.find_match("testing", "test"))
|
||||
assert_equal("world", string.find_match("hello world", "world"))
|
||||
end)
|
||||
|
||||
test("string.find_all", function()
|
||||
local matches = string.find_all("123 and 456 and 789", "%d+")
|
||||
assert_equal(3, #matches)
|
||||
assert_equal("123", matches[1])
|
||||
assert_equal("789", matches[3])
|
||||
|
||||
-- No matches
|
||||
assert_table_equal({}, string.find_all("hello", "%d+"))
|
||||
|
||||
-- Multiple matches
|
||||
local overlaps = string.find_all("test test test", "test")
|
||||
assert_equal(3, #overlaps)
|
||||
end)
|
||||
|
||||
test("string.gsub", function()
|
||||
assert_equal("helloXXXworldXXX", (string.gsub("hello123world456", "%d+", "XXX")))
|
||||
assert_equal("hello world", (string.gsub("hello world", "%s+", " ")))
|
||||
assert_equal("abc abc abc", (string.gsub("test abc test", "test", "abc")))
|
||||
|
||||
-- No matches
|
||||
assert_equal("hello", (string.gsub("hello", "%d+", "XXX")))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- TYPE CONVERSION & VALIDATION
|
||||
-- ======================================================================
|
||||
|
||||
test("string.to_number", function()
|
||||
assert_equal(123, string.to_number("123"))
|
||||
assert_equal(123.45, string.to_number("123.45"))
|
||||
assert_equal(-42, string.to_number("-42"))
|
||||
assert_equal(nil, string.to_number("not_a_number"))
|
||||
assert_equal(nil, string.to_number(""))
|
||||
assert_equal(42, string.to_number(" 42 "))
|
||||
end)
|
||||
|
||||
test("string.is_numeric", function()
|
||||
assert_equal(true, string.is_numeric("123"))
|
||||
assert_equal(true, string.is_numeric("123.45"))
|
||||
assert_equal(true, string.is_numeric("-42"))
|
||||
assert_equal(false, string.is_numeric("abc"))
|
||||
assert_equal(false, string.is_numeric(""))
|
||||
assert_equal(true, string.is_numeric(" 42 "))
|
||||
end)
|
||||
|
||||
test("string.is_alpha", function()
|
||||
assert_equal(true, string.is_alpha("hello"))
|
||||
assert_equal(false, string.is_alpha("hello123"))
|
||||
assert_equal(false, string.is_alpha(""))
|
||||
assert_equal(false, string.is_alpha("hello!"))
|
||||
assert_equal(true, string.is_alpha("ABC"))
|
||||
end)
|
||||
|
||||
test("string.is_alphanumeric", function()
|
||||
assert_equal(true, string.is_alphanumeric("hello123"))
|
||||
assert_equal(false, string.is_alphanumeric("hello!"))
|
||||
assert_equal(false, string.is_alphanumeric(""))
|
||||
assert_equal(true, string.is_alphanumeric("ABC123"))
|
||||
assert_equal(true, string.is_alphanumeric("hello"))
|
||||
end)
|
||||
|
||||
test("string.is_empty", function()
|
||||
assert_equal(true, string.is_empty(""))
|
||||
assert_equal(true, string.is_empty(nil))
|
||||
assert_equal(false, string.is_empty("hello"))
|
||||
assert_equal(false, string.is_empty(" "))
|
||||
end)
|
||||
|
||||
test("string.is_blank", function()
|
||||
assert_equal(true, string.is_blank(""))
|
||||
assert_equal(true, string.is_blank(" "))
|
||||
assert_equal(true, string.is_blank(nil))
|
||||
assert_equal(false, string.is_blank("hello"))
|
||||
assert_equal(false, string.is_blank(" a "))
|
||||
end)
|
||||
|
||||
test("string.is_utf8", function()
|
||||
assert_equal(true, string.is_utf8("hello"))
|
||||
assert_equal(true, string.is_utf8("café"))
|
||||
assert_equal(true, string.is_utf8(""))
|
||||
assert_equal(true, string.is_utf8(unicode_string))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- ADVANCED STRING OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("string.capitalize", function()
|
||||
assert_equal("Hello World", string.capitalize("hello world"))
|
||||
assert_equal("Hello", string.capitalize("hello"))
|
||||
assert_equal("", string.capitalize(""))
|
||||
assert_equal("A", string.capitalize("a"))
|
||||
end)
|
||||
|
||||
test("string.camel_case", function()
|
||||
assert_equal("helloWorld", string.camel_case("hello world"))
|
||||
assert_equal("hello", string.camel_case("hello"))
|
||||
assert_equal("", string.camel_case(""))
|
||||
assert_equal("testCaseExample", string.camel_case("test case example"))
|
||||
end)
|
||||
|
||||
test("string.pascal_case", function()
|
||||
assert_equal("HelloWorld", string.pascal_case("hello world"))
|
||||
assert_equal("Hello", string.pascal_case("hello"))
|
||||
assert_equal("", string.pascal_case(""))
|
||||
assert_equal("TestCaseExample", string.pascal_case("test case example"))
|
||||
end)
|
||||
|
||||
test("string.snake_case", function()
|
||||
assert_equal("hello_world", string.snake_case("Hello World"))
|
||||
assert_equal("hello", string.snake_case("hello"))
|
||||
assert_equal("", string.snake_case(""))
|
||||
assert_equal("test_case_example", string.snake_case("Test Case Example"))
|
||||
end)
|
||||
|
||||
test("string.kebab_case", function()
|
||||
assert_equal("hello-world", string.kebab_case("Hello World"))
|
||||
assert_equal("hello", string.kebab_case("hello"))
|
||||
assert_equal("", string.kebab_case(""))
|
||||
assert_equal("test-case-example", string.kebab_case("Test Case Example"))
|
||||
end)
|
||||
|
||||
test("string.screaming_snake_case", function()
|
||||
assert_equal("HELLO_WORLD", string.screaming_snake_case("hello world"))
|
||||
assert_equal("HELLO", string.screaming_snake_case("hello"))
|
||||
assert_equal("", string.screaming_snake_case(""))
|
||||
assert_equal("TEST_CASE", string.screaming_snake_case("test case"))
|
||||
end)
|
||||
|
||||
test("string.center", function()
|
||||
assert_equal(" hi ", string.center("hi", 6))
|
||||
assert_equal("**hi***", string.center("hi", 7, "*"))
|
||||
assert_equal("hello", string.center("hello", 3))
|
||||
assert_equal("hi", string.center("hi", 2))
|
||||
assert_equal(" hi ", string.center("hi", 4))
|
||||
end)
|
||||
|
||||
test("string.truncate", function()
|
||||
assert_equal("hello...", string.truncate("hello world", 8))
|
||||
assert_equal("hello>>", string.truncate("hello world", 8, ">>"))
|
||||
assert_equal("hi", string.truncate("hi", 10))
|
||||
assert_equal("...", string.truncate("hello", 3))
|
||||
assert_equal("h>", string.truncate("hello", 2, ">"))
|
||||
end)
|
||||
|
||||
test("string.wrap", function()
|
||||
local wrapped = string.wrap("The quick brown fox jumps over the lazy dog", 10)
|
||||
assert_equal("table", type(wrapped))
|
||||
assert(#wrapped > 1, "should wrap into multiple lines")
|
||||
|
||||
-- Each line should be within limit
|
||||
for _, line in ipairs(wrapped) do
|
||||
assert(string.length(line) <= 10, "line should be within width limit: " .. line)
|
||||
end
|
||||
|
||||
-- Empty string
|
||||
assert_table_equal({""}, string.wrap("", 10))
|
||||
|
||||
-- Single word longer than width
|
||||
local long_word = string.wrap("supercalifragilisticexpialidocious", 10)
|
||||
assert_equal(1, #long_word)
|
||||
end)
|
||||
|
||||
test("string.dedent", function()
|
||||
local indented = " line1\n line2\n line3"
|
||||
local dedented = string.dedent(indented)
|
||||
local lines = string.lines(dedented)
|
||||
|
||||
assert_equal("line1", lines[1])
|
||||
assert_equal("line2", lines[2])
|
||||
assert_equal("line3", lines[3])
|
||||
|
||||
-- Mixed indentation
|
||||
local mixed = " a\n b\n c"
|
||||
local mixed_result = string.dedent(mixed)
|
||||
local mixed_lines = string.lines(mixed_result)
|
||||
assert_equal("a", mixed_lines[1])
|
||||
assert_equal(" b", mixed_lines[2])
|
||||
assert_equal("c", mixed_lines[3])
|
||||
end)
|
||||
|
||||
test("string.escape", function()
|
||||
assert_equal("hello%.world", string.escape("hello.world"))
|
||||
assert_equal("a%+b%*c%?", string.escape("a+b*c?"))
|
||||
assert_equal("%[%]%(%)", string.escape("[]()"))
|
||||
assert_equal("hello", string.escape("hello"))
|
||||
end)
|
||||
|
||||
test("string.shell_quote", function()
|
||||
assert_equal("'hello world'", string.shell_quote("hello world"))
|
||||
assert_equal("'it'\"'\"'s great'", string.shell_quote("it's great"))
|
||||
assert_equal("hello", string.shell_quote("hello"))
|
||||
assert_equal("hello-world.txt", string.shell_quote("hello-world.txt"))
|
||||
end)
|
||||
|
||||
test("string.url_encode", function()
|
||||
assert_equal("hello%20world", string.url_encode("hello world"))
|
||||
assert_equal("hello", string.url_encode("hello"))
|
||||
assert_equal("hello%21%40%23", string.url_encode("hello!@#"))
|
||||
end)
|
||||
|
||||
test("string.url_decode", function()
|
||||
assert_equal("hello world", string.url_decode("hello%20world"))
|
||||
assert_equal("hello world", string.url_decode("hello+world"))
|
||||
assert_equal("hello", string.url_decode("hello"))
|
||||
assert_equal("hello!@#", string.url_decode("hello%21%40%23"))
|
||||
|
||||
-- Round trip
|
||||
local original = "hello world!@#"
|
||||
local encoded = string.url_encode(original)
|
||||
assert_equal(original, string.url_decode(encoded))
|
||||
end)
|
||||
|
||||
test("string.template", function()
|
||||
local context = {
|
||||
user = {name = "Jane", role = "admin"},
|
||||
count = 5
|
||||
}
|
||||
local template = "User ${user.name} (${user.role}) has ${count} items"
|
||||
local result = string.template(template, context)
|
||||
assert_equal("User Jane (admin) has 5 items", result)
|
||||
|
||||
-- Missing variables
|
||||
local incomplete = string.template("Hello ${name} and ${unknown}", {name = "John"})
|
||||
assert_equal("Hello John and ", incomplete)
|
||||
|
||||
-- Missing nested property
|
||||
local missing = string.template("${user.missing}", context)
|
||||
assert_equal("", missing)
|
||||
|
||||
-- No variables
|
||||
assert_equal("Hello world", string.template("Hello world", {}))
|
||||
end)
|
||||
|
||||
test("string.random", function()
|
||||
local random1 = string.random(10)
|
||||
local random2 = string.random(10)
|
||||
|
||||
assert_equal(10, string.length(random1))
|
||||
assert_equal(10, string.length(random2))
|
||||
assert(random1 ~= random2, "random strings should be different")
|
||||
|
||||
-- Custom charset
|
||||
local custom = string.random(5, "abc")
|
||||
assert_equal(5, string.length(custom))
|
||||
|
||||
-- Zero length
|
||||
assert_equal("", string.random(0))
|
||||
end)
|
||||
|
||||
test("string.slug", function()
|
||||
assert_equal("hello-world", string.slug("Hello World"))
|
||||
assert_equal("cafe-restaurant", string.slug("Café & Restaurant"))
|
||||
assert_equal("specialcharacters", string.slug("Special!@#$%Characters"))
|
||||
assert_equal("", string.slug(""))
|
||||
assert_equal("test", string.slug("test"))
|
||||
end)
|
||||
|
||||
test("string.iequals", function()
|
||||
assert_equal(true, string.iequals("Hello", "HELLO"))
|
||||
assert_equal(true, string.iequals("hello", "hello"))
|
||||
assert_equal(false, string.iequals("Hello", "world"))
|
||||
assert_equal(true, string.iequals("", ""))
|
||||
end)
|
||||
|
||||
test("string.is_whitespace", function()
|
||||
assert_equal(true, string.is_whitespace(" "))
|
||||
assert_equal(true, string.is_whitespace(""))
|
||||
assert_equal(true, string.is_whitespace("\t\n\r "))
|
||||
assert_equal(false, string.is_whitespace("hello"))
|
||||
assert_equal(false, string.is_whitespace(" a "))
|
||||
end)
|
||||
|
||||
test("string.strip_whitespace", function()
|
||||
assert_equal("hello", string.strip_whitespace("h e l l o"))
|
||||
assert_equal("helloworld", string.strip_whitespace("hello world"))
|
||||
assert_equal("", string.strip_whitespace(" "))
|
||||
assert_equal("abc", string.strip_whitespace("a\tb\nc"))
|
||||
end)
|
||||
|
||||
test("string.normalize_whitespace", function()
|
||||
assert_equal("hello world test", string.normalize_whitespace("hello world test"))
|
||||
assert_equal("a b c", string.normalize_whitespace(" a b c "))
|
||||
assert_equal("", string.normalize_whitespace(" "))
|
||||
assert_equal("hello", string.normalize_whitespace("hello"))
|
||||
end)
|
||||
|
||||
test("string.extract_numbers", function()
|
||||
local numbers = string.extract_numbers("The price is $123.45 and tax is 8.5%")
|
||||
assert_equal(2, #numbers)
|
||||
assert_close(123.45, numbers[1])
|
||||
assert_close(8.5, numbers[2])
|
||||
|
||||
local negative = string.extract_numbers("Temperature: -15.5 degrees")
|
||||
assert_equal(1, #negative)
|
||||
assert_close(-15.5, negative[1])
|
||||
|
||||
-- No numbers
|
||||
assert_table_equal({}, string.extract_numbers("hello world"))
|
||||
end)
|
||||
|
||||
test("string.remove_accents", function()
|
||||
assert_equal("cafe", string.remove_accents("café"))
|
||||
assert_equal("resume", string.remove_accents("résumé"))
|
||||
assert_equal("naive", string.remove_accents("naïve"))
|
||||
assert_equal("hello", string.remove_accents("hello"))
|
||||
assert_equal("", string.remove_accents(""))
|
||||
|
||||
-- Mixed case
|
||||
assert_equal("Cafe", string.remove_accents("Café"))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- EDGE CASES AND ERROR HANDLING
|
||||
-- ======================================================================
|
||||
|
||||
test("Empty String Handling", function()
|
||||
assert_table_equal({""}, string.split("", ","))
|
||||
assert_equal("", string.join({}, ","))
|
||||
assert_equal("", string.trim(""))
|
||||
assert_equal("", string.reverse(""))
|
||||
assert_equal("", string.repeat_("", 5))
|
||||
assert_table_equal({""}, string.lines(""))
|
||||
assert_table_equal({}, string.words(""))
|
||||
assert_equal("", string.slice("", 1, 5))
|
||||
assert_equal("", string.slug(""))
|
||||
end)
|
||||
|
||||
test("Large String Handling", function()
|
||||
local large_string = string.rep("test ", 1000)
|
||||
|
||||
assert_equal(5000, string.length(large_string))
|
||||
assert_equal(1000, string.count(large_string, "test"))
|
||||
|
||||
local words = string.words(large_string)
|
||||
assert_equal(1000, #words)
|
||||
|
||||
local trimmed = string.trim(large_string)
|
||||
assert_equal(true, string.ends_with(trimmed, "test"))
|
||||
|
||||
-- Should use Go for reverse on large strings
|
||||
local reversed = string.reverse(large_string)
|
||||
assert_equal(string.length(large_string), string.length(reversed))
|
||||
end)
|
||||
|
||||
test("Unicode Handling", function()
|
||||
local unicode_str = "Hello 🌍 World 🚀"
|
||||
|
||||
assert_equal(true, string.contains(unicode_str, "🌍"))
|
||||
assert_equal(true, string.starts_with(unicode_str, "Hello"))
|
||||
assert_equal(true, string.ends_with(unicode_str, "🚀"))
|
||||
|
||||
local parts = string.split(unicode_str, " ")
|
||||
assert_equal(4, #parts)
|
||||
assert_equal("🌍", parts[2])
|
||||
|
||||
-- Length should count Unicode characters, not bytes
|
||||
assert_equal(15, string.length(unicode_str))
|
||||
assert(string.byte_length(unicode_str) > 15, "byte length should be larger")
|
||||
end)
|
||||
|
||||
test("Boundary Conditions", function()
|
||||
-- Index boundary tests
|
||||
assert_equal("", string.slice("hello", 6))
|
||||
assert_equal("", string.slice("hello", 1, 0))
|
||||
assert_equal("hello", string.slice("hello", 1, 100))
|
||||
|
||||
-- Padding boundary tests
|
||||
assert_equal("hi", string.pad_left("hi", 0))
|
||||
assert_equal("hi", string.pad_right("hi", 1))
|
||||
|
||||
-- Count boundary tests
|
||||
assert_equal(0, string.count("", "a"))
|
||||
assert_equal(1, string.count("", ""))
|
||||
|
||||
-- Replace boundary tests
|
||||
assert_equal("", string.replace_n("hello", "hello", "", 1))
|
||||
assert_equal("hello", string.replace_n("hello", "x", "y", 0))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- INTEGRATION TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("String Processing Pipeline", function()
|
||||
local messy_input = " HELLO, world! How ARE you? "
|
||||
|
||||
-- Clean and normalize
|
||||
local cleaned = string.normalize_whitespace(string.trim(messy_input))
|
||||
local lowered = string.lower(cleaned)
|
||||
local words = string.words(lowered)
|
||||
local filtered = {}
|
||||
|
||||
for _, word in ipairs(words) do
|
||||
local clean_word = string.gsub(word, "[%p]", "") -- Remove punctuation using Lua pattern
|
||||
if string.length(clean_word) > 2 then
|
||||
table.insert(filtered, clean_word)
|
||||
end
|
||||
end
|
||||
|
||||
local result = string.join(filtered, "-")
|
||||
assert_equal("hello-world-how-are-you", result)
|
||||
end)
|
||||
|
||||
test("Text Analysis", function()
|
||||
local text = "The quick brown fox jumps over the lazy dog. The dog was sleeping."
|
||||
|
||||
local word_count = #string.words(text)
|
||||
local sentence_count = string.count(text, ".")
|
||||
local the_count = string.count(string.lower(text), "the")
|
||||
|
||||
assert_equal(13, word_count)
|
||||
assert_equal(2, sentence_count)
|
||||
assert_equal(3, the_count)
|
||||
|
||||
-- Template processing
|
||||
local template = "Found ${word_count} words and ${the_count} instances of 'the'"
|
||||
local vars = {word_count = word_count, the_count = the_count}
|
||||
local summary = string.template(template, vars)
|
||||
|
||||
assert_equal("Found 13 words and 3 instances of 'the'", summary)
|
||||
end)
|
||||
|
||||
test("Case Conversion Chain", function()
|
||||
local original = "Hello World Test Case"
|
||||
|
||||
-- Test conversion chain
|
||||
local snake = string.snake_case(original)
|
||||
local camel = string.camel_case(original)
|
||||
local pascal = string.pascal_case(original)
|
||||
local kebab = string.kebab_case(original)
|
||||
local screaming = string.screaming_snake_case(original)
|
||||
|
||||
assert_equal("hello_world_test_case", snake)
|
||||
assert_equal("helloWorldTestCase", camel)
|
||||
assert_equal("HelloWorldTestCase", pascal)
|
||||
assert_equal("hello-world-test-case", kebab)
|
||||
assert_equal("HELLO_WORLD_TEST_CASE", screaming)
|
||||
|
||||
-- Convert back should be similar
|
||||
local back_to_words = string.split(snake, "_")
|
||||
local rejoined = string.join(back_to_words, " ")
|
||||
local capitalized = string.title(rejoined)
|
||||
|
||||
assert_equal("Hello World Test Case", capitalized)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- PERFORMANCE TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Performance Characteristics", function()
|
||||
local large_text = string.rep("The quick brown fox jumps over the lazy dog. ", 1000)
|
||||
|
||||
-- Test that operations complete in reasonable time
|
||||
local start = os.clock()
|
||||
|
||||
local words = string.words(large_text)
|
||||
local lines = string.lines(large_text)
|
||||
local replaced = string.replace(large_text, "fox", "cat")
|
||||
local parts = string.split(large_text, " ")
|
||||
local reversed = string.reverse(large_text)
|
||||
|
||||
local elapsed = os.clock() - start
|
||||
|
||||
-- Verify results
|
||||
assert(#words > 8000, "should extract many words")
|
||||
assert(string.contains(replaced, "cat"), "replacement should work")
|
||||
assert(#parts > 8000, "should split into many parts")
|
||||
assert_equal(string.length(large_text), string.length(reversed))
|
||||
|
||||
print(string.format(" Processed %d characters in %.3fs", string.length(large_text), elapsed))
|
||||
|
||||
-- Performance should be reasonable (< 1 second for this test)
|
||||
assert(elapsed < 1.0, "operations should complete quickly")
|
||||
end)
|
||||
|
||||
summary()
|
||||
test_exit()
|
||||
911
tests/table.lua
911
tests/table.lua
@ -1,911 +0,0 @@
|
||||
require("tests")
|
||||
|
||||
-- Test data
|
||||
local simple_array = {1, 2, 3, 4, 5}
|
||||
local simple_table = {a = 1, b = 2, c = 3}
|
||||
local mixed_table = {1, 2, a = "hello", b = "world"}
|
||||
local nested_table = {
|
||||
a = {x = 1, y = 2},
|
||||
b = {x = 3, y = 4},
|
||||
c = {1, 2, 3}
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- BUILT-IN TABLE FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Insert Operations", function()
|
||||
local t = {1, 2, 3}
|
||||
|
||||
table.insert(t, 4)
|
||||
assert_equal(4, #t)
|
||||
assert_equal(4, t[4])
|
||||
|
||||
table.insert(t, 2, "inserted")
|
||||
assert_equal(5, #t)
|
||||
assert_equal("inserted", t[2])
|
||||
assert_equal(2, t[3])
|
||||
end)
|
||||
|
||||
test("Table Remove Operations", function()
|
||||
local t = {1, 2, 3, 4, 5}
|
||||
|
||||
local removed = table.remove(t)
|
||||
assert_equal(5, removed)
|
||||
assert_equal(4, #t)
|
||||
|
||||
removed = table.remove(t, 2)
|
||||
assert_equal(2, removed)
|
||||
assert_equal(3, #t)
|
||||
assert_equal(3, t[2])
|
||||
end)
|
||||
|
||||
test("Table Concat", function()
|
||||
local t = {"hello", "world", "test"}
|
||||
assert_equal("helloworldtest", table.concat(t))
|
||||
assert_equal("hello,world,test", table.concat(t, ","))
|
||||
assert_equal("world,test", table.concat(t, ",", 2))
|
||||
assert_equal("world", table.concat(t, ",", 2, 2))
|
||||
end)
|
||||
|
||||
test("Table Sort", function()
|
||||
local t = {3, 1, 4, 1, 5}
|
||||
table.sort(t)
|
||||
assert_table_equal({1, 1, 3, 4, 5}, t)
|
||||
|
||||
local t2 = {"c", "a", "b"}
|
||||
table.sort(t2)
|
||||
assert_table_equal({"a", "b", "c"}, t2)
|
||||
|
||||
local t3 = {3, 1, 4, 1, 5}
|
||||
table.sort(t3, function(a, b) return a > b end)
|
||||
assert_table_equal({5, 4, 3, 1, 1}, t3)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- BASIC TABLE OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Length and Size", function()
|
||||
assert_equal(5, table.length(simple_array))
|
||||
assert_equal(0, table.length({}))
|
||||
|
||||
assert_equal(3, table.size(simple_table))
|
||||
assert_equal(4, table.size(mixed_table))
|
||||
assert_equal(0, table.size({}))
|
||||
end)
|
||||
|
||||
test("Table Empty Check", function()
|
||||
assert_equal(true, table.is_empty({}))
|
||||
assert_equal(false, table.is_empty(simple_array))
|
||||
assert_equal(false, table.is_empty(simple_table))
|
||||
end)
|
||||
|
||||
test("Table Array Check", function()
|
||||
assert_equal(true, table.is_array(simple_array))
|
||||
assert_equal(true, table.is_array({}))
|
||||
assert_equal(false, table.is_array(simple_table))
|
||||
assert_equal(false, table.is_array(mixed_table))
|
||||
|
||||
assert_equal(true, table.is_array({1, 2, 3}))
|
||||
assert_equal(false, table.is_array({1, 2, nil, 4}))
|
||||
assert_equal(false, table.is_array({[0] = 1, [1] = 2}))
|
||||
end)
|
||||
|
||||
test("Table Clear", function()
|
||||
local t = table.clone(simple_table)
|
||||
table.clear(t)
|
||||
assert_equal(true, table.is_empty(t))
|
||||
end)
|
||||
|
||||
test("Table Clone", function()
|
||||
local cloned = table.clone(simple_table)
|
||||
assert_table_equal(simple_table, cloned)
|
||||
|
||||
-- Modify original shouldn't affect clone
|
||||
simple_table.new_key = "test"
|
||||
assert_equal(nil, cloned.new_key)
|
||||
simple_table.new_key = nil
|
||||
end)
|
||||
|
||||
test("Table Deep Copy", function()
|
||||
local copied = table.deep_copy(nested_table)
|
||||
assert_table_equal(nested_table, copied)
|
||||
|
||||
-- Modify nested part shouldn't affect copy
|
||||
nested_table.a.z = 99
|
||||
assert_equal(nil, copied.a.z)
|
||||
nested_table.a.z = nil
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- SEARCHING AND FINDING
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Contains", function()
|
||||
assert_equal(true, table.contains(simple_array, 3))
|
||||
assert_equal(false, table.contains(simple_array, 6))
|
||||
assert_equal(true, table.contains(simple_table, 2))
|
||||
assert_equal(false, table.contains(simple_table, "hello"))
|
||||
end)
|
||||
|
||||
test("Table Index Of", function()
|
||||
assert_equal(3, table.index_of(simple_array, 3))
|
||||
assert_equal(nil, table.index_of(simple_array, 6))
|
||||
assert_equal("b", table.index_of(simple_table, 2))
|
||||
assert_equal(nil, table.index_of(simple_table, "hello"))
|
||||
end)
|
||||
|
||||
test("Table Find", function()
|
||||
local value, key = table.find(simple_array, function(v) return v > 3 end)
|
||||
assert_equal(4, value)
|
||||
assert_equal(4, key)
|
||||
|
||||
local value2, key2 = table.find(simple_table, function(v, k) return k == "b" end)
|
||||
assert_equal(2, value2)
|
||||
assert_equal("b", key2)
|
||||
|
||||
local value3 = table.find(simple_array, function(v) return v > 10 end)
|
||||
assert_equal(nil, value3)
|
||||
end)
|
||||
|
||||
test("Table Find Index", function()
|
||||
local idx = table.find_index(simple_array, function(v) return v > 3 end)
|
||||
assert_equal(4, idx)
|
||||
|
||||
local idx2 = table.find_index(simple_table, function(v, k) return k == "c" end)
|
||||
assert_equal("c", idx2)
|
||||
|
||||
local idx3 = table.find_index(simple_array, function(v) return v > 10 end)
|
||||
assert_equal(nil, idx3)
|
||||
end)
|
||||
|
||||
test("Table Count", function()
|
||||
local arr = {1, 2, 3, 2, 4, 2}
|
||||
assert_equal(3, table.count(arr, 2))
|
||||
assert_equal(0, table.count(arr, 5))
|
||||
|
||||
assert_equal(2, table.count(arr, function(v) return v > 2 end))
|
||||
assert_equal(2, table.count(arr, function(v) return v == 1 or v == 4 end))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- FILTERING AND MAPPING
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Filter", function()
|
||||
local evens = table.filter(simple_array, function(v) return v % 2 == 0 end)
|
||||
assert_table_equal({2, 4}, evens)
|
||||
|
||||
local filtered_table = table.filter(simple_table, function(v) return v > 1 end)
|
||||
assert_equal(2, table.size(filtered_table))
|
||||
assert_equal(2, filtered_table.b)
|
||||
assert_equal(3, filtered_table.c)
|
||||
end)
|
||||
|
||||
test("Table Reject", function()
|
||||
local odds = table.reject(simple_array, function(v) return v % 2 == 0 end)
|
||||
assert_table_equal({1, 3, 5}, odds)
|
||||
end)
|
||||
|
||||
test("Table Map", function()
|
||||
local doubled = table.map(simple_array, function(v) return v * 2 end)
|
||||
assert_table_equal({2, 4, 6, 8, 10}, doubled)
|
||||
|
||||
local mapped_table = table.map(simple_table, function(v) return v + 10 end)
|
||||
assert_equal(11, mapped_table.a)
|
||||
assert_equal(12, mapped_table.b)
|
||||
assert_equal(13, mapped_table.c)
|
||||
end)
|
||||
|
||||
test("Table Map Values", function()
|
||||
local incremented = table.map_values(simple_table, function(v) return v + 1 end)
|
||||
assert_equal(2, incremented.a)
|
||||
assert_equal(3, incremented.b)
|
||||
assert_equal(4, incremented.c)
|
||||
end)
|
||||
|
||||
test("Table Map Keys", function()
|
||||
local prefixed = table.map_keys(simple_table, function(k) return "key_" .. k end)
|
||||
assert_equal(1, prefixed.key_a)
|
||||
assert_equal(2, prefixed.key_b)
|
||||
assert_equal(3, prefixed.key_c)
|
||||
assert_equal(nil, prefixed.a)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- REDUCING AND AGGREGATING
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Reduce", function()
|
||||
local sum = table.reduce(simple_array, function(acc, v) return acc + v end)
|
||||
assert_equal(15, sum)
|
||||
|
||||
local sum_with_initial = table.reduce(simple_array, function(acc, v) return acc + v end, 10)
|
||||
assert_equal(25, sum_with_initial)
|
||||
|
||||
local product = table.reduce({2, 3, 4}, function(acc, v) return acc * v end)
|
||||
assert_equal(24, product)
|
||||
end)
|
||||
|
||||
test("Table Fold", function()
|
||||
local sum = table.fold(simple_array, function(acc, v) return acc + v end, 0)
|
||||
assert_equal(15, sum)
|
||||
|
||||
local concatenated = table.fold({"a", "b", "c"}, function(acc, v) return acc .. v end, "")
|
||||
assert_equal("abc", concatenated)
|
||||
end)
|
||||
|
||||
test("Table Math Operations", function()
|
||||
assert_equal(15, table.sum(simple_array))
|
||||
assert_equal(120, table.product(simple_array))
|
||||
assert_equal(1, table.min(simple_array))
|
||||
assert_equal(5, table.max(simple_array))
|
||||
assert_equal(3, table.average(simple_array))
|
||||
|
||||
local floats = {1.5, 2.5, 3.0}
|
||||
assert_close(7.0, table.sum(floats))
|
||||
assert_close(2.33333, table.average(floats), 0.001)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- BOOLEAN OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Table All", function()
|
||||
assert_equal(true, table.all({true, true, true}))
|
||||
assert_equal(false, table.all({true, false, true}))
|
||||
assert_equal(true, table.all({}))
|
||||
|
||||
assert_equal(true, table.all(simple_array, function(v) return v > 0 end))
|
||||
assert_equal(false, table.all(simple_array, function(v) return v > 3 end))
|
||||
end)
|
||||
|
||||
test("Table Any", function()
|
||||
assert_equal(true, table.any({false, true, false}))
|
||||
assert_equal(false, table.any({false, false, false}))
|
||||
assert_equal(false, table.any({}))
|
||||
|
||||
assert_equal(true, table.any(simple_array, function(v) return v > 3 end))
|
||||
assert_equal(false, table.any(simple_array, function(v) return v > 10 end))
|
||||
end)
|
||||
|
||||
test("Table None", function()
|
||||
assert_equal(true, table.none({false, false, false}))
|
||||
assert_equal(false, table.none({false, true, false}))
|
||||
assert_equal(true, table.none({}))
|
||||
|
||||
assert_equal(false, table.none(simple_array, function(v) return v > 3 end))
|
||||
assert_equal(true, table.none(simple_array, function(v) return v > 10 end))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- SET OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Unique", function()
|
||||
local duplicates = {1, 2, 2, 3, 3, 3, 4}
|
||||
local unique = table.unique(duplicates)
|
||||
assert_table_equal({1, 2, 3, 4}, unique)
|
||||
|
||||
local empty_unique = table.unique({})
|
||||
assert_table_equal({}, empty_unique)
|
||||
end)
|
||||
|
||||
test("Table Intersection", function()
|
||||
local arr1 = {1, 2, 3, 4}
|
||||
local arr2 = {3, 4, 5, 6}
|
||||
local intersect = table.intersection(arr1, arr2)
|
||||
assert_equal(2, #intersect)
|
||||
assert_equal(true, table.contains(intersect, 3))
|
||||
assert_equal(true, table.contains(intersect, 4))
|
||||
end)
|
||||
|
||||
test("Table Union", function()
|
||||
local arr1 = {1, 2, 3}
|
||||
local arr2 = {3, 4, 5}
|
||||
local union = table.union(arr1, arr2)
|
||||
assert_equal(5, #union)
|
||||
for i = 1, 5 do
|
||||
assert_equal(true, table.contains(union, i))
|
||||
end
|
||||
end)
|
||||
|
||||
test("Table Difference", function()
|
||||
local arr1 = {1, 2, 3, 4, 5}
|
||||
local arr2 = {3, 4}
|
||||
local diff = table.difference(arr1, arr2)
|
||||
assert_table_equal({1, 2, 5}, diff)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- ARRAY OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Reverse", function()
|
||||
local reversed = table.reverse(simple_array)
|
||||
assert_table_equal({5, 4, 3, 2, 1}, reversed)
|
||||
|
||||
local single = table.reverse({42})
|
||||
assert_table_equal({42}, single)
|
||||
|
||||
local empty = table.reverse({})
|
||||
assert_table_equal({}, empty)
|
||||
end)
|
||||
|
||||
test("Table Shuffle", function()
|
||||
local shuffled = table.shuffle(simple_array)
|
||||
assert_equal(5, #shuffled)
|
||||
|
||||
-- All original elements should still be present
|
||||
for _, v in ipairs(simple_array) do
|
||||
assert_equal(true, table.contains(shuffled, v))
|
||||
end
|
||||
|
||||
-- Should be same length
|
||||
assert_equal(#simple_array, #shuffled)
|
||||
end)
|
||||
|
||||
test("Table Rotate", function()
|
||||
local arr = {1, 2, 3, 4, 5}
|
||||
|
||||
local rotated_right = table.rotate(arr, 2)
|
||||
assert_table_equal({4, 5, 1, 2, 3}, rotated_right)
|
||||
|
||||
local rotated_left = table.rotate(arr, -2)
|
||||
assert_table_equal({3, 4, 5, 1, 2}, rotated_left)
|
||||
|
||||
local no_rotation = table.rotate(arr, 0)
|
||||
assert_table_equal(arr, no_rotation)
|
||||
|
||||
local full_rotation = table.rotate(arr, 5)
|
||||
assert_table_equal(arr, full_rotation)
|
||||
end)
|
||||
|
||||
test("Table Slice", function()
|
||||
local sliced = table.slice(simple_array, 2, 4)
|
||||
assert_table_equal({2, 3, 4}, sliced)
|
||||
|
||||
local from_start = table.slice(simple_array, 1, 3)
|
||||
assert_table_equal({1, 2, 3}, from_start)
|
||||
|
||||
local to_end = table.slice(simple_array, 3)
|
||||
assert_table_equal({3, 4, 5}, to_end)
|
||||
|
||||
local negative_indices = table.slice(simple_array, -3, -1)
|
||||
assert_table_equal({3, 4, 5}, negative_indices)
|
||||
end)
|
||||
|
||||
test("Table Splice", function()
|
||||
local arr = {1, 2, 3, 4, 5}
|
||||
|
||||
-- Remove elements
|
||||
local removed = table.splice(arr, 2, 2)
|
||||
assert_table_equal({2, 3}, removed)
|
||||
assert_table_equal({1, 4, 5}, arr)
|
||||
|
||||
-- Insert elements
|
||||
arr = {1, 2, 3, 4, 5}
|
||||
removed = table.splice(arr, 3, 0, "a", "b")
|
||||
assert_table_equal({}, removed)
|
||||
assert_table_equal({1, 2, "a", "b", 3, 4, 5}, arr)
|
||||
|
||||
-- Replace elements
|
||||
arr = {1, 2, 3, 4, 5}
|
||||
removed = table.splice(arr, 2, 2, "x", "y", "z")
|
||||
assert_table_equal({2, 3}, removed)
|
||||
assert_table_equal({1, "x", "y", "z", 4, 5}, arr)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- SORTING HELPERS
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Sort By", function()
|
||||
local people = {
|
||||
{name = "Alice", age = 30},
|
||||
{name = "Bob", age = 25},
|
||||
{name = "Charlie", age = 35}
|
||||
}
|
||||
|
||||
local sorted_by_age = table.sort_by(people, function(p) return p.age end)
|
||||
assert_equal("Bob", sorted_by_age[1].name)
|
||||
assert_equal("Charlie", sorted_by_age[3].name)
|
||||
|
||||
local sorted_by_name = table.sort_by(people, function(p) return p.name end)
|
||||
assert_equal("Alice", sorted_by_name[1].name)
|
||||
assert_equal("Charlie", sorted_by_name[3].name)
|
||||
end)
|
||||
|
||||
test("Table Is Sorted", function()
|
||||
assert_equal(true, table.is_sorted({1, 2, 3, 4, 5}))
|
||||
assert_equal(false, table.is_sorted({1, 3, 2, 4, 5}))
|
||||
assert_equal(true, table.is_sorted({}))
|
||||
assert_equal(true, table.is_sorted({42}))
|
||||
|
||||
assert_equal(true, table.is_sorted({5, 4, 3, 2, 1}, function(a, b) return a > b end))
|
||||
assert_equal(false, table.is_sorted({1, 2, 3, 4, 5}, function(a, b) return a > b end))
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Keys and Values", function()
|
||||
local keys = table.keys(simple_table)
|
||||
assert_equal(3, #keys)
|
||||
assert_equal(true, table.contains(keys, "a"))
|
||||
assert_equal(true, table.contains(keys, "b"))
|
||||
assert_equal(true, table.contains(keys, "c"))
|
||||
|
||||
local values = table.values(simple_table)
|
||||
assert_equal(3, #values)
|
||||
assert_equal(true, table.contains(values, 1))
|
||||
assert_equal(true, table.contains(values, 2))
|
||||
assert_equal(true, table.contains(values, 3))
|
||||
end)
|
||||
|
||||
test("Table Pairs", function()
|
||||
local pairs_list = table.pairs({a = 1, b = 2})
|
||||
assert_equal(2, #pairs_list)
|
||||
|
||||
-- Should contain key-value pairs
|
||||
local found_a, found_b = false, false
|
||||
for _, pair in ipairs(pairs_list) do
|
||||
if pair[1] == "a" and pair[2] == 1 then found_a = true end
|
||||
if pair[1] == "b" and pair[2] == 2 then found_b = true end
|
||||
end
|
||||
assert_equal(true, found_a)
|
||||
assert_equal(true, found_b)
|
||||
end)
|
||||
|
||||
test("Table Merge", function()
|
||||
local t1 = {a = 1, b = 2}
|
||||
local t2 = {c = 3, d = 4}
|
||||
local t3 = {b = 20, e = 5}
|
||||
|
||||
local merged = table.merge(t1, t2, t3)
|
||||
assert_equal(5, table.size(merged))
|
||||
assert_equal(1, merged.a)
|
||||
assert_equal(20, merged.b) -- Last one wins
|
||||
assert_equal(3, merged.c)
|
||||
assert_equal(4, merged.d)
|
||||
assert_equal(5, merged.e)
|
||||
end)
|
||||
|
||||
test("Table Extend", function()
|
||||
local t1 = {a = 1, b = 2}
|
||||
local t2 = {c = 3, d = 4}
|
||||
|
||||
local extended = table.extend(t1, t2)
|
||||
assert_equal(t1, extended) -- Should return t1
|
||||
assert_equal(4, table.size(t1))
|
||||
assert_equal(3, t1.c)
|
||||
assert_equal(4, t1.d)
|
||||
end)
|
||||
|
||||
test("Table Invert", function()
|
||||
local inverted = table.invert(simple_table)
|
||||
assert_equal("a", inverted[1])
|
||||
assert_equal("b", inverted[2])
|
||||
assert_equal("c", inverted[3])
|
||||
end)
|
||||
|
||||
test("Table Pick and Omit", function()
|
||||
local big_table = {a = 1, b = 2, c = 3, d = 4, e = 5}
|
||||
|
||||
local picked = table.pick(big_table, "a", "c", "e")
|
||||
assert_equal(3, table.size(picked))
|
||||
assert_equal(1, picked.a)
|
||||
assert_equal(3, picked.c)
|
||||
assert_equal(5, picked.e)
|
||||
assert_equal(nil, picked.b)
|
||||
assert_equal(nil, picked.d)
|
||||
|
||||
local omitted = table.omit(big_table, "b", "d")
|
||||
assert_equal(3, table.size(omitted))
|
||||
assert_equal(1, omitted.a)
|
||||
assert_equal(3, omitted.c)
|
||||
assert_equal(5, omitted.e)
|
||||
assert_equal(nil, omitted.b)
|
||||
assert_equal(nil, omitted.d)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- DEEP OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Deep Equals", function()
|
||||
local t1 = {a = {x = 1, y = 2}, b = {1, 2, 3}}
|
||||
local t2 = {a = {x = 1, y = 2}, b = {1, 2, 3}}
|
||||
local t3 = {a = {x = 1, y = 3}, b = {1, 2, 3}}
|
||||
|
||||
assert_equal(true, table.deep_equals(t1, t2))
|
||||
assert_equal(false, table.deep_equals(t1, t3))
|
||||
assert_equal(true, table.deep_equals({}, {}))
|
||||
assert_equal(false, table.deep_equals({a = 1}, {a = 1, b = 2}))
|
||||
end)
|
||||
|
||||
test("Table Flatten", function()
|
||||
local nested = {{1, 2}, {3, 4}, {5, {6, 7}}}
|
||||
local flattened = table.flatten(nested)
|
||||
assert_table_equal({1, 2, 3, 4, 5, {6, 7}}, flattened)
|
||||
|
||||
local deep_flattened = table.flatten(nested, 2)
|
||||
assert_table_equal({1, 2, 3, 4, 5, 6, 7}, deep_flattened)
|
||||
|
||||
local already_flat = table.flatten({1, 2, 3})
|
||||
assert_table_equal({1, 2, 3}, already_flat)
|
||||
end)
|
||||
|
||||
test("Table Deep Merge", function()
|
||||
local t1 = {a = {x = 1}, b = 2}
|
||||
local t2 = {a = {y = 3}, c = 4}
|
||||
|
||||
local merged = table.deep_merge(t1, t2)
|
||||
assert_equal(1, merged.a.x)
|
||||
assert_equal(3, merged.a.y)
|
||||
assert_equal(2, merged.b)
|
||||
assert_equal(4, merged.c)
|
||||
|
||||
-- Original tables should be unchanged
|
||||
assert_equal(nil, t1.a.y)
|
||||
assert_equal(nil, t1.c)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- ADVANCED OPERATIONS
|
||||
-- ======================================================================
|
||||
|
||||
test("Table Chunk", function()
|
||||
local chunks = table.chunk({1, 2, 3, 4, 5, 6, 7}, 3)
|
||||
assert_equal(3, #chunks)
|
||||
assert_table_equal({1, 2, 3}, chunks[1])
|
||||
assert_table_equal({4, 5, 6}, chunks[2])
|
||||
assert_table_equal({7}, chunks[3])
|
||||
|
||||
local exact_chunks = table.chunk({1, 2, 3, 4}, 2)
|
||||
assert_equal(2, #exact_chunks)
|
||||
assert_table_equal({1, 2}, exact_chunks[1])
|
||||
assert_table_equal({3, 4}, exact_chunks[2])
|
||||
end)
|
||||
|
||||
test("Table Partition", function()
|
||||
local evens, odds = table.partition(simple_array, function(v) return v % 2 == 0 end)
|
||||
assert_table_equal({2, 4}, evens)
|
||||
assert_table_equal({1, 3, 5}, odds)
|
||||
|
||||
local empty_true, all_false = table.partition({1, 3, 5}, function(v) return v % 2 == 0 end)
|
||||
assert_table_equal({}, empty_true)
|
||||
assert_table_equal({1, 3, 5}, all_false)
|
||||
end)
|
||||
|
||||
test("Table Group By", function()
|
||||
local people = {
|
||||
{name = "Alice", department = "engineering"},
|
||||
{name = "Bob", department = "sales"},
|
||||
{name = "Charlie", department = "engineering"},
|
||||
{name = "David", department = "sales"}
|
||||
}
|
||||
|
||||
local by_dept = table.group_by(people, function(person) return person.department end)
|
||||
assert_equal(2, table.size(by_dept))
|
||||
assert_equal(2, #by_dept.engineering)
|
||||
assert_equal(2, #by_dept.sales)
|
||||
assert_equal("Alice", by_dept.engineering[1].name)
|
||||
assert_equal("Bob", by_dept.sales[1].name)
|
||||
end)
|
||||
|
||||
test("Table Zip", function()
|
||||
local names = {"Alice", "Bob", "Charlie"}
|
||||
local ages = {25, 30, 35}
|
||||
local cities = {"NYC", "LA", "Chicago"}
|
||||
|
||||
local zipped = table.zip(names, ages, cities)
|
||||
assert_equal(3, #zipped)
|
||||
assert_table_equal({"Alice", 25, "NYC"}, zipped[1])
|
||||
assert_table_equal({"Bob", 30, "LA"}, zipped[2])
|
||||
assert_table_equal({"Charlie", 35, "Chicago"}, zipped[3])
|
||||
|
||||
-- Different lengths
|
||||
local short_zip = table.zip({1, 2, 3}, {"a", "b"})
|
||||
assert_equal(2, #short_zip)
|
||||
assert_table_equal({1, "a"}, short_zip[1])
|
||||
assert_table_equal({2, "b"}, short_zip[2])
|
||||
end)
|
||||
|
||||
test("Table Compact", function()
|
||||
local messy = {1, nil, false, 2, nil, 3, false}
|
||||
local compacted = table.compact(messy)
|
||||
assert_table_equal({1, 2, 3}, compacted)
|
||||
|
||||
local clean = {1, 2, 3}
|
||||
local unchanged = table.compact(clean)
|
||||
assert_table_equal(clean, unchanged)
|
||||
end)
|
||||
|
||||
test("Table Sample", function()
|
||||
local sample1 = table.sample(simple_array, 3)
|
||||
assert_equal(3, #sample1)
|
||||
|
||||
-- All sampled elements should be from original
|
||||
for _, v in ipairs(sample1) do
|
||||
assert_equal(true, table.contains(simple_array, v))
|
||||
end
|
||||
|
||||
local single_sample = table.sample(simple_array)
|
||||
assert_equal(1, #single_sample)
|
||||
assert_equal(true, table.contains(simple_array, single_sample[1]))
|
||||
|
||||
local oversample = table.sample({1, 2}, 5)
|
||||
assert_equal(2, #oversample)
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- EDGE CASES AND ERROR HANDLING
|
||||
-- ======================================================================
|
||||
|
||||
test("Empty Table Handling", function()
|
||||
local empty = {}
|
||||
|
||||
assert_equal(true, table.is_empty(empty))
|
||||
assert_equal(0, table.length(empty))
|
||||
assert_equal(0, table.size(empty))
|
||||
assert_equal(true, table.is_array(empty))
|
||||
|
||||
assert_table_equal({}, table.filter(empty, function() return true end))
|
||||
assert_table_equal({}, table.map(empty, function(v) return v * 2 end))
|
||||
assert_table_equal({}, table.keys(empty))
|
||||
assert_table_equal({}, table.values(empty))
|
||||
|
||||
assert_equal(true, table.all(empty))
|
||||
assert_equal(false, table.any(empty))
|
||||
assert_equal(true, table.none(empty))
|
||||
end)
|
||||
|
||||
test("Single Element Tables", function()
|
||||
local single = {42}
|
||||
|
||||
assert_equal(1, table.length(single))
|
||||
assert_equal(1, table.size(single))
|
||||
assert_equal(true, table.is_array(single))
|
||||
assert_equal(false, table.is_empty(single))
|
||||
|
||||
assert_equal(42, table.sum(single))
|
||||
assert_equal(42, table.product(single))
|
||||
assert_equal(42, table.min(single))
|
||||
assert_equal(42, table.max(single))
|
||||
assert_equal(42, table.average(single))
|
||||
|
||||
assert_table_equal({42}, table.reverse(single))
|
||||
assert_table_equal({84}, table.map(single, function(v) return v * 2 end))
|
||||
end)
|
||||
|
||||
test("Circular Reference Handling", function()
|
||||
local t1 = {a = 1}
|
||||
local t2 = {b = 2}
|
||||
t1.ref = t2
|
||||
t2.ref = t1
|
||||
|
||||
-- Deep copy should handle circular references
|
||||
local copied = table.deep_copy(t1)
|
||||
assert_equal(1, copied.a)
|
||||
assert_equal(2, copied.ref.b)
|
||||
assert_equal(copied, copied.ref.ref) -- Should maintain circular structure
|
||||
end)
|
||||
|
||||
test("Large Table Performance", function()
|
||||
local large = {}
|
||||
for i = 1, 10000 do
|
||||
large[i] = i
|
||||
end
|
||||
|
||||
assert_equal(10000, table.length(large))
|
||||
assert_equal(true, table.is_array(large))
|
||||
assert_equal(50005000, table.sum(large)) -- Sum of 1 to 10000
|
||||
|
||||
local evens = table.filter(large, function(v) return v % 2 == 0 end)
|
||||
assert_equal(5000, #evens)
|
||||
|
||||
local doubled = table.map(large, function(v) return v * 2 end)
|
||||
assert_equal(10000, #doubled)
|
||||
assert_equal(2, doubled[1])
|
||||
assert_equal(20000, doubled[10000])
|
||||
end)
|
||||
|
||||
test("Mixed Type Table Handling", function()
|
||||
local mixed = {1, "hello", true, {a = 1}, function() end}
|
||||
|
||||
assert_equal(5, table.length(mixed))
|
||||
assert_equal(true, table.is_array(mixed))
|
||||
assert_equal(true, table.contains(mixed, "hello"))
|
||||
assert_equal(true, table.contains(mixed, true))
|
||||
|
||||
local strings_only = table.filter(mixed, function(v) return type(v) == "string" end)
|
||||
assert_equal(1, #strings_only)
|
||||
assert_equal("hello", strings_only[1])
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- PERFORMANCE TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Performance Test", function()
|
||||
local large_array = {}
|
||||
for i = 1, 10000 do
|
||||
large_array[i] = math.random(1, 1000)
|
||||
end
|
||||
|
||||
local start = os.clock()
|
||||
local filtered = table.filter(large_array, function(v) return v > 500 end)
|
||||
local filter_time = os.clock() - start
|
||||
|
||||
start = os.clock()
|
||||
local mapped = table.map(large_array, function(v) return v * 2 end)
|
||||
local map_time = os.clock() - start
|
||||
|
||||
start = os.clock()
|
||||
local sum = table.sum(large_array)
|
||||
local sum_time = os.clock() - start
|
||||
|
||||
start = os.clock()
|
||||
local sorted = table.sort_by(large_array, function(v) return v end)
|
||||
local sort_time = os.clock() - start
|
||||
|
||||
start = os.clock()
|
||||
local unique = table.unique(large_array)
|
||||
local unique_time = os.clock() - start
|
||||
|
||||
print(string.format(" Filter %d elements: %.3fs", #filtered, filter_time))
|
||||
print(string.format(" Map %d elements: %.3fs", #mapped, map_time))
|
||||
print(string.format(" Sum %d elements: %.3fs", #large_array, sum_time))
|
||||
print(string.format(" Sort %d elements: %.3fs", #sorted, sort_time))
|
||||
print(string.format(" Unique from %d to %d: %.3fs", #large_array, #unique, unique_time))
|
||||
|
||||
assert(#filtered > 0, "should filter some elements")
|
||||
assert_equal(#large_array, #mapped)
|
||||
assert(sum > 0, "sum should be positive")
|
||||
assert_equal(#large_array, #sorted)
|
||||
assert(table.is_sorted(sorted), "should be sorted")
|
||||
end)
|
||||
|
||||
-- ======================================================================
|
||||
-- INTEGRATION TESTS
|
||||
-- ======================================================================
|
||||
|
||||
test("Data Processing Pipeline", function()
|
||||
local sales_data = {
|
||||
{product = "laptop", price = 1000, quantity = 2, category = "electronics"},
|
||||
{product = "mouse", price = 25, quantity = 10, category = "electronics"},
|
||||
{product = "book", price = 15, quantity = 5, category = "books"},
|
||||
{product = "phone", price = 800, quantity = 3, category = "electronics"},
|
||||
{product = "magazine", price = 5, quantity = 20, category = "books"}
|
||||
}
|
||||
|
||||
-- Calculate total revenue per item
|
||||
local with_revenue = table.map(sales_data, function(item)
|
||||
local new_item = table.clone(item)
|
||||
new_item.revenue = item.price * item.quantity
|
||||
return new_item
|
||||
end)
|
||||
|
||||
-- Filter high-value items (revenue >= 100)
|
||||
local high_value = table.filter(with_revenue, function(item)
|
||||
return item.revenue >= 100
|
||||
end)
|
||||
|
||||
-- Group by category
|
||||
local by_category = table.group_by(high_value, function(item)
|
||||
return item.category
|
||||
end)
|
||||
|
||||
-- Calculate total revenue by category
|
||||
local category_totals = table.map_values(by_category, function(items)
|
||||
return table.sum(table.map(items, function(item) return item.revenue end))
|
||||
end)
|
||||
|
||||
assert_equal(2, table.size(category_totals))
|
||||
assert_equal(4650, category_totals.electronics) -- laptop: 2000, mouse: 250, phone: 2400
|
||||
assert_equal(100, category_totals.books) -- magazine: 100
|
||||
end)
|
||||
|
||||
test("Complex Data Transformation", function()
|
||||
local users = {
|
||||
{id = 1, name = "Alice", age = 25, skills = {"lua", "python"}},
|
||||
{id = 2, name = "Bob", age = 30, skills = {"javascript", "lua"}},
|
||||
{id = 3, name = "Charlie", age = 35, skills = {"python", "java"}},
|
||||
{id = 4, name = "David", age = 28, skills = {"lua", "go"}}
|
||||
}
|
||||
|
||||
-- Find Lua developers
|
||||
local lua_devs = table.filter(users, function(user)
|
||||
return table.contains(user.skills, "lua")
|
||||
end)
|
||||
|
||||
-- Sort by age
|
||||
local sorted_lua_devs = table.sort_by(lua_devs, function(user) return user.age end)
|
||||
|
||||
-- Extract just names and ages
|
||||
local simplified = table.map(sorted_lua_devs, function(user)
|
||||
return {name = user.name, age = user.age}
|
||||
end)
|
||||
|
||||
assert_equal(3, #simplified)
|
||||
assert_equal("Alice", simplified[1].name) -- Youngest
|
||||
assert_equal("David", simplified[2].name)
|
||||
assert_equal("Bob", simplified[3].name) -- Oldest
|
||||
|
||||
-- Group all users by age ranges
|
||||
local age_groups = table.group_by(users, function(user)
|
||||
if user.age < 30 then return "young"
|
||||
else return "experienced" end
|
||||
end)
|
||||
|
||||
assert_equal(2, #age_groups.young) -- Alice, David
|
||||
assert_equal(2, #age_groups.experienced) -- Bob, Charlie
|
||||
end)
|
||||
|
||||
test("Statistical Analysis", function()
|
||||
local test_scores = {
|
||||
{student = "Alice", scores = {85, 92, 78, 95, 88}},
|
||||
{student = "Bob", scores = {72, 85, 90, 77, 82}},
|
||||
{student = "Charlie", scores = {95, 88, 92, 90, 85}},
|
||||
{student = "David", scores = {68, 75, 80, 72, 78}}
|
||||
}
|
||||
|
||||
-- Calculate average score for each student
|
||||
local with_averages = table.map(test_scores, function(student)
|
||||
local avg = table.average(student.scores)
|
||||
return {
|
||||
student = student.student,
|
||||
scores = student.scores,
|
||||
average = avg,
|
||||
max_score = table.max(student.scores),
|
||||
min_score = table.min(student.scores)
|
||||
}
|
||||
end)
|
||||
|
||||
-- Find top performer
|
||||
local top_student = table.reduce(with_averages, function(best, current)
|
||||
return current.average > best.average and current or best
|
||||
end)
|
||||
|
||||
-- Students above class average
|
||||
local all_averages = table.map(with_averages, function(s) return s.average end)
|
||||
local class_average = table.average(all_averages)
|
||||
local above_average = table.filter(with_averages, function(s)
|
||||
return s.average > class_average
|
||||
end)
|
||||
|
||||
assert_equal("Charlie", top_student.student)
|
||||
assert_equal(90, top_student.average)
|
||||
assert_equal(2, #above_average) -- Charlie and Alice
|
||||
assert_close(83.4, class_average, 0.1)
|
||||
end)
|
||||
|
||||
test("Table Metatable Method Chaining", function()
|
||||
local t = {1, 2, 3, 4, 5}
|
||||
|
||||
-- Check if methods are available directly on table instances
|
||||
assert_equal("function", type(t.filter), "table should have filter method via metatable")
|
||||
assert_equal("function", type(t.map), "table should have map method via metatable")
|
||||
assert_equal("function", type(t.length), "table should have length method via metatable")
|
||||
|
||||
-- Test actual method chaining
|
||||
local result = t:filter(function(v) return v > 2 end)
|
||||
:map(function(v) return v * 2 end)
|
||||
|
||||
assert_table_equal({6, 8, 10}, result)
|
||||
|
||||
-- Test chaining with method calls
|
||||
local nums = {1, 2, 3, 4, 5}
|
||||
local sum = nums:sum()
|
||||
assert_equal(15, sum)
|
||||
|
||||
local sizes = {a = 1, b = 2}
|
||||
local size = sizes:size()
|
||||
assert_equal(2, size)
|
||||
end)
|
||||
|
||||
summary()
|
||||
test_exit()
|
||||
192
tests/tests.lua
192
tests/tests.lua
@ -1,192 +0,0 @@
|
||||
local passed = 0
|
||||
local total = 0
|
||||
|
||||
function assert(condition, message, level)
|
||||
if condition then
|
||||
return true
|
||||
end
|
||||
|
||||
level = level or 2
|
||||
local info = debug.getinfo(level, "Sl")
|
||||
local file = info.source
|
||||
|
||||
-- Extract filename from source or use generic name
|
||||
if file:sub(1,1) == "@" then
|
||||
file = file:sub(2) -- Remove @ prefix for files
|
||||
else
|
||||
file = "<script>" -- Generic name for inline scripts
|
||||
end
|
||||
|
||||
local line = info.currentline or "unknown"
|
||||
local error_msg = message or "assertion failed"
|
||||
local full_msg = string.format("%s:%s: %s", file, line, error_msg)
|
||||
|
||||
error(full_msg, 0)
|
||||
end
|
||||
|
||||
function assert_equal(expected, actual, message)
|
||||
if expected == actual then
|
||||
return true
|
||||
end
|
||||
|
||||
local msg = message or string.format("Expected %s, got %s", tostring(expected), tostring(actual))
|
||||
assert(false, msg, 3)
|
||||
end
|
||||
|
||||
function assert_table_equal(expected, actual, message, path)
|
||||
path = path or "root"
|
||||
|
||||
if type(expected) ~= type(actual) then
|
||||
local msg = message or string.format("Type mismatch at %s: expected %s, got %s", path, type(expected), type(actual))
|
||||
assert(false, msg, 3)
|
||||
end
|
||||
|
||||
if type(expected) ~= "table" then
|
||||
if expected ~= actual then
|
||||
local msg = message or string.format("Value mismatch at %s: expected %s, got %s", path, tostring(expected), tostring(actual))
|
||||
assert(false, msg, 3)
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check all keys in a exist in b with same values
|
||||
for k, v in pairs(expected) do
|
||||
local new_path = path .. "." .. tostring(k)
|
||||
if actual[k] == nil then
|
||||
local msg = message or string.format("Missing key at %s", new_path)
|
||||
assert(false, msg, 3)
|
||||
end
|
||||
assert_table_equal(v, actual[k], message, new_path)
|
||||
end
|
||||
|
||||
-- Check all keys in b exist in a
|
||||
for k, v in pairs(actual) do
|
||||
if expected[k] == nil then
|
||||
local new_path = path .. "." .. tostring(k)
|
||||
local msg = message or string.format("Extra key at %s", new_path)
|
||||
assert(false, msg, 3)
|
||||
end
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
function assert_close(expected, actual, tolerance, message)
|
||||
tolerance = tolerance or 1e-10
|
||||
local diff = math.abs(expected - actual)
|
||||
if diff <= tolerance then
|
||||
return true
|
||||
end
|
||||
local msg = message or string.format("Expected %g, got %g (diff: %g > %g)", expected, actual, diff, tolerance)
|
||||
assert(false, msg, 3)
|
||||
end
|
||||
|
||||
function test(name, fn)
|
||||
print("Testing " .. name .. "...")
|
||||
total = total + 1
|
||||
|
||||
local start_time = os.clock()
|
||||
local ok, err = pcall(fn)
|
||||
local end_time = os.clock()
|
||||
local duration = end_time - start_time
|
||||
|
||||
if ok then
|
||||
passed = passed + 1
|
||||
print(string.format(" ✓ PASS (%.3fs)", duration))
|
||||
return true
|
||||
else
|
||||
print(" ✗ FAIL: " .. err)
|
||||
if duration > 0.001 then
|
||||
print(string.format(" (%.3fs)", duration))
|
||||
end
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
function run_tests(tests)
|
||||
print("Running test suite...")
|
||||
print("=" .. string.rep("=", 50))
|
||||
|
||||
for name, test_fn in pairs(tests) do
|
||||
test(name, test_fn)
|
||||
end
|
||||
|
||||
return summary()
|
||||
end
|
||||
|
||||
function reset_tests()
|
||||
passed = 0
|
||||
total = 0
|
||||
end
|
||||
|
||||
function test_stats()
|
||||
return {
|
||||
passed = passed,
|
||||
total = total,
|
||||
failed = total - passed,
|
||||
success_rate = total > 0 and (passed / total) or 0
|
||||
}
|
||||
end
|
||||
|
||||
function summary()
|
||||
print("=" .. string.rep("=", 50))
|
||||
print(string.format("Test Results: %d/%d passed", passed, total))
|
||||
|
||||
local success = passed == total
|
||||
if success then
|
||||
print("🎉 All tests passed!")
|
||||
else
|
||||
local failed = total - passed
|
||||
local rate = total > 0 and (passed / total * 100) or 0
|
||||
print(string.format("❌ %d test(s) failed! (%.1f%% success rate)", failed, rate))
|
||||
end
|
||||
|
||||
return success
|
||||
end
|
||||
|
||||
function test_exit()
|
||||
local success = passed == total
|
||||
os.exit(success and 0 or 1)
|
||||
end
|
||||
|
||||
function run_and_exit(tests)
|
||||
local success = run_tests(tests)
|
||||
os.exit(success and 0 or 1)
|
||||
end
|
||||
|
||||
function benchmark(name, fn, iterations)
|
||||
iterations = iterations or 1000
|
||||
print("Benchmarking " .. name .. " (" .. iterations .. " iterations)...")
|
||||
|
||||
-- Warmup
|
||||
for i = 1, math.min(10, iterations) do
|
||||
fn()
|
||||
end
|
||||
|
||||
-- Actual benchmark
|
||||
local start = os.clock()
|
||||
for i = 1, iterations do
|
||||
fn()
|
||||
end
|
||||
local total_time = os.clock() - start
|
||||
local avg_time = total_time / iterations
|
||||
|
||||
print(string.format(" Total: %.3fs, Average: %.6fs, Rate: %.0f ops/sec",
|
||||
total_time, avg_time, 1/avg_time))
|
||||
|
||||
return {
|
||||
total_time = total_time,
|
||||
avg_time = avg_time,
|
||||
ops_per_sec = 1/avg_time,
|
||||
iterations = iterations
|
||||
}
|
||||
end
|
||||
|
||||
function file_exists(filename)
|
||||
local file = io.open(filename, "r")
|
||||
if file then
|
||||
file:close()
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
68
utils/color/color.go
Normal file
68
utils/color/color.go
Normal file
@ -0,0 +1,68 @@
|
||||
package color
|
||||
|
||||
// ANSI color codes
|
||||
const (
|
||||
Reset = "\033[0m"
|
||||
Red = "\033[31m"
|
||||
Green = "\033[32m"
|
||||
Yellow = "\033[33m"
|
||||
Blue = "\033[34m"
|
||||
Purple = "\033[35m"
|
||||
Cyan = "\033[36m"
|
||||
White = "\033[37m"
|
||||
Gray = "\033[90m"
|
||||
)
|
||||
|
||||
var useColors = true
|
||||
|
||||
// SetColors enables or disables colors globally
|
||||
func SetColors(enabled bool) {
|
||||
useColors = enabled
|
||||
}
|
||||
|
||||
// ColorsEnabled returns current global color setting
|
||||
func ColorsEnabled() bool {
|
||||
return useColors
|
||||
}
|
||||
|
||||
// Apply adds color to text using global color setting
|
||||
func Apply(text, color string) string {
|
||||
if useColors {
|
||||
return color + text + Reset
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
// ApplyIf adds color to text if useColors is true (for backward compatibility)
|
||||
func ApplyIf(text, color string, enabled bool) string {
|
||||
if enabled {
|
||||
return color + text + Reset
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
// Set adds color to text (always applies color, ignores global setting)
|
||||
func Set(text, color string) string {
|
||||
return color + text + Reset
|
||||
}
|
||||
|
||||
// Strip removes ANSI color codes from a string
|
||||
func Strip(s string) string {
|
||||
result := ""
|
||||
inEscape := false
|
||||
|
||||
for _, c := range s {
|
||||
if inEscape {
|
||||
if c == 'm' {
|
||||
inEscape = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if c == '\033' {
|
||||
inEscape = true
|
||||
continue
|
||||
}
|
||||
result += string(c)
|
||||
}
|
||||
return result
|
||||
}
|
||||
290
utils/config/config.go
Normal file
290
utils/config/config.go
Normal file
@ -0,0 +1,290 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Config represents a configuration loaded from a Lua file
|
||||
type Config struct {
|
||||
// Server settings
|
||||
Server struct {
|
||||
Port int
|
||||
Debug bool
|
||||
HTTPLogging bool
|
||||
StaticPrefix string
|
||||
}
|
||||
|
||||
// Runner settings
|
||||
Runner struct {
|
||||
PoolSize int
|
||||
}
|
||||
|
||||
// Directory paths
|
||||
Dirs struct {
|
||||
Routes string
|
||||
Static string
|
||||
FS string
|
||||
Data string
|
||||
Override string
|
||||
Libs []string
|
||||
}
|
||||
|
||||
// Raw values map for custom values
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
// NewConfig creates a new configuration with default values
|
||||
func New() *Config {
|
||||
config := &Config{
|
||||
// Initialize values map
|
||||
values: make(map[string]any),
|
||||
}
|
||||
|
||||
// Server defaults
|
||||
config.Server.Port = 3117
|
||||
config.Server.Debug = false
|
||||
config.Server.HTTPLogging = false
|
||||
config.Server.StaticPrefix = "static/"
|
||||
|
||||
// Runner defaults
|
||||
config.Runner.PoolSize = runtime.GOMAXPROCS(0)
|
||||
|
||||
// Dirs defaults
|
||||
config.Dirs.Routes = "routes"
|
||||
config.Dirs.Static = "public"
|
||||
config.Dirs.FS = "fs"
|
||||
config.Dirs.Data = "data"
|
||||
config.Dirs.Override = "override"
|
||||
config.Dirs.Libs = []string{"libs"}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// Load loads configuration from a Lua file
|
||||
func Load(filePath string) (*Config, error) {
|
||||
// Create Lua state
|
||||
state := luajit.New(true)
|
||||
if state == nil {
|
||||
return nil, errors.New("failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
defer state.Cleanup()
|
||||
|
||||
// Create config with default values
|
||||
config := New()
|
||||
|
||||
// Execute the config file
|
||||
if err := state.DoFile(filePath); err != nil {
|
||||
return nil, fmt.Errorf("failed to load config file: %w", err)
|
||||
}
|
||||
|
||||
// Store values directly to the config
|
||||
config.values = make(map[string]any)
|
||||
|
||||
// Extract top-level tables
|
||||
tables := []string{"server", "runner", "dirs"}
|
||||
for _, table := range tables {
|
||||
state.GetGlobal(table)
|
||||
if state.IsTable(-1) {
|
||||
tableMap, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
config.values[table] = tableMap
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
// Apply configuration values
|
||||
applyConfig(config, config.values)
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// applyConfig applies configuration values from the globals map
|
||||
func applyConfig(config *Config, values map[string]any) {
|
||||
// Apply server settings
|
||||
if serverTable, ok := values["server"].(map[string]any); ok {
|
||||
if v, ok := serverTable["port"].(float64); ok {
|
||||
config.Server.Port = int(v)
|
||||
}
|
||||
if v, ok := serverTable["debug"].(bool); ok {
|
||||
config.Server.Debug = v
|
||||
}
|
||||
if v, ok := serverTable["http_logging"].(bool); ok {
|
||||
config.Server.HTTPLogging = v
|
||||
}
|
||||
if v, ok := serverTable["static_prefix"].(string); ok {
|
||||
config.Server.StaticPrefix = v
|
||||
}
|
||||
}
|
||||
|
||||
// Apply runner settings
|
||||
if runnerTable, ok := values["runner"].(map[string]any); ok {
|
||||
if v, ok := runnerTable["pool_size"].(float64); ok && v != 0 {
|
||||
config.Runner.PoolSize = int(v)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply dirs settings
|
||||
if dirsTable, ok := values["dirs"].(map[string]any); ok {
|
||||
if v, ok := dirsTable["routes"].(string); ok {
|
||||
config.Dirs.Routes = v
|
||||
}
|
||||
if v, ok := dirsTable["static"].(string); ok {
|
||||
config.Dirs.Static = v
|
||||
}
|
||||
if v, ok := dirsTable["fs"].(string); ok {
|
||||
config.Dirs.FS = v
|
||||
}
|
||||
if v, ok := dirsTable["data"].(string); ok {
|
||||
config.Dirs.Data = v
|
||||
}
|
||||
if v, ok := dirsTable["override"].(string); ok {
|
||||
config.Dirs.Override = v
|
||||
}
|
||||
|
||||
// Handle libs array
|
||||
if libs, ok := dirsTable["libs"]; ok {
|
||||
if libsArray := extractStringArray(libs); len(libsArray) > 0 {
|
||||
config.Dirs.Libs = libsArray
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractStringArray extracts a string array from a Lua table
|
||||
func extractStringArray(value any) []string {
|
||||
// Direct array case
|
||||
if arr, ok := value.([]any); ok {
|
||||
result := make([]string, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
if str, ok := v.(string); ok {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Map with numeric keys case
|
||||
if tableMap, ok := value.(map[string]any); ok {
|
||||
result := make([]string, 0, len(tableMap))
|
||||
for _, v := range tableMap {
|
||||
if str, ok := v.(string); ok {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCustomValue returns any custom configuration value by key
|
||||
// Key can be a dotted path like "server.port"
|
||||
func (c *Config) GetCustomValue(key string) any {
|
||||
parts := strings.Split(key, ".")
|
||||
|
||||
if len(parts) == 1 {
|
||||
return c.values[key]
|
||||
}
|
||||
|
||||
current := c.values
|
||||
for _, part := range parts[:len(parts)-1] {
|
||||
next, ok := current[part].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
current = next
|
||||
}
|
||||
|
||||
return current[parts[len(parts)-1]]
|
||||
}
|
||||
|
||||
// GetCustomString returns a custom string configuration value
|
||||
func (c *Config) GetCustomString(key string, defaultValue string) string {
|
||||
value := c.GetCustomValue(key)
|
||||
if value == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// Convert to string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case float64:
|
||||
return fmt.Sprintf("%g", v)
|
||||
case int:
|
||||
return strconv.Itoa(v)
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
default:
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
// GetCustomInt returns a custom integer configuration value
|
||||
func (c *Config) GetCustomInt(key string, defaultValue int) int {
|
||||
value := c.GetCustomValue(key)
|
||||
if value == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// Convert to int
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case float64:
|
||||
return int(v)
|
||||
case string:
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return i
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// GetCustomBool returns a custom boolean configuration value
|
||||
func (c *Config) GetCustomBool(key string, defaultValue bool) bool {
|
||||
value := c.GetCustomValue(key)
|
||||
if value == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// Convert to bool
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
switch v {
|
||||
case "true", "yes", "1":
|
||||
return true
|
||||
case "false", "no", "0", "":
|
||||
return false
|
||||
}
|
||||
case int:
|
||||
return v != 0
|
||||
case float64:
|
||||
return v != 0
|
||||
default:
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
217
utils/debug.go
Normal file
217
utils/debug.go
Normal file
@ -0,0 +1,217 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"Moonshark/utils/config"
|
||||
"Moonshark/utils/metadata"
|
||||
)
|
||||
|
||||
// ComponentStats holds stats from various system components
|
||||
type ComponentStats struct {
|
||||
RouteCount int // Number of routes
|
||||
BytecodeBytes int64 // Total size of bytecode in bytes
|
||||
ModuleCount int // Number of loaded modules
|
||||
SessionStats map[string]uint64 // Session cache statistics
|
||||
}
|
||||
|
||||
// SystemStats represents system statistics for debugging
|
||||
type SystemStats struct {
|
||||
Timestamp time.Time
|
||||
GoVersion string
|
||||
GoRoutines int
|
||||
Memory runtime.MemStats
|
||||
Components ComponentStats
|
||||
Version string
|
||||
Config *config.Config
|
||||
}
|
||||
|
||||
// CollectSystemStats gathers basic system statistics
|
||||
func CollectSystemStats(cfg *config.Config) SystemStats {
|
||||
var stats SystemStats
|
||||
var mem runtime.MemStats
|
||||
|
||||
stats.Timestamp = time.Now()
|
||||
stats.GoVersion = runtime.Version()
|
||||
stats.GoRoutines = runtime.NumGoroutine()
|
||||
stats.Version = metadata.Version
|
||||
stats.Config = cfg
|
||||
|
||||
runtime.ReadMemStats(&mem)
|
||||
stats.Memory = mem
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// DebugStatsPage generates an HTML debug stats page
|
||||
func DebugStatsPage(stats SystemStats) string {
|
||||
const debugTemplate = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Moonshark</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
max-width: 900px;
|
||||
margin: 0 auto;
|
||||
background-color: #333;
|
||||
color: white;
|
||||
}
|
||||
h1 {
|
||||
padding: 1rem;
|
||||
background-color: #4F5B93;
|
||||
box-shadow: 0 2px 4px 0px rgba(0, 0, 0, 0.2);
|
||||
margin-top: 0;
|
||||
}
|
||||
h2 { margin-top: 0; margin-bottom: 0.5rem; }
|
||||
table { width: 100%; border-collapse: collapse; }
|
||||
th { width: 1%; white-space: nowrap; border-right: 1px solid rgba(0, 0, 0, 0.1); }
|
||||
th, td { text-align: left; padding: 0.75rem 0.5rem; border-bottom: 1px solid #ddd; }
|
||||
tr:last-child th, tr:last-child td { border-bottom: none; }
|
||||
table tr:nth-child(even), tbody tr:nth-child(even) { background-color: rgba(0, 0, 0, 0.1); }
|
||||
.card {
|
||||
background: #F2F2F2;
|
||||
color: #333;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 1rem;
|
||||
box-shadow: 0 2px 4px 0px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
.timestamp { color: #999; font-size: 0.9em; margin-bottom: 1rem; }
|
||||
.section { margin-bottom: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Moonshark</h1>
|
||||
<div class="timestamp">Generated at: {{.Timestamp.Format "2006-01-02 15:04:05"}}</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Server</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Version</th><td>{{.Version}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>System</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Go Version</th><td>{{.GoVersion}}</td></tr>
|
||||
<tr><th>Goroutines</th><td>{{.GoRoutines}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Memory</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Allocated</th><td>{{ByteCount .Memory.Alloc}}</td></tr>
|
||||
<tr><th>Total Allocated</th><td>{{ByteCount .Memory.TotalAlloc}}</td></tr>
|
||||
<tr><th>System Memory</th><td>{{ByteCount .Memory.Sys}}</td></tr>
|
||||
<tr><th>GC Cycles</th><td>{{.Memory.NumGC}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Sessions</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Active Sessions</th><td>{{index .Components.SessionStats "entries"}}</td></tr>
|
||||
<tr><th>Cache Size</th><td>{{ByteCount (index .Components.SessionStats "bytes")}}</td></tr>
|
||||
<tr><th>Max Cache Size</th><td>{{ByteCount (index .Components.SessionStats "max_bytes")}}</td></tr>
|
||||
<tr><th>Cache Gets</th><td>{{index .Components.SessionStats "gets"}}</td></tr>
|
||||
<tr><th>Cache Sets</th><td>{{index .Components.SessionStats "sets"}}</td></tr>
|
||||
<tr><th>Cache Misses</th><td>{{index .Components.SessionStats "misses"}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>LuaRunner</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Interpreter</th><td>LuaJIT 2.1 (Lua 5.1)</td></tr>
|
||||
<tr><th>Active Routes</th><td>{{.Components.RouteCount}}</td></tr>
|
||||
<tr><th>Bytecode Size</th><td>{{ByteCount .Components.BytecodeBytes}}</td></tr>
|
||||
<tr><th>Loaded Modules</th><td>{{.Components.ModuleCount}}</td></tr>
|
||||
<tr><th>State Pool Size</th><td>{{.Config.Runner.PoolSize}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Config</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Port</th><td>{{.Config.Server.Port}}</td></tr>
|
||||
<tr><th>Pool Size</th><td>{{.Config.Runner.PoolSize}}</td></tr>
|
||||
<tr><th>Debug Mode</th><td>{{.Config.Server.Debug}}</td></tr>
|
||||
<tr><th>Log Level</th><td>{{.Config.Server.LogLevel}}</td></tr>
|
||||
<tr><th>HTTP Logging</th><td>{{.Config.Server.HTTPLogging}}</td></tr>
|
||||
<tr>
|
||||
<th>Directories</th>
|
||||
<td>
|
||||
<div>Routes: {{.Config.Dirs.Routes}}</div>
|
||||
<div>Static: {{.Config.Dirs.Static}}</div>
|
||||
<div>Override: {{.Config.Dirs.Override}}</div>
|
||||
<div>Libs: {{range .Config.Dirs.Libs}}{{.}}, {{end}}</div>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
// Create a template function map
|
||||
funcMap := template.FuncMap{
|
||||
"ByteCount": func(b any) string {
|
||||
var bytes uint64
|
||||
|
||||
switch v := b.(type) {
|
||||
case uint64:
|
||||
bytes = v
|
||||
case int64:
|
||||
bytes = uint64(v)
|
||||
case int:
|
||||
bytes = uint64(v)
|
||||
default:
|
||||
return fmt.Sprintf("%T: %v", b, b)
|
||||
}
|
||||
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := uint64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
},
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("debug").Funcs(funcMap).Parse(debugTemplate)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error parsing template: %v", err)
|
||||
}
|
||||
|
||||
// Execute the template
|
||||
var output strings.Builder
|
||||
if err := tmpl.Execute(&output, stats); err != nil {
|
||||
return fmt.Sprintf("Error executing template: %v", err)
|
||||
}
|
||||
|
||||
return output.String()
|
||||
}
|
||||
257
utils/errorPages.go
Normal file
257
utils/errorPages.go
Normal file
@ -0,0 +1,257 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// ErrorPageConfig holds configuration for generating error pages
|
||||
type ErrorPageConfig struct {
|
||||
OverrideDir string // Directory where override templates are stored
|
||||
DebugMode bool // Whether to show debug information
|
||||
}
|
||||
|
||||
// ErrorType represents HTTP error types
|
||||
type ErrorType int
|
||||
|
||||
const (
|
||||
ErrorTypeNotFound ErrorType = 404
|
||||
ErrorTypeMethodNotAllowed ErrorType = 405
|
||||
ErrorTypeInternalError ErrorType = 500
|
||||
ErrorTypeForbidden ErrorType = 403 // Added CSRF/Forbidden error type
|
||||
)
|
||||
|
||||
// ErrorPage generates an HTML error page based on the error type
|
||||
// It first checks for an override file, and if not found, generates a default page
|
||||
func ErrorPage(config ErrorPageConfig, errorType ErrorType, url string, errMsg string) string {
|
||||
// Check for override file
|
||||
if config.OverrideDir != "" {
|
||||
var filename string
|
||||
switch errorType {
|
||||
case ErrorTypeNotFound:
|
||||
filename = "404.html"
|
||||
case ErrorTypeMethodNotAllowed:
|
||||
filename = "405.html"
|
||||
case ErrorTypeInternalError:
|
||||
filename = "500.html"
|
||||
case ErrorTypeForbidden:
|
||||
filename = "403.html"
|
||||
}
|
||||
|
||||
if filename != "" {
|
||||
overridePath := filepath.Join(config.OverrideDir, filename)
|
||||
if content, err := os.ReadFile(overridePath); err == nil {
|
||||
return string(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No override found, generate default page
|
||||
switch errorType {
|
||||
case ErrorTypeNotFound:
|
||||
return generateNotFoundHTML(url)
|
||||
case ErrorTypeMethodNotAllowed:
|
||||
return generateMethodNotAllowedHTML(url)
|
||||
case ErrorTypeInternalError:
|
||||
return generateInternalErrorHTML(config.DebugMode, url, errMsg)
|
||||
case ErrorTypeForbidden:
|
||||
return generateForbiddenHTML(config.DebugMode, url, errMsg)
|
||||
default:
|
||||
// Fallback to internal error
|
||||
return generateInternalErrorHTML(config.DebugMode, url, errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// NotFoundPage generates a 404 Not Found error page
|
||||
func NotFoundPage(config ErrorPageConfig, url string) string {
|
||||
return ErrorPage(config, ErrorTypeNotFound, url, "")
|
||||
}
|
||||
|
||||
// MethodNotAllowedPage generates a 405 Method Not Allowed error page
|
||||
func MethodNotAllowedPage(config ErrorPageConfig, url string) string {
|
||||
return ErrorPage(config, ErrorTypeMethodNotAllowed, url, "")
|
||||
}
|
||||
|
||||
// InternalErrorPage generates a 500 Internal Server Error page
|
||||
func InternalErrorPage(config ErrorPageConfig, url string, errMsg string) string {
|
||||
return ErrorPage(config, ErrorTypeInternalError, url, errMsg)
|
||||
}
|
||||
|
||||
// ForbiddenPage generates a 403 Forbidden error page
|
||||
func ForbiddenPage(config ErrorPageConfig, url string, errMsg string) string {
|
||||
return ErrorPage(config, ErrorTypeForbidden, url, errMsg)
|
||||
}
|
||||
|
||||
// generateInternalErrorHTML creates a 500 Internal Server Error page
|
||||
func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string {
|
||||
errorMessages := []string{
|
||||
"Oops! Something went wrong",
|
||||
"Oh no! The server choked",
|
||||
"Well, this is embarrassing...",
|
||||
"Houston, we have a problem",
|
||||
"Gremlins in the system",
|
||||
"The server is taking a coffee break",
|
||||
"Moonshark encountered a lunar eclipse",
|
||||
"Our code monkeys are working on it",
|
||||
"The server is feeling under the weather",
|
||||
"500 Brain Not Found",
|
||||
}
|
||||
|
||||
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
|
||||
return generateErrorHTML("500", randomMessage, "Internal Server Error", debugMode, errMsg)
|
||||
}
|
||||
|
||||
// generateForbiddenHTML creates a 403 Forbidden error page
|
||||
func generateForbiddenHTML(debugMode bool, url string, errMsg string) string {
|
||||
errorMessages := []string{
|
||||
"Access denied",
|
||||
"You shall not pass",
|
||||
"This area is off-limits",
|
||||
"Security check failed",
|
||||
"Invalid security token",
|
||||
"Request blocked for security reasons",
|
||||
"Permission denied",
|
||||
"Security violation detected",
|
||||
"This request was rejected",
|
||||
"Security first, access second",
|
||||
}
|
||||
|
||||
defaultMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
|
||||
if errMsg == "" {
|
||||
errMsg = defaultMsg
|
||||
}
|
||||
|
||||
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
|
||||
return generateErrorHTML("403", randomMessage, "Forbidden", debugMode, errMsg)
|
||||
}
|
||||
|
||||
// generateNotFoundHTML creates a 404 Not Found error page
|
||||
func generateNotFoundHTML(url string) string {
|
||||
errorMessages := []string{
|
||||
"Nothing to see here",
|
||||
"This page is on vacation",
|
||||
"The page is missing in action",
|
||||
"This page has left the building",
|
||||
"This page is in another castle",
|
||||
"Sorry, we can't find that",
|
||||
"The page you're looking for doesn't exist",
|
||||
"Lost in space",
|
||||
"That's a 404",
|
||||
"Page not found",
|
||||
}
|
||||
|
||||
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
|
||||
return generateErrorHTML("404", randomMessage, "Page Not Found", false, url)
|
||||
}
|
||||
|
||||
// generateMethodNotAllowedHTML creates a 405 Method Not Allowed error page
|
||||
func generateMethodNotAllowedHTML(url string) string {
|
||||
errorMessages := []string{
|
||||
"That's not how this works",
|
||||
"Method not allowed",
|
||||
"Wrong way!",
|
||||
"This method is not supported",
|
||||
"You can't do that here",
|
||||
"Sorry, wrong door",
|
||||
"That method won't work here",
|
||||
"Try a different approach",
|
||||
"Access denied for this method",
|
||||
"Method mismatch",
|
||||
}
|
||||
|
||||
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
|
||||
return generateErrorHTML("405", randomMessage, "Method Not Allowed", false, url)
|
||||
}
|
||||
|
||||
// generateErrorHTML creates the common HTML structure for error pages
|
||||
func generateErrorHTML(errorCode, mainMessage, subMessage string, showDebugInfo bool, codeContent string) string {
|
||||
errorHTML := `<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>` + errorCode + `</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg-color: #2d2e2d;
|
||||
--bg-gradient: linear-gradient(to bottom, #2d2e2d 0%, #000 100%);
|
||||
--text-color: white;
|
||||
--code-bg: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: light) {
|
||||
:root {
|
||||
--bg-color: #f5f5f5;
|
||||
--bg-gradient: linear-gradient(to bottom, #f5f5f5 0%, #ddd 100%);
|
||||
--text-color: #333;
|
||||
--code-bg: rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background-color: var(--bg-color);
|
||||
color: var(--text-color);
|
||||
background: var(--bg-gradient);
|
||||
}
|
||||
h1 {
|
||||
font-size: 4rem;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
p {
|
||||
font-size: 1.5rem;
|
||||
margin: 0.5rem 0;
|
||||
padding: 0;
|
||||
}
|
||||
.sub-message {
|
||||
font-size: 1.2rem;
|
||||
margin-bottom: 1rem;
|
||||
opacity: 0.8;
|
||||
}
|
||||
code {
|
||||
display: inline-block;
|
||||
font-size: 1rem;
|
||||
font-family: monospace;
|
||||
background-color: var(--code-bg);
|
||||
padding: 0.25em 0.5em;
|
||||
border-radius: 0.25em;
|
||||
margin-top: 1rem;
|
||||
max-width: 90vw;
|
||||
overflow-wrap: break-word;
|
||||
word-break: break-all;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<h1>` + errorCode + `</h1>
|
||||
<p>` + mainMessage + `</p>
|
||||
<div class="sub-message">` + subMessage + `</div>`
|
||||
|
||||
if codeContent != "" {
|
||||
errorHTML += `
|
||||
<code>` + codeContent + `</code>`
|
||||
}
|
||||
|
||||
// Add a note for debug mode
|
||||
if showDebugInfo {
|
||||
errorHTML += `
|
||||
<p style="font-size: 0.9rem; margin-top: 1rem;">
|
||||
An error occurred while processing your request.<br>
|
||||
Please check the server logs for details.
|
||||
</p>`
|
||||
}
|
||||
|
||||
errorHTML += `
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
return errorHTML
|
||||
}
|
||||
352
utils/logger/logger.go
Normal file
352
utils/logger/logger.go
Normal file
@ -0,0 +1,352 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ANSI color codes
|
||||
const (
|
||||
colorReset = "\033[0m"
|
||||
colorRed = "\033[31m"
|
||||
colorGreen = "\033[32m"
|
||||
colorYellow = "\033[33m"
|
||||
colorBlue = "\033[34m"
|
||||
colorPurple = "\033[35m"
|
||||
colorCyan = "\033[36m"
|
||||
colorGray = "\033[90m"
|
||||
)
|
||||
|
||||
// Log types
|
||||
const (
|
||||
TypeNone = iota
|
||||
TypeDebug
|
||||
TypeInfo
|
||||
TypeWarning
|
||||
TypeError
|
||||
TypeServer
|
||||
TypeFatal
|
||||
)
|
||||
|
||||
// Type properties
|
||||
var typeProps = map[int]struct {
|
||||
tag string
|
||||
color string
|
||||
}{
|
||||
TypeDebug: {"D", colorCyan},
|
||||
TypeInfo: {"I", colorBlue},
|
||||
TypeWarning: {"W", colorYellow},
|
||||
TypeError: {"E", colorRed},
|
||||
TypeServer: {"S", colorGreen},
|
||||
TypeFatal: {"F", colorPurple},
|
||||
}
|
||||
|
||||
const timeFormat = "15:04:05"
|
||||
|
||||
var (
|
||||
globalLogger *Logger
|
||||
globalLoggerOnce sync.Once
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
writer io.Writer
|
||||
useColors bool
|
||||
timeFormat string
|
||||
showTimestamp bool
|
||||
mu sync.Mutex
|
||||
debugMode atomic.Bool
|
||||
}
|
||||
|
||||
func GetLogger() *Logger {
|
||||
globalLoggerOnce.Do(func() {
|
||||
globalLogger = newLogger(true, true)
|
||||
})
|
||||
return globalLogger
|
||||
}
|
||||
|
||||
func InitGlobalLogger(useColors bool, showTimestamp bool) {
|
||||
globalLogger = newLogger(useColors, showTimestamp)
|
||||
}
|
||||
|
||||
func newLogger(useColors bool, showTimestamp bool) *Logger {
|
||||
return &Logger{
|
||||
writer: os.Stdout,
|
||||
useColors: useColors,
|
||||
timeFormat: timeFormat,
|
||||
showTimestamp: showTimestamp,
|
||||
}
|
||||
}
|
||||
|
||||
func New(useColors bool, showTimestamp bool) *Logger {
|
||||
return newLogger(useColors, showTimestamp)
|
||||
}
|
||||
|
||||
func (l *Logger) SetOutput(w io.Writer) {
|
||||
l.mu.Lock()
|
||||
l.writer = w
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
func (l *Logger) TimeFormat() string {
|
||||
return l.timeFormat
|
||||
}
|
||||
|
||||
func (l *Logger) SetTimeFormat(format string) {
|
||||
l.mu.Lock()
|
||||
l.timeFormat = format
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
func (l *Logger) EnableTimestamp() {
|
||||
l.showTimestamp = true
|
||||
}
|
||||
|
||||
func (l *Logger) DisableTimestamp() {
|
||||
l.showTimestamp = false
|
||||
}
|
||||
|
||||
func (l *Logger) EnableColors() {
|
||||
l.useColors = true
|
||||
}
|
||||
|
||||
func (l *Logger) DisableColors() {
|
||||
l.useColors = false
|
||||
}
|
||||
|
||||
func (l *Logger) EnableDebug() {
|
||||
l.debugMode.Store(true)
|
||||
}
|
||||
|
||||
func (l *Logger) DisableDebug() {
|
||||
l.debugMode.Store(false)
|
||||
}
|
||||
|
||||
func (l *Logger) IsDebugEnabled() bool {
|
||||
return l.debugMode.Load()
|
||||
}
|
||||
|
||||
func (l *Logger) applyColor(text, color string) string {
|
||||
if l.useColors {
|
||||
return color + text + colorReset
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func stripAnsiColors(s string) string {
|
||||
result := ""
|
||||
inEscape := false
|
||||
|
||||
for _, c := range s {
|
||||
if inEscape {
|
||||
if c == 'm' {
|
||||
inEscape = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if c == '\033' {
|
||||
inEscape = true
|
||||
continue
|
||||
}
|
||||
result += string(c)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (l *Logger) writeMessage(logType int, message string, rawMode bool) {
|
||||
if rawMode {
|
||||
l.mu.Lock()
|
||||
_, _ = fmt.Fprint(l.writer, message+"\n")
|
||||
l.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
parts := []string{}
|
||||
|
||||
if l.showTimestamp {
|
||||
timestamp := time.Now().Format(l.timeFormat)
|
||||
if l.useColors {
|
||||
timestamp = l.applyColor(timestamp, colorGray)
|
||||
}
|
||||
parts = append(parts, timestamp)
|
||||
}
|
||||
|
||||
if logType != TypeNone {
|
||||
props := typeProps[logType]
|
||||
tag := "[" + props.tag + "]"
|
||||
if l.useColors {
|
||||
tag = l.applyColor(tag, props.color)
|
||||
}
|
||||
parts = append(parts, tag)
|
||||
}
|
||||
|
||||
parts = append(parts, message)
|
||||
logLine := strings.Join(parts, " ") + "\n"
|
||||
|
||||
l.mu.Lock()
|
||||
_, _ = fmt.Fprint(l.writer, logLine)
|
||||
if logType == TypeFatal {
|
||||
if f, ok := l.writer.(*os.File); ok {
|
||||
_ = f.Sync()
|
||||
}
|
||||
}
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
func (l *Logger) log(logType int, format string, args ...any) {
|
||||
// Only filter debug messages
|
||||
if logType == TypeDebug && !l.debugMode.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
var message string
|
||||
if len(args) > 0 {
|
||||
message = fmt.Sprintf(format, args...)
|
||||
} else {
|
||||
message = format
|
||||
}
|
||||
|
||||
l.writeMessage(logType, message, false)
|
||||
|
||||
if logType == TypeFatal {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) LogRaw(format string, args ...any) {
|
||||
var message string
|
||||
if len(args) > 0 {
|
||||
message = fmt.Sprintf(format, args...)
|
||||
} else {
|
||||
message = format
|
||||
}
|
||||
|
||||
if !l.useColors {
|
||||
message = stripAnsiColors(message)
|
||||
}
|
||||
|
||||
l.writeMessage(TypeInfo, message, true)
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, args ...any) {
|
||||
l.log(TypeDebug, format, args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, args ...any) {
|
||||
l.log(TypeInfo, format, args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Warning(format string, args ...any) {
|
||||
l.log(TypeWarning, format, args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, args ...any) {
|
||||
l.log(TypeError, format, args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Fatal(format string, args ...any) {
|
||||
l.log(TypeFatal, format, args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Server(format string, args ...any) {
|
||||
l.log(TypeServer, format, args...)
|
||||
}
|
||||
|
||||
func (l *Logger) LogRequest(statusCode int, method, path string, duration time.Duration) {
|
||||
var statusColor string
|
||||
|
||||
switch {
|
||||
case statusCode < 300:
|
||||
statusColor = colorGreen
|
||||
case statusCode < 400:
|
||||
statusColor = colorCyan
|
||||
case statusCode < 500:
|
||||
statusColor = colorYellow
|
||||
default:
|
||||
statusColor = colorRed
|
||||
}
|
||||
|
||||
var durationStr string
|
||||
micros := duration.Microseconds()
|
||||
if micros < 1000 {
|
||||
durationStr = fmt.Sprintf("%.0fµs", float64(micros))
|
||||
} else if micros < 1000000 {
|
||||
durationStr = fmt.Sprintf("%.1fms", float64(micros)/1000)
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.2fs", duration.Seconds())
|
||||
}
|
||||
|
||||
message := fmt.Sprintf("%s %s %s %s",
|
||||
l.applyColor("["+method+"]", colorGray),
|
||||
l.applyColor(fmt.Sprintf("%d", statusCode), statusColor),
|
||||
l.applyColor(path, colorGray),
|
||||
l.applyColor(durationStr, colorGray),
|
||||
)
|
||||
|
||||
l.writeMessage(TypeNone, message, false)
|
||||
}
|
||||
|
||||
// Global functions
|
||||
func Debug(format string, args ...any) {
|
||||
GetLogger().Debug(format, args...)
|
||||
}
|
||||
|
||||
func Info(format string, args ...any) {
|
||||
GetLogger().Info(format, args...)
|
||||
}
|
||||
|
||||
func Warning(format string, args ...any) {
|
||||
GetLogger().Warning(format, args...)
|
||||
}
|
||||
|
||||
func Error(format string, args ...any) {
|
||||
GetLogger().Error(format, args...)
|
||||
}
|
||||
|
||||
func Fatal(format string, args ...any) {
|
||||
GetLogger().Fatal(format, args...)
|
||||
}
|
||||
|
||||
func Server(format string, args ...any) {
|
||||
GetLogger().Server(format, args...)
|
||||
}
|
||||
|
||||
func LogRaw(format string, args ...any) {
|
||||
GetLogger().LogRaw(format, args...)
|
||||
}
|
||||
|
||||
func SetOutput(w io.Writer) {
|
||||
GetLogger().SetOutput(w)
|
||||
}
|
||||
|
||||
func TimeFormat() string {
|
||||
return GetLogger().TimeFormat()
|
||||
}
|
||||
|
||||
func EnableDebug() {
|
||||
GetLogger().EnableDebug()
|
||||
}
|
||||
|
||||
func DisableDebug() {
|
||||
GetLogger().DisableDebug()
|
||||
}
|
||||
|
||||
func IsDebugEnabled() bool {
|
||||
return GetLogger().IsDebugEnabled()
|
||||
}
|
||||
|
||||
func EnableTimestamp() {
|
||||
GetLogger().EnableTimestamp()
|
||||
}
|
||||
|
||||
func DisableTimestamp() {
|
||||
GetLogger().DisableTimestamp()
|
||||
}
|
||||
|
||||
func LogRequest(statusCode int, method, path string, duration time.Duration) {
|
||||
GetLogger().LogRequest(statusCode, method, path, duration)
|
||||
}
|
||||
11
utils/metadata/metadata.go
Normal file
11
utils/metadata/metadata.go
Normal file
@ -0,0 +1,11 @@
|
||||
package metadata
|
||||
|
||||
// Version holds the current Moonshark version
|
||||
const Version = "1.0"
|
||||
|
||||
// Build time information
|
||||
var (
|
||||
BuildTime = "unknown"
|
||||
GitCommit = "unknown"
|
||||
GoVersion = "unknown"
|
||||
)
|
||||
197
watcher.go
197
watcher.go
@ -1,197 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FileWatcher struct {
|
||||
files map[string]time.Time
|
||||
dirs map[string]bool // Track watched directories
|
||||
mu sync.RWMutex
|
||||
restartCh chan bool
|
||||
stopCh chan bool
|
||||
debounceMs int
|
||||
lastEvent time.Time
|
||||
pollMs int
|
||||
}
|
||||
|
||||
func NewFileWatcher(debounceMs int) (*FileWatcher, error) {
|
||||
return &FileWatcher{
|
||||
files: make(map[string]time.Time),
|
||||
dirs: make(map[string]bool),
|
||||
restartCh: make(chan bool, 1),
|
||||
stopCh: make(chan bool, 1),
|
||||
debounceMs: debounceMs,
|
||||
pollMs: 250, // Poll every 250ms
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (fw *FileWatcher) AddFile(path string) error {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fw.mu.Lock()
|
||||
fw.files[absPath] = info.ModTime()
|
||||
fw.mu.Unlock()
|
||||
|
||||
fmt.Printf("Watching: %s\n", absPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fw *FileWatcher) AddDirectory(dir string) error {
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fw.mu.Lock()
|
||||
fw.dirs[absDir] = true
|
||||
fw.mu.Unlock()
|
||||
|
||||
return filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !info.IsDir() && strings.HasSuffix(path, ".lua") {
|
||||
return fw.AddFile(path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (fw *FileWatcher) Start() <-chan bool {
|
||||
go fw.pollLoop()
|
||||
return fw.restartCh
|
||||
}
|
||||
|
||||
func (fw *FileWatcher) pollLoop() {
|
||||
ticker := time.NewTicker(time.Duration(fw.pollMs) * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-fw.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
fw.checkFiles()
|
||||
fw.scanForNewFiles()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fw *FileWatcher) checkFiles() {
|
||||
fw.mu.RLock()
|
||||
files := make(map[string]time.Time, len(fw.files))
|
||||
maps.Copy(files, fw.files)
|
||||
fw.mu.RUnlock()
|
||||
|
||||
changed := false
|
||||
for path, lastMod := range files {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if info.ModTime().After(lastMod) {
|
||||
fw.mu.Lock()
|
||||
fw.files[path] = info.ModTime()
|
||||
fw.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if now.Sub(fw.lastEvent) > time.Duration(fw.debounceMs)*time.Millisecond {
|
||||
fw.lastEvent = now
|
||||
// log.Printf("File changed: %s", path)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if changed {
|
||||
select {
|
||||
case fw.restartCh <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fw *FileWatcher) scanForNewFiles() {
|
||||
fw.mu.RLock()
|
||||
dirs := make(map[string]bool, len(fw.dirs))
|
||||
maps.Copy(dirs, fw.dirs)
|
||||
existingFiles := make(map[string]bool, len(fw.files))
|
||||
for path := range fw.files {
|
||||
existingFiles[path] = true
|
||||
}
|
||||
fw.mu.RUnlock()
|
||||
|
||||
newFilesFound := false
|
||||
for dir := range dirs {
|
||||
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !info.IsDir() && strings.HasSuffix(path, ".lua") {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !existingFiles[absPath] {
|
||||
fw.mu.Lock()
|
||||
fw.files[absPath] = info.ModTime()
|
||||
fw.mu.Unlock()
|
||||
|
||||
fmt.Printf("New file detected: %s\n", absPath)
|
||||
newFilesFound = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if newFilesFound {
|
||||
now := time.Now()
|
||||
if now.Sub(fw.lastEvent) > time.Duration(fw.debounceMs)*time.Millisecond {
|
||||
fw.lastEvent = now
|
||||
select {
|
||||
case fw.restartCh <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fw *FileWatcher) Close() error {
|
||||
select {
|
||||
case fw.stopCh <- true:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fw *FileWatcher) DiscoverRequiredFiles(scriptPath string) error {
|
||||
if err := fw.AddFile(scriptPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scriptDir := filepath.Dir(scriptPath)
|
||||
return fw.AddDirectory(scriptDir)
|
||||
}
|
||||
84
watchers/api.go
Normal file
84
watchers/api.go
Normal file
@ -0,0 +1,84 @@
|
||||
package watchers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"Moonshark/router"
|
||||
"Moonshark/runner"
|
||||
"Moonshark/utils/color"
|
||||
"Moonshark/utils/logger"
|
||||
)
|
||||
|
||||
// Global watcher manager instance with explicit creation
|
||||
var (
|
||||
globalManager *WatcherManager
|
||||
globalManagerOnce sync.Once
|
||||
)
|
||||
|
||||
// GetWatcherManager returns the global watcher manager, creating it if needed
|
||||
func GetWatcherManager() *WatcherManager {
|
||||
globalManagerOnce.Do(func() {
|
||||
globalManager = NewWatcherManager(DefaultPollInterval)
|
||||
})
|
||||
return globalManager
|
||||
}
|
||||
|
||||
// ShutdownWatcherManager closes the global watcher manager if it exists
|
||||
func ShutdownWatcherManager() {
|
||||
if globalManager != nil {
|
||||
globalManager.Close()
|
||||
globalManager = nil
|
||||
}
|
||||
}
|
||||
|
||||
// WatchLuaRouter sets up a watcher for a LuaRouter's routes directory
|
||||
func WatchLuaRouter(router *router.LuaRouter, runner *runner.Runner, routesDir string) (*DirectoryWatcher, error) {
|
||||
manager := GetWatcherManager()
|
||||
|
||||
config := DirectoryWatcherConfig{
|
||||
Dir: routesDir,
|
||||
Callback: router.Refresh,
|
||||
Recursive: true,
|
||||
}
|
||||
|
||||
watcher, err := manager.WatchDirectory(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to watch directory: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Started watching Lua routes! %s", color.Apply(routesDir, color.Yellow))
|
||||
return watcher, nil
|
||||
}
|
||||
|
||||
// WatchLuaModules sets up watchers for Lua module directories
|
||||
func WatchLuaModules(luaRunner *runner.Runner, libDirs []string) ([]*DirectoryWatcher, error) {
|
||||
manager := GetWatcherManager()
|
||||
watchers := make([]*DirectoryWatcher, 0, len(libDirs))
|
||||
|
||||
for _, dir := range libDirs {
|
||||
config := DirectoryWatcherConfig{
|
||||
Dir: dir,
|
||||
EnhancedCallback: func(changes []FileChange) error {
|
||||
for _, change := range changes {
|
||||
if !change.IsDeleted && strings.HasSuffix(change.Path, ".lua") {
|
||||
luaRunner.NotifyFileChanged(change.Path)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Recursive: true,
|
||||
}
|
||||
|
||||
watcher, err := manager.WatchDirectory(config)
|
||||
if err != nil {
|
||||
// Error handling...
|
||||
}
|
||||
|
||||
watchers = append(watchers, watcher)
|
||||
logger.Info("Started watching Lua modules! %s", color.Apply(dir, color.Yellow))
|
||||
}
|
||||
|
||||
return watchers, nil
|
||||
}
|
||||
212
watchers/dir.go
Normal file
212
watchers/dir.go
Normal file
@ -0,0 +1,212 @@
|
||||
package watchers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"Moonshark/utils/logger"
|
||||
)
|
||||
|
||||
// Default debounce time between detected change and callback
|
||||
const defaultDebounceTime = 300 * time.Millisecond
|
||||
|
||||
// FileChange represents a detected file change
|
||||
type FileChange struct {
|
||||
Path string
|
||||
IsNew bool
|
||||
IsDeleted bool
|
||||
}
|
||||
|
||||
// FileInfo stores minimal metadata about a file for change detection
|
||||
type FileInfo struct {
|
||||
ModTime time.Time
|
||||
}
|
||||
|
||||
// DirectoryWatcher watches a specific directory for changes
|
||||
type DirectoryWatcher struct {
|
||||
// Directory to watch
|
||||
dir string
|
||||
|
||||
// Map of file paths to their metadata
|
||||
files map[string]FileInfo
|
||||
filesMu sync.RWMutex
|
||||
|
||||
// Track changed files during a check cycle
|
||||
changedFiles []FileChange
|
||||
|
||||
// Enhanced callback that receives changes (optional)
|
||||
enhancedCallback func([]FileChange) error
|
||||
|
||||
// Configuration
|
||||
callback func() error
|
||||
debounceTime time.Duration
|
||||
recursive bool
|
||||
|
||||
// Debounce timer
|
||||
debounceTimer *time.Timer
|
||||
debouncing bool
|
||||
debounceMu sync.Mutex
|
||||
|
||||
// Error tracking
|
||||
consecutiveErrors int
|
||||
lastError error
|
||||
}
|
||||
|
||||
// DirectoryWatcherConfig contains configuration for a directory watcher
|
||||
type DirectoryWatcherConfig struct {
|
||||
Dir string // Directory to watch
|
||||
Callback func() error // Callback function to call when changes are detected
|
||||
DebounceTime time.Duration // Debounce time (0 means use default)
|
||||
Recursive bool // Recursive watching (watch subdirectories)
|
||||
EnhancedCallback func([]FileChange) error // Enhanced callback that receives file changes
|
||||
}
|
||||
|
||||
// scanDirectory builds the initial file list
|
||||
func (w *DirectoryWatcher) scanDirectory() error {
|
||||
w.filesMu.Lock()
|
||||
defer w.filesMu.Unlock()
|
||||
|
||||
w.files = make(map[string]FileInfo)
|
||||
|
||||
return filepath.Walk(w.dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip files with errors
|
||||
}
|
||||
|
||||
if !w.recursive && info.IsDir() && path != w.dir {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
w.files[path] = FileInfo{
|
||||
ModTime: info.ModTime(),
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// checkForChanges detects if any files have been added, modified, or deleted
|
||||
func (w *DirectoryWatcher) checkForChanges() (bool, error) {
|
||||
w.filesMu.RLock()
|
||||
prevFileCount := len(w.files)
|
||||
w.filesMu.RUnlock()
|
||||
|
||||
newFiles := make(map[string]FileInfo)
|
||||
changed := false
|
||||
w.changedFiles = nil // Reset changed files list
|
||||
|
||||
err := filepath.Walk(w.dir, func(path string, info os.FileInfo, err error) error {
|
||||
// Skip errors
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !w.recursive && info.IsDir() && path != w.dir {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
currentInfo := FileInfo{
|
||||
ModTime: info.ModTime(),
|
||||
}
|
||||
newFiles[path] = currentInfo
|
||||
|
||||
w.filesMu.RLock()
|
||||
prevInfo, exists := w.files[path]
|
||||
w.filesMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
changed = true
|
||||
w.changedFiles = append(w.changedFiles, FileChange{
|
||||
Path: path,
|
||||
IsNew: true,
|
||||
})
|
||||
w.logDebug("File added: %s", path)
|
||||
} else if currentInfo.ModTime != prevInfo.ModTime {
|
||||
changed = true
|
||||
w.changedFiles = append(w.changedFiles, FileChange{
|
||||
Path: path,
|
||||
})
|
||||
w.logDebug("File changed: %s", path)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Only check for deleted files if needed
|
||||
if err == nil && (!changed && len(newFiles) != prevFileCount) {
|
||||
w.filesMu.RLock()
|
||||
for path := range w.files {
|
||||
if _, exists := newFiles[path]; !exists {
|
||||
changed = true
|
||||
w.changedFiles = append(w.changedFiles, FileChange{
|
||||
Path: path,
|
||||
IsDeleted: true,
|
||||
})
|
||||
w.logDebug("File deleted: %s", path)
|
||||
break // We already know changes happened
|
||||
}
|
||||
}
|
||||
w.filesMu.RUnlock()
|
||||
}
|
||||
|
||||
if changed {
|
||||
w.filesMu.Lock()
|
||||
w.files = newFiles
|
||||
w.filesMu.Unlock()
|
||||
}
|
||||
|
||||
return changed, err
|
||||
}
|
||||
|
||||
// notifyChange triggers the callback with debouncing
|
||||
func (w *DirectoryWatcher) notifyChange() {
|
||||
w.debounceMu.Lock()
|
||||
defer w.debounceMu.Unlock()
|
||||
|
||||
if w.debouncing {
|
||||
if w.debounceTimer != nil {
|
||||
w.debounceTimer.Stop()
|
||||
}
|
||||
} else {
|
||||
w.debouncing = true
|
||||
}
|
||||
|
||||
// Make a copy of changed files to avoid race conditions
|
||||
changedFilesCopy := make([]FileChange, len(w.changedFiles))
|
||||
copy(changedFilesCopy, w.changedFiles)
|
||||
|
||||
w.debounceTimer = time.AfterFunc(w.debounceTime, func() {
|
||||
var err error
|
||||
if w.enhancedCallback != nil {
|
||||
err = w.enhancedCallback(changedFilesCopy)
|
||||
} else if w.callback != nil {
|
||||
err = w.callback()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
w.logError("Callback error: %v", err)
|
||||
}
|
||||
|
||||
w.debounceMu.Lock()
|
||||
w.debouncing = false
|
||||
w.debounceMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
// logDebug logs a debug message with the watcher's directory prefix
|
||||
func (w *DirectoryWatcher) logDebug(format string, args ...any) {
|
||||
logger.Debug("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
// logError logs an error message with the watcher's directory prefix
|
||||
func (w *DirectoryWatcher) logError(format string, args ...any) {
|
||||
logger.Error("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
// GetDir gets the DirectoryWatcher's current directory
|
||||
func (w *DirectoryWatcher) GetDir() string {
|
||||
return w.dir
|
||||
}
|
||||
154
watchers/manager.go
Normal file
154
watchers/manager.go
Normal file
@ -0,0 +1,154 @@
|
||||
package watchers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"Moonshark/utils/logger"
|
||||
)
|
||||
|
||||
// DefaultPollInterval is the time between directory checks
|
||||
const DefaultPollInterval = 1 * time.Second
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrDirectoryNotFound = errors.New("directory not found")
|
||||
ErrAlreadyWatching = errors.New("already watching directory")
|
||||
)
|
||||
|
||||
// WatcherManager coordinates file watching across multiple directories
|
||||
type WatcherManager struct {
|
||||
watchers map[string]*DirectoryWatcher
|
||||
mu sync.RWMutex
|
||||
|
||||
done chan struct{}
|
||||
ticker *time.Ticker
|
||||
interval time.Duration
|
||||
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewWatcherManager creates a new watcher manager with a specified poll interval
|
||||
func NewWatcherManager(pollInterval time.Duration) *WatcherManager {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = DefaultPollInterval
|
||||
}
|
||||
|
||||
manager := &WatcherManager{
|
||||
watchers: make(map[string]*DirectoryWatcher),
|
||||
done: make(chan struct{}),
|
||||
interval: pollInterval,
|
||||
}
|
||||
|
||||
manager.ticker = time.NewTicker(pollInterval)
|
||||
manager.wg.Add(1)
|
||||
go manager.pollLoop()
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// Close stops all watchers and the manager
|
||||
func (m *WatcherManager) Close() error {
|
||||
close(m.done)
|
||||
if m.ticker != nil {
|
||||
m.ticker.Stop()
|
||||
}
|
||||
m.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// WatchDirectory adds a new directory to watch and returns the watcher
|
||||
func (m *WatcherManager) WatchDirectory(config DirectoryWatcherConfig) (*DirectoryWatcher, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.watchers[config.Dir]; exists {
|
||||
return nil, ErrAlreadyWatching
|
||||
}
|
||||
|
||||
if config.DebounceTime == 0 {
|
||||
config.DebounceTime = defaultDebounceTime
|
||||
}
|
||||
|
||||
watcher := &DirectoryWatcher{
|
||||
dir: config.Dir,
|
||||
files: make(map[string]FileInfo),
|
||||
callback: config.Callback,
|
||||
enhancedCallback: config.EnhancedCallback,
|
||||
debounceTime: config.DebounceTime,
|
||||
recursive: config.Recursive,
|
||||
}
|
||||
|
||||
// Perform initial scan
|
||||
if err := watcher.scanDirectory(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.watchers[config.Dir] = watcher
|
||||
logger.Debug("WatcherManager added watcher for %s", config.Dir)
|
||||
|
||||
return watcher, nil
|
||||
}
|
||||
|
||||
// UnwatchDirectory removes a directory from being watched
|
||||
func (m *WatcherManager) UnwatchDirectory(dir string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.watchers[dir]; !exists {
|
||||
return ErrDirectoryNotFound
|
||||
}
|
||||
|
||||
delete(m.watchers, dir)
|
||||
logger.Debug("WatcherManager removed watcher for %s", dir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// pollLoop is the main polling loop that checks all watched directories
|
||||
func (m *WatcherManager) pollLoop() {
|
||||
defer m.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ticker.C:
|
||||
m.checkAllDirectories()
|
||||
case <-m.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkAllDirectories polls all registered directories for changes
|
||||
func (m *WatcherManager) checkAllDirectories() {
|
||||
m.mu.RLock()
|
||||
watchers := make([]*DirectoryWatcher, 0, len(m.watchers))
|
||||
for _, w := range m.watchers {
|
||||
watchers = append(watchers, w)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
changesDetected := 0
|
||||
|
||||
for _, watcher := range watchers {
|
||||
if watcher.consecutiveErrors > 3 {
|
||||
if watcher.consecutiveErrors == 4 {
|
||||
logger.Error("Temporarily skipping directory %s due to errors: %v",
|
||||
watcher.dir, watcher.lastError)
|
||||
watcher.consecutiveErrors++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
changed, err := watcher.checkForChanges()
|
||||
if err != nil {
|
||||
logger.Error("Error checking directory %s: %v", watcher.dir, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if changed {
|
||||
changesDetected++
|
||||
watcher.notifyChange()
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user