1
0

Compare commits

..

No commits in common. "master" and "v0.5.0" have entirely different histories.

4 changed files with 107 additions and 238 deletions

View File

@ -109,7 +109,7 @@ if err != nil {
``` ```
### ToTable(index int) (any, error) ### ToTable(index int) (any, error)
Converts a Lua table to optimal Go type; arrays or `map[string]any`. Converts a Lua table to optimal Go type ([]int, []string, map[string]any, etc.).
```go ```go
table, err := L.ToTable(-1) table, err := L.ToTable(-1)
if err != nil { if err != nil {

View File

@ -116,7 +116,7 @@ L.PushValue(stuff) // Handles all Go types automatically
// Lua → Go with automatic type detection // Lua → Go with automatic type detection
L.GetGlobal("some_table") L.GetGlobal("some_table")
result, err := L.ToTable(-1) // Returns optimal Go type (typed array, or map[string]any) result, err := L.ToTable(-1) // Returns optimal Go type ([]int, map[string]string, etc.)
``` ```
### Table Builder ### Table Builder

View File

@ -22,12 +22,6 @@ import (
// GoFunction defines the signature for Go functions callable from Lua // GoFunction defines the signature for Go functions callable from Lua
type GoFunction func(*State) int type GoFunction func(*State) int
// LuaFunction represents a Lua function callable from Go
type LuaFunction struct {
state *State
ref int
}
// Static registry size reduces resizing operations // Static registry size reduces resizing operations
const initialRegistrySize = 64 const initialRegistrySize = 64
@ -41,14 +35,6 @@ var (
funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize), funcs: make(map[unsafe.Pointer]GoFunction, initialRegistrySize),
} }
// luaFunctionRegistry stores Lua function references
luaFunctionRegistry = struct {
sync.RWMutex
refs map[int]*State
}{
refs: make(map[int]*State),
}
// statePool reuses State structs to avoid allocations // statePool reuses State structs to avoid allocations
statePool = sync.Pool{ statePool = sync.Pool{
New: func() any { New: func() any {
@ -117,193 +103,14 @@ func (s *State) UnregisterGoFunction(name string) {
s.SetGlobal(name) s.SetGlobal(name)
} }
// StoreLuaFunction stores a Lua function from the stack and returns a reference // Cleanup frees all function pointers and clears the registry
func (s *State) StoreLuaFunction(index int) (*LuaFunction, error) {
if !s.IsFunction(index) {
return nil, fmt.Errorf("value at index %d is not a function", index)
}
s.PushCopy(index)
ref := int(C.luaL_ref(s.L, C.LUA_REGISTRYINDEX))
if ref == C.LUA_REFNIL {
return nil, fmt.Errorf("failed to store function reference")
}
luaFunc := &LuaFunction{
state: s,
ref: ref,
}
luaFunctionRegistry.Lock()
luaFunctionRegistry.refs[ref] = s
luaFunctionRegistry.Unlock()
return luaFunc, nil
}
// GetLuaFunction gets a global Lua function and stores it
func (s *State) GetLuaFunction(name string) (*LuaFunction, error) {
s.GetGlobal(name)
defer s.Pop(1)
if !s.IsFunction(-1) {
return nil, fmt.Errorf("global '%s' is not a function", name)
}
return s.StoreLuaFunction(-1)
}
// Call executes the Lua function with given arguments and returns results
func (lf *LuaFunction) Call(args ...any) ([]any, error) {
s := lf.state
// Push function from registry
C.lua_rawgeti(s.L, C.LUA_REGISTRYINDEX, C.int(lf.ref))
// Push arguments
for i, arg := range args {
if err := s.PushValue(arg); err != nil {
s.Pop(i + 1) // Clean up function and pushed args
return nil, fmt.Errorf("failed to push argument %d: %w", i+1, err)
}
}
// Call function
baseTop := s.GetTop() - len(args) - 1
if err := s.Call(len(args), C.LUA_MULTRET); err != nil {
return nil, err
}
// Extract results
newTop := s.GetTop()
nresults := newTop - baseTop
results := make([]any, nresults)
for i := 0; i < nresults; i++ {
val, err := s.ToValue(baseTop + i + 1)
if err != nil {
results[i] = nil
} else {
results[i] = val
}
}
s.SetTop(baseTop)
return results, nil
}
// CallSingle calls the function and returns only the first result
func (lf *LuaFunction) CallSingle(args ...any) (any, error) {
results, err := lf.Call(args...)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, nil
}
return results[0], nil
}
// CallTyped calls the function and converts the first result to the specified type
func CallTyped[T any](lf *LuaFunction, args ...any) (T, error) {
var zero T
result, err := lf.CallSingle(args...)
if err != nil {
return zero, err
}
if result == nil {
return zero, nil
}
if converted, ok := ConvertValue[T](result); ok {
return converted, nil
}
return zero, fmt.Errorf("cannot convert result to %T", zero)
}
// Release releases the Lua function reference
func (lf *LuaFunction) Release() {
if lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL {
luaFunctionRegistry.Lock()
delete(luaFunctionRegistry.refs, lf.ref)
luaFunctionRegistry.Unlock()
C.luaL_unref(lf.state.L, C.LUA_REGISTRYINDEX, C.int(lf.ref))
lf.ref = C.LUA_NOREF
}
}
// IsValid checks if the function reference is still valid
func (lf *LuaFunction) IsValid() bool {
return lf.ref != C.LUA_NOREF && lf.ref != C.LUA_REFNIL
}
// ToGoFunction converts to a standard Go function signature
func (lf *LuaFunction) ToGoFunction() func(...any) ([]any, error) {
return func(args ...any) ([]any, error) {
return lf.Call(args...)
}
}
// CreateCallback creates a reusable callback function
func (s *State) CreateCallback(luaCode string) (*LuaFunction, error) {
if err := s.LoadString(luaCode); err != nil {
return nil, fmt.Errorf("failed to load callback code: %w", err)
}
luaFunc, err := s.StoreLuaFunction(-1)
s.Pop(1) // Remove function from stack
return luaFunc, err
}
// Cleanup frees all function pointers and clears registries
func (s *State) Cleanup() { func (s *State) Cleanup() {
// Clean up Go function registry
functionRegistry.Lock() functionRegistry.Lock()
defer functionRegistry.Unlock()
// Free all allocated pointers
for ptr := range functionRegistry.funcs { for ptr := range functionRegistry.funcs {
C.free(ptr) C.free(ptr)
delete(functionRegistry.funcs, ptr) delete(functionRegistry.funcs, ptr)
} }
functionRegistry.Unlock()
// Clean up Lua function registry for this state
luaFunctionRegistry.Lock()
for ref, state := range luaFunctionRegistry.refs {
if state == s {
C.luaL_unref(s.L, C.LUA_REGISTRYINDEX, C.int(ref))
delete(luaFunctionRegistry.refs, ref)
}
}
luaFunctionRegistry.Unlock()
}
// BatchRegisterGoFunctions registers multiple Go functions at once
func (s *State) BatchRegisterGoFunctions(funcs map[string]GoFunction) error {
for name, fn := range funcs {
if err := s.RegisterGoFunction(name, fn); err != nil {
return fmt.Errorf("failed to register function '%s': %w", name, err)
}
}
return nil
}
// GetAllLuaFunctions gets multiple global Lua functions by name
func (s *State) GetAllLuaFunctions(names ...string) (map[string]*LuaFunction, error) {
funcs := make(map[string]*LuaFunction, len(names))
for _, name := range names {
if fn, err := s.GetLuaFunction(name); err == nil {
funcs[name] = fn
} else {
// Clean up any successfully created functions
for _, f := range funcs {
f.Release()
}
return nil, fmt.Errorf("failed to get function '%s': %w", name, err)
}
}
return funcs, nil
}
// PushLuaFunction pushes a stored LuaFunction reference onto the stack
func (s *State) PushLuaFunction(lf *LuaFunction) {
C.lua_rawgeti(s.L, C.LUA_REGISTRYINDEX, C.int(lf.ref))
} }

View File

@ -71,6 +71,39 @@ static int sample_array_type(lua_State *L, int index, int count) {
if (all_bools) return 4; if (all_bools) return 4;
return 0; return 0;
} }
static int sample_map_type(lua_State *L, int index) {
int all_string_vals = 1;
int all_int_vals = 1;
int all_int_keys = 1;
int count = 0;
lua_pushnil(L);
while (lua_next(L, index) && count < 5) {
if (lua_type(L, -2) != LUA_TSTRING) {
all_int_keys = 0;
} else {
const char *key = lua_tostring(L, -2);
char *endptr;
strtol(key, &endptr, 10);
if (*endptr != '\0') all_int_keys = 0;
}
int val_type = lua_type(L, -1);
if (val_type != LUA_TSTRING) all_string_vals = 0;
if (val_type != LUA_TNUMBER || !is_integer(L, -1)) all_int_vals = 0;
lua_pop(L, 1);
count++;
if (!all_string_vals && !all_int_vals && !all_int_keys) break;
}
if (all_int_keys) return 4;
if (all_string_vals) return 1;
if (all_int_vals) return 2;
return 3;
}
*/ */
import "C" import "C"
import ( import (
@ -191,6 +224,7 @@ func (s *State) GetTableLength(index int) int {
return int(C.get_table_length(s.L, C.int(index))) return int(C.get_table_length(s.L, C.int(index)))
} }
// Enhanced PushValue with comprehensive type support
func (s *State) PushValue(v any) error { func (s *State) PushValue(v any) error {
switch val := v.(type) { switch val := v.(type) {
case nil: case nil:
@ -205,14 +239,6 @@ func (s *State) PushValue(v any) error {
s.PushNumber(val) s.PushNumber(val)
case string: case string:
s.PushString(val) s.PushString(val)
case GoFunction:
return s.PushGoFunction(val)
case *LuaFunction:
C.lua_rawgeti(s.L, C.LUA_REGISTRYINDEX, C.int(val.ref))
if val.ref == C.LUA_NOREF || val.ref == C.LUA_REFNIL {
s.Pop(1)
s.PushNil()
}
case []int: case []int:
return s.pushIntSlice(val) return s.pushIntSlice(val)
case []string: case []string:
@ -223,8 +249,6 @@ func (s *State) PushValue(v any) error {
return s.pushFloatSlice(val) return s.pushFloatSlice(val)
case []any: case []any:
return s.pushAnySlice(val) return s.pushAnySlice(val)
case []map[string]any:
return s.pushMapSlice(val)
case map[string]string: case map[string]string:
return s.pushStringMap(val) return s.pushStringMap(val)
case map[string]int: case map[string]int:
@ -382,7 +406,17 @@ func (s *State) ToTable(index int) (any, error) {
} }
} }
mapType := int(C.sample_map_type(s.L, C.int(absIdx)))
switch mapType {
case 1: // map[string]string
return s.extractStringMap(absIdx)
case 2: // map[string]int
return s.extractIntMap(absIdx)
case 4: // map[int]any
return s.extractIntKeyMap(absIdx)
default: // map[string]any
return s.extractAnyMap(absIdx) return s.extractAnyMap(absIdx)
}
} }
func (s *State) extractIntArray(index, length int) []int { func (s *State) extractIntArray(index, length int) []int {
@ -442,6 +476,62 @@ func (s *State) extractAnyArray(index, length int) []any {
return result return result
} }
func (s *State) extractStringMap(index int) (map[string]string, error) {
result := make(map[string]string)
s.PushNil()
for s.Next(index) {
if s.GetType(-2) == TypeString {
key := s.ToString(-2)
value := s.ToString(-1)
result[key] = value
}
s.Pop(1)
}
return result, nil
}
func (s *State) extractIntMap(index int) (map[string]int, error) {
result := make(map[string]int)
s.PushNil()
for s.Next(index) {
if s.GetType(-2) == TypeString {
key := s.ToString(-2)
value := int(s.ToNumber(-1))
result[key] = value
}
s.Pop(1)
}
return result, nil
}
func (s *State) extractIntKeyMap(index int) (map[int]any, error) {
result := make(map[int]any)
s.PushNil()
for s.Next(index) {
var key int
switch s.GetType(-2) {
case TypeString:
if k, err := strconv.Atoi(s.ToString(-2)); err == nil {
key = k
} else {
s.Pop(1)
continue
}
case TypeNumber:
key = int(s.ToNumber(-2))
default:
s.Pop(1)
continue
}
if value, err := s.ToValue(-1); err == nil {
result[key] = value
}
s.Pop(1)
}
return result, nil
}
func (s *State) extractAnyMap(index int) (map[string]any, error) { func (s *State) extractAnyMap(index int) (map[string]any, error) {
result := make(map[string]any) result := make(map[string]any)
s.PushNil() s.PushNil()
@ -720,31 +810,3 @@ func (s *State) CallGlobal(name string, args ...any) ([]any, error) {
s.SetTop(baseTop) s.SetTop(baseTop)
return results, nil return results, nil
} }
func (s *State) pushMapSlice(arr []map[string]any) error {
s.CreateTable(len(arr), 0)
for i, m := range arr {
s.PushNumber(float64(i + 1))
if err := s.PushValue(m); err != nil {
return err
}
s.SetTable(-3)
}
return nil
}
// PushLightUserData pushes a light userdata value
func (s *State) PushLightUserData(ptr any) {
C.lua_pushlightuserdata(s.L, unsafe.Pointer(&ptr))
}
// Helper method for getting table length used internally
func (s *State) getTableLength(index int) int {
length := 0
s.PushNil()
for s.Next(index - 1) {
length++
s.Pop(1)
}
return length
}