diff --git a/src/api/api.lua b/src/api/api.lua index 75453c0..a9f7089 100644 --- a/src/api/api.lua +++ b/src/api/api.lua @@ -1,18 +1,104 @@ --- --- Generated by EmmyLua(https://github.com/EmmyLua) --- Created by admin. ---- DateTime: 2025/9/25 14:19 +--- DateTime: 2025/9/24 15:29 --- +-- local user = require('api.system.user')--- 启动调试 +-- local mobdebug = require('src.share.initial.mobdebug'); +-- mobdebug.start(); + local function say_hello(req) ngx.say("Hello, World!") end +local function get_user(req) + ngx.say("call get_user") + local user_id = req.args.id or "unknown" + ngx.say("User ID: " .. user_id) +end + +local function get_id(req) + ngx.say("call get_id") + local args = req.get_uri_args() + -- 获取单个参数 + local id = args["id"] -- 值为 "john" + ngx.say("User ID: " .. user_id) +end + +local function test(req) + local request_method = ngx.var.request_method + local args = nil + ngx.say(request_method) + + --1、获取参数的值 获取前端提交参数 + if "GET" == request_method then + args = ngx.req.get_uri_args() + elseif "POST" == request_method then + ngx.req.read_body() + args = ngx.req.get_post_args() + end + + --2、组合url请求Get/Post请求 并获取参数 + local http = require "resty.http" + local httpc = http.new() + local url = "http://xxxxx/user/login/"..args["userid"].."/"..args["pass"] + local resStr --响应结果 + local res, err = httpc:request_uri(url, { + method = "GET", + --args = str, + body = "a=1&b=2", + headers = { + ["Content-Type"] = "application/json", + } + }) + + --3、开始重新组合参数 例子 可根据返回的JSON自己处理 + local cjson = require "cjson" + local sampleJson = [[{"age":"23","testArray":{"array":[8,9,11,14,25]},"Himi":"himigame.com"}]]; + --解析json字符串 + local data = cjson.decode(sampleJson); + --打印json字符串中的age字段 + ngx.say(data["age"]); + --打印数组中的第一个值(lua默认是从0开始计数) + ngx.say(data["testArray"]["array"][1]); + + --4、打印输出新返回值 + ngx.say(res.body) + + --获取url中a的值 + ngx.say(ngx.var.arg_a); + --获取主机名 + ngx.say(ngx.var.remote_addr); + --获取get和post参数 + local arg = ngx.req.get_uri_args() + for k,v in pairs(arg) do + ngx.say("[GET ] key:", k, " v:", v) + end + + ngx.req.read_body() -- 解析 body 参数之前一定要先读取 body + local arg = ngx.req.get_post_args() + for k,v in pairs(arg) do + ngx.say("[POST] key:", k, " v:", v) + end +end + +---local function conn() +--- local conn1 = user.conn +--- conn1:connect(...) +---end + local routes = { ["/hello"] = say_hello, + ["/user"] = get_user, + ["/userid"] = get_id, } local function handle_request() + local request_method = ngx.var.request_method + local args = nil + ngx.say(request_method) + local uri = ngx.var.request_uri ngx.say("url: " .. uri) local handler = routes[uri] diff --git a/src/api/system/user.lua b/src/api/system/user.lua new file mode 100644 index 0000000..791840a --- /dev/null +++ b/src/api/system/user.lua @@ -0,0 +1,34 @@ +--- +--- Generated by EmmyLua(https://github.com/EmmyLua) +--- Created by admin. +--- DateTime: 2025/9/25 08:19 +--- + +local db_config = require('config.database') +local pgmoon = require('share.pgmoonn') + +-- 创建一个新的连接 +local conn = pgmoon.new(db_config.postgres) + +-- 连接到数据库 +conn:connect(function(err) + if err then + print("Error connecting to database: ", err) + else + print("Connected to the PostgreSQL server.") + + -- 执行一个简单的查询 + conn:query("SELECT version()") + :on_data(function(row) + print("Database Version: ", row[1]) + end) + :on_error(function(err) + print("Query Error: ", err) + end) + :on_finish(function() + print("Query finished.") + -- 关闭连接 + conn:close() + end) + end +end) \ No newline at end of file diff --git a/src/config/config.lua b/src/config/config.lua new file mode 100644 index 0000000..25b5b2e --- /dev/null +++ b/src/config/config.lua @@ -0,0 +1,25 @@ +--- +--- Generated by EmmyLua(https://github.com/EmmyLua) +--- Created by admin. +--- DateTime: 2025/9/24 16:31 +--- + +return { + APP_ENV = "dev", -- dev/prod + + -- 配置redis数据库连接 + REDIS = { + HOST = "127.0.0.1", -- redis host + PORT = 6379, -- redis port + PASSWORD = nil -- redis password + }, + + -- 配置PostgresSQL数据库连接 + POSTGRES = { + HOST = "127.0.0.1", -- postgres host + PORT = 5432, -- postgres port + USERNAME = "postgres", + PASSWORD = "123456", -- postgres password + DATABASE = "postgres" + } +} \ No newline at end of file diff --git a/src/config/database.lua b/src/config/database.lua new file mode 100644 index 0000000..8f0b175 --- /dev/null +++ b/src/config/database.lua @@ -0,0 +1,18 @@ +local env = require('env') + +return { + redis_prefix = 'Auth:', + redis = { + host = env.REDIS.HOST, + port = env.REDIS.PORT, + password = env.REDIS.PASSWORD + }, + + postgres = { + host = env.POSTGRES.HOST, + port = env.POSTGRES.PORT, + username = env.POSTGRES.USERNAME, + password = env.POSTGRES.PASSWORD, + dbname = env.POSTGRES.DATABASE + }, +} \ No newline at end of file diff --git a/src/routes/dynamic_router.lua b/src/routes/dynamic_router.lua new file mode 100644 index 0000000..8513f6e --- /dev/null +++ b/src/routes/dynamic_router.lua @@ -0,0 +1,18 @@ +--- +--- Generated by EmmyLua(https://github.com/EmmyLua) +--- Created by admin. +--- DateTime: 2025/9/24 18:14 +--- + +local cjson = require("cjson.safe") +local routes = require("routes_cache") -- 从共享内存获取 +function handle_request() + local target = routes:get(ngx.var.uri) + if not target then + ngx.status = 404 + ngx.say("Route not found: ", ngx.var.uri) + return + end + ngx.var.target = target +end +handle_request() diff --git a/src/routes/route_updater.lua b/src/routes/route_updater.lua new file mode 100644 index 0000000..0c37356 --- /dev/null +++ b/src/routes/route_updater.lua @@ -0,0 +1,15 @@ +--- +--- Generated by EmmyLua(https://github.com/EmmyLua) +--- Created by admin. +--- DateTime: 2025/9/24 18:14 +--- + +local cjson = require("cjson") +local dict = ngx.shared.routes_cache +ngx.req.read_body() +local body = ngx.req.get_body_data() +local new_rules = cjson.decode(body) +for path, target in pairs(new_rules) do + dict:set(path, target) +end +ngx.say(cjson.encode({ status = "ok", count = #new_rules })) diff --git a/src/service/system/user.lua b/src/service/system/user.lua new file mode 100644 index 0000000..d70a06e --- /dev/null +++ b/src/service/system/user.lua @@ -0,0 +1,5 @@ +--- +--- Generated by EmmyLua(https://github.com/EmmyLua) +--- Created by admin. +--- DateTime: 2025/9/25 08:19 +--- 业务逻辑 \ No newline at end of file diff --git a/src/share/initial/loading_config.lua b/src/share/initial/loading_config.lua new file mode 100644 index 0000000..b66d7d7 --- /dev/null +++ b/src/share/initial/loading_config.lua @@ -0,0 +1,29 @@ +--- +--- Generated by EmmyLua(https://github.com/EmmyLua) +--- Created by . +--- DateTime: +--- + +--[[ + 公共包引用的范围为http块;项目用到的非配置类都可以在此处配置; + 日志级别: + ngx.STDERR + ngx.EMERG + ngx.ALERT + ngx.CRIT + ngx.ERR + ngx.WARN + ngx.NOTICE + ngx.INFO + ngx.DEBUG +]] +-- 加载cjson +cjson = require("cjson"); +-- 加载string +string = require("string"); + +--[[ + 项目内公共包配置 +]] +-- 加载resty.core +require("resty.core"); \ No newline at end of file diff --git a/src/share/initial/mobdebug.lua b/src/share/initial/mobdebug.lua new file mode 100644 index 0000000..23a4705 --- /dev/null +++ b/src/share/initial/mobdebug.lua @@ -0,0 +1,1703 @@ +-- +-- MobDebug -- Lua remote debugger +-- Copyright 2011-15 Paul Kulchenko +-- Based on RemDebug 1.0 Copyright Kepler Project 2005 +-- + +-- use loaded modules or load explicitly on those systems that require that +local require = require +local io = io or require "io" +local table = table or require "table" +local string = string or require "string" +local coroutine = coroutine or require "coroutine" +local debug = require "debug" +-- protect require "os" as it may fail on embedded systems without os module +local os = os or (function(module) + local ok, res = pcall(require, module) + return ok and res or nil +end)("os") + +local mobdebug = { + _NAME = "mobdebug", + _VERSION = "0.70", + _COPYRIGHT = "Paul Kulchenko", + _DESCRIPTION = "Mobile Remote Debugger for the Lua programming language", + port = os and os.getenv and tonumber((os.getenv("MOBDEBUG_PORT"))) or 8172, + checkcount = 200, + yieldtimeout = 0.02, -- yield timeout (s) + connecttimeout = 2, -- connect timeout (s) +} + +local HOOKMASK = "lcr" +local error = error +local getfenv = getfenv +local setfenv = setfenv +local loadstring = loadstring or load -- "load" replaced "loadstring" in Lua 5.2 +local pairs = pairs +local setmetatable = setmetatable +local tonumber = tonumber +local unpack = table.unpack or unpack +local rawget = rawget +local gsub, sub, find = string.gsub, string.sub, string.find + +-- if strict.lua is used, then need to avoid referencing some global +-- variables, as they can be undefined; +-- use rawget to avoid complaints from strict.lua at run-time. +-- it's safe to do the initialization here as all these variables +-- should get defined values (if any) before the debugging starts. +-- there is also global 'wx' variable, which is checked as part of +-- the debug loop as 'wx' can be loaded at any time during debugging. +local genv = _G or _ENV +local jit = rawget(genv, "jit") +local MOAICoroutine = rawget(genv, "MOAICoroutine") + +-- ngx_lua debugging requires a special handling as its coroutine.* +-- methods use a different mechanism that doesn't allow resume calls +-- from debug hook handlers. +-- Instead, the "original" coroutine.* methods are used. +-- `rawget` needs to be used to protect against `strict` checks, but +-- ngx_lua hides those in a metatable, so need to use that. +local metagindex = getmetatable(genv) and getmetatable(genv).__index +local ngx = type(metagindex) == "table" and metagindex.rawget and metagindex:rawget("ngx") or nil +local corocreate = ngx and coroutine._create or coroutine.create +local cororesume = ngx and coroutine._resume or coroutine.resume +local coroyield = ngx and coroutine._yield or coroutine.yield +local corostatus = ngx and coroutine._status or coroutine.status +local corowrap = coroutine.wrap + +if not setfenv then -- Lua 5.2+ + -- based on http://lua-users.org/lists/lua-l/2010-06/msg00314.html + -- this assumes f is a function + local function findenv(f) + local level = 1 + repeat + local name, value = debug.getupvalue(f, level) + if name == '_ENV' then return level, value end + level = level + 1 + until name == nil + return nil end + getfenv = function (f) return(select(2, findenv(f)) or _G) end + setfenv = function (f, t) + local level = findenv(f) + if level then debug.setupvalue(f, level, t) end + return f end +end + +-- check for OS and convert file names to lower case on windows +-- (its file system is case insensitive, but case preserving), as setting a +-- breakpoint on x:\Foo.lua will not work if the file was loaded as X:\foo.lua. +-- OSX and Windows behave the same way (case insensitive, but case preserving). +-- OSX can be configured to be case-sensitive, so check for that. This doesn't +-- handle the case of different partitions having different case-sensitivity. +local win = os and os.getenv and (os.getenv('WINDIR') or (os.getenv('OS') or ''):match('[Ww]indows')) and true or false +local mac = not win and (os and os.getenv and os.getenv('DYLD_LIBRARY_PATH') or not io.open("/proc")) and true or false +local iscasepreserving = win or (mac and io.open('/library') ~= nil) + +-- turn jit off based on Mike Pall's comment in this discussion: +-- http://www.freelists.org/post/luajit/Debug-hooks-and-JIT,2 +-- "You need to turn it off at the start if you plan to receive +-- reliable hook calls at any later point in time." +if jit and jit.off then jit.off() end + +local socket = require "socket" +local coro_debugger +local coro_debugee +local coroutines = {}; setmetatable(coroutines, {__mode = "k"}) -- "weak" keys +local events = { BREAK = 1, WATCH = 2, RESTART = 3, STACK = 4 } +local breakpoints = {} +local watches = {} +local lastsource +local lastfile +local watchescnt = 0 +local abort -- default value is nil; this is used in start/loop distinction +local seen_hook = false +local checkcount = 0 +local step_into = false +local step_over = false +local step_level = 0 +local stack_level = 0 +local server +local buf +local outputs = {} +local iobase = {print = print} +local basedir = "" +local deferror = "execution aborted at default debugee" +local debugee = function () + local a = 1 + for _ = 1, 10 do a = a + 1 end + error(deferror) +end +local function q(s) return string.gsub(s, '([%(%)%.%%%+%-%*%?%[%^%$%]])','%%%1') end + +local serpent = (function() ---- include Serpent module for serialization +local n, v = "serpent", "0.30" -- (C) 2012-17 Paul Kulchenko; MIT License +local c, d = "Paul Kulchenko", "Lua serializer and pretty printer" +local snum = {[tostring(1/0)]='1/0 --[[math.huge]]',[tostring(-1/0)]='-1/0 --[[-math.huge]]',[tostring(0/0)]='0/0'} +local badtype = {thread = true, userdata = true, cdata = true} +local getmetatable = debug and debug.getmetatable or getmetatable +local pairs = function(t) return next, t end -- avoid using __pairs in Lua 5.2+ +local keyword, globals, G = {}, {}, (_G or _ENV) +for _,k in ipairs({'and', 'break', 'do', 'else', 'elseif', 'end', 'false', + 'for', 'function', 'goto', 'if', 'in', 'local', 'nil', 'not', 'or', 'repeat', + 'return', 'then', 'true', 'until', 'while'}) do keyword[k] = true end +for k,v in pairs(G) do globals[v] = k end -- build func to name mapping +for _,g in ipairs({'coroutine', 'debug', 'io', 'math', 'string', 'table', 'os'}) do + for k,v in pairs(type(G[g]) == 'table' and G[g] or {}) do globals[v] = g..'.'..k end end + +local function s(t, opts) + local name, indent, fatal, maxnum = opts.name, opts.indent, opts.fatal, opts.maxnum + local sparse, custom, huge = opts.sparse, opts.custom, not opts.nohuge + local space, maxl = (opts.compact and '' or ' '), (opts.maxlevel or math.huge) + local maxlen, metatostring = tonumber(opts.maxlength), opts.metatostring + local iname, comm = '_'..(name or ''), opts.comment and (tonumber(opts.comment) or math.huge) + local numformat = opts.numformat or "%.17g" + local seen, sref, syms, symn = {}, {'local '..iname..'={}'}, {}, 0 + local function gensym(val) return '_'..(tostring(tostring(val)):gsub("[^%w]",""):gsub("(%d%w+)", + -- tostring(val) is needed because __tostring may return a non-string value + function(s) if not syms[s] then symn = symn+1; syms[s] = symn end return tostring(syms[s]) end)) end + local function safestr(s) return type(s) == "number" and tostring(huge and snum[tostring(s)] or numformat:format(s)) + or type(s) ~= "string" and tostring(s) -- escape NEWLINE/010 and EOF/026 + or ("%q"):format(s):gsub("\010","n"):gsub("\026","\\026") end + local function comment(s,l) return comm and (l or 0) < comm and ' --[['..select(2, pcall(tostring, s))..']]' or '' end + local function globerr(s,l) return globals[s] and globals[s]..comment(s,l) or not fatal + and safestr(select(2, pcall(tostring, s))) or error("Can't serialize "..tostring(s)) end + local function safename(path, name) -- generates foo.bar, foo[3], or foo['b a r'] + local n = name == nil and '' or name + local plain = type(n) == "string" and n:match("^[%l%u_][%w_]*$") and not keyword[n] + local safe = plain and n or '['..safestr(n)..']' + return (path or '')..(plain and path and '.' or '')..safe, safe end + local alphanumsort = type(opts.sortkeys) == 'function' and opts.sortkeys or function(k, o, n) -- k=keys, o=originaltable, n=padding + local maxn, to = tonumber(n) or 12, {number = 'a', string = 'b'} + local function padnum(d) return ("%0"..tostring(maxn).."d"):format(tonumber(d)) end + table.sort(k, function(a,b) + -- sort numeric keys first: k[key] is not nil for numerical keys + return (k[a] ~= nil and 0 or to[type(a)] or 'z')..(tostring(a):gsub("%d+",padnum)) + < (k[b] ~= nil and 0 or to[type(b)] or 'z')..(tostring(b):gsub("%d+",padnum)) end) end + local function val2str(t, name, indent, insref, path, plainindex, level) + local ttype, level, mt = type(t), (level or 0), getmetatable(t) + local spath, sname = safename(path, name) + local tag = plainindex and + ((type(name) == "number") and '' or name..space..'='..space) or + (name ~= nil and sname..space..'='..space or '') + if seen[t] then -- already seen this element + sref[#sref+1] = spath..space..'='..space..seen[t] + return tag..'nil'..comment('ref', level) end + -- protect from those cases where __tostring may fail + if type(mt) == 'table' then + local to, tr = pcall(function() return mt.__tostring(t) end) + local so, sr = pcall(function() return mt.__serialize(t) end) + if (opts.metatostring ~= false and to or so) then -- knows how to serialize itself + seen[t] = insref or spath + t = so and sr or tr + ttype = type(t) + end -- new value falls through to be serialized + end + if ttype == "table" then + if level >= maxl then return tag..'{}'..comment('maxlvl', level) end + seen[t] = insref or spath + if next(t) == nil then return tag..'{}'..comment(t, level) end -- table empty + if maxlen and maxlen < 0 then return tag..'{}'..comment('maxlen', level) end + local maxn, o, out = math.min(#t, maxnum or #t), {}, {} + for key = 1, maxn do o[key] = key end + if not maxnum or #o < maxnum then + local n = #o -- n = n + 1; o[n] is much faster than o[#o+1] on large tables + for key in pairs(t) do if o[key] ~= key then n = n + 1; o[n] = key end end end + if maxnum and #o > maxnum then o[maxnum+1] = nil end + if opts.sortkeys and #o > maxn then alphanumsort(o, t, opts.sortkeys) end + local sparse = sparse and #o > maxn -- disable sparsness if only numeric keys (shorter output) + for n, key in ipairs(o) do + local value, ktype, plainindex = t[key], type(key), n <= maxn and not sparse + if opts.valignore and opts.valignore[value] -- skip ignored values; do nothing + or opts.keyallow and not opts.keyallow[key] + or opts.keyignore and opts.keyignore[key] + or opts.valtypeignore and opts.valtypeignore[type(value)] -- skipping ignored value types + or sparse and value == nil then -- skipping nils; do nothing + elseif ktype == 'table' or ktype == 'function' or badtype[ktype] then + if not seen[key] and not globals[key] then + sref[#sref+1] = 'placeholder' + local sname = safename(iname, gensym(key)) -- iname is table for local variables + sref[#sref] = val2str(key,sname,indent,sname,iname,true) end + sref[#sref+1] = 'placeholder' + local path = seen[t]..'['..tostring(seen[key] or globals[key] or gensym(key))..']' + sref[#sref] = path..space..'='..space..tostring(seen[value] or val2str(value,nil,indent,path)) + else + out[#out+1] = val2str(value,key,indent,insref,seen[t],plainindex,level+1) + if maxlen then + maxlen = maxlen - #out[#out] + if maxlen < 0 then break end + end + end + end + local prefix = string.rep(indent or '', level) + local head = indent and '{\n'..prefix..indent or '{' + local body = table.concat(out, ','..(indent and '\n'..prefix..indent or space)) + local tail = indent and "\n"..prefix..'}' or '}' + return (custom and custom(tag,head,body,tail,level) or tag..head..body..tail)..comment(t, level) + elseif badtype[ttype] then + seen[t] = insref or spath + return tag..globerr(t, level) + elseif ttype == 'function' then + seen[t] = insref or spath + if opts.nocode then return tag.."function() --[[..skipped..]] end"..comment(t, level) end + local ok, res = pcall(string.dump, t) + local func = ok and "((loadstring or load)("..safestr(res)..",'@serialized'))"..comment(t, level) + return tag..(func or globerr(t, level)) + else return tag..safestr(t) end -- handle all other types + end + local sepr = indent and "\n" or ";"..space + local body = val2str(t, name, indent) -- this call also populates sref + local tail = #sref>1 and table.concat(sref, sepr)..sepr or '' + local warn = opts.comment and #sref>1 and space.."--[[incomplete output with shared/self-references skipped]]" or '' + return not name and body..warn or "do local "..body..sepr..tail.."return "..name..sepr.."end" +end + +local function deserialize(data, opts) + local env = (opts and opts.safe == false) and G + or setmetatable({}, { + __index = function(t,k) return t end, + __call = function(t,...) error("cannot call functions") end + }) + local f, res = (loadstring or load)('return '..data, nil, nil, env) + if not f then f, res = (loadstring or load)(data, nil, nil, env) end + if not f then return f, res end + if setfenv then setfenv(f, env) end + return pcall(f) +end + +local function merge(a, b) if b then for k,v in pairs(b) do a[k] = v end end; return a; end +return { _NAME = n, _COPYRIGHT = c, _DESCRIPTION = d, _VERSION = v, serialize = s, + load = deserialize, + dump = function(a, opts) return s(a, merge({name = '_', compact = true, sparse = true}, opts)) end, + line = function(a, opts) return s(a, merge({sortkeys = true, comment = true}, opts)) end, + block = function(a, opts) return s(a, merge({indent = ' ', sortkeys = true, comment = true}, opts)) end } +end)() ---- end of Serpent module + +mobdebug.line = serpent.line +mobdebug.dump = serpent.dump +mobdebug.linemap = nil +mobdebug.loadstring = loadstring + +local function removebasedir(path, basedir) + if iscasepreserving then + -- check if the lowercased path matches the basedir + -- if so, return substring of the original path (to not lowercase it) + return path:lower():find('^'..q(basedir:lower())) + and path:sub(#basedir+1) or path + else + return string.gsub(path, '^'..q(basedir), '') + end +end + +local function stack(start) + local function vars(f) + local func = debug.getinfo(f, "f").func + local i = 1 + local locals = {} + -- get locals + while true do + local name, value = debug.getlocal(f, i) + if not name then break end + if string.sub(name, 1, 1) ~= '(' then + locals[name] = {value, select(2,pcall(tostring,value))} + end + i = i + 1 + end + -- get varargs (these use negative indices) + i = 1 + while true do + local name, value = debug.getlocal(f, -i) + -- `not name` should be enough, but LuaJIT 2.0.0 incorrectly reports `(*temporary)` names here + if not name or name ~= "(*vararg)" then break end + locals[name:gsub("%)$"," "..i..")")] = {value, select(2,pcall(tostring,value))} + i = i + 1 + end + -- get upvalues + i = 1 + local ups = {} + while func do -- check for func as it may be nil for tail calls + local name, value = debug.getupvalue(func, i) + if not name then break end + ups[name] = {value, select(2,pcall(tostring,value))} + i = i + 1 + end + return locals, ups + end + + local stack = {} + local linemap = mobdebug.linemap + for i = (start or 0), 100 do + local source = debug.getinfo(i, "Snl") + if not source then break end + + local src = source.source + if src:find("@") == 1 then + src = src:sub(2):gsub("\\", "/") + if src:find("%./") == 1 then src = src:sub(3) end + end + + table.insert(stack, { -- remove basedir from source + {source.name, removebasedir(src, basedir), + linemap and linemap(source.linedefined, source.source) or source.linedefined, + linemap and linemap(source.currentline, source.source) or source.currentline, + source.what, source.namewhat, source.short_src}, + vars(i+1)}) + if source.what == 'main' then break end + end + return stack +end + +local function set_breakpoint(file, line) + if file == '-' and lastfile then file = lastfile + elseif iscasepreserving then file = string.lower(file) end + if not breakpoints[line] then breakpoints[line] = {} end + breakpoints[line][file] = true +end + +local function remove_breakpoint(file, line) + if file == '-' and lastfile then file = lastfile + elseif file == '*' and line == 0 then breakpoints = {} + elseif iscasepreserving then file = string.lower(file) end + if breakpoints[line] then breakpoints[line][file] = nil end +end + +local function has_breakpoint(file, line) + return breakpoints[line] + and breakpoints[line][iscasepreserving and string.lower(file) or file] +end + +local function restore_vars(vars) + if type(vars) ~= 'table' then return end + + -- locals need to be processed in the reverse order, starting from + -- the inner block out, to make sure that the localized variables + -- are correctly updated with only the closest variable with + -- the same name being changed + -- first loop find how many local variables there is, while + -- the second loop processes them from i to 1 + local i = 1 + while true do + local name = debug.getlocal(3, i) + if not name then break end + i = i + 1 + end + i = i - 1 + local written_vars = {} + while i > 0 do + local name = debug.getlocal(3, i) + if not written_vars[name] then + if string.sub(name, 1, 1) ~= '(' then + debug.setlocal(3, i, rawget(vars, name)) + end + written_vars[name] = true + end + i = i - 1 + end + + i = 1 + local func = debug.getinfo(3, "f").func + while true do + local name = debug.getupvalue(func, i) + if not name then break end + if not written_vars[name] then + if string.sub(name, 1, 1) ~= '(' then + debug.setupvalue(func, i, rawget(vars, name)) + end + written_vars[name] = true + end + i = i + 1 + end +end + +local function capture_vars(level, thread) + level = (level or 0)+2 -- add two levels for this and debug calls + local func = (thread and debug.getinfo(thread, level, "f") or debug.getinfo(level, "f") or {}).func + if not func then return {} end + + local vars = {['...'] = {}} + local i = 1 + while true do + local name, value = debug.getupvalue(func, i) + if not name then break end + if string.sub(name, 1, 1) ~= '(' then vars[name] = value end + i = i + 1 + end + i = 1 + while true do + local name, value + if thread then + name, value = debug.getlocal(thread, level, i) + else + name, value = debug.getlocal(level, i) + end + if not name then break end + if string.sub(name, 1, 1) ~= '(' then vars[name] = value end + i = i + 1 + end + -- get varargs (these use negative indices) + i = 1 + while true do + local name, value + if thread then + name, value = debug.getlocal(thread, level, -i) + else + name, value = debug.getlocal(level, -i) + end + -- `not name` should be enough, but LuaJIT 2.0.0 incorrectly reports `(*temporary)` names here + if not name or name ~= "(*vararg)" then break end + vars['...'][i] = value + i = i + 1 + end + -- returned 'vars' table plays a dual role: (1) it captures local values + -- and upvalues to be restored later (in case they are modified in "eval"), + -- and (2) it provides an environment for evaluated chunks. + -- getfenv(func) is needed to provide proper environment for functions, + -- including access to globals, but this causes vars[name] to fail in + -- restore_vars on local variables or upvalues with `nil` values when + -- 'strict' is in effect. To avoid this `rawget` is used in restore_vars. + setmetatable(vars, { __index = getfenv(func), __newindex = getfenv(func) }) + return vars +end + +local function stack_depth(start_depth) + for i = start_depth, 0, -1 do + if debug.getinfo(i, "l") then return i+1 end + end + return start_depth +end + +local function is_safe(stack_level) + -- the stack grows up: 0 is getinfo, 1 is is_safe, 2 is debug_hook, 3 is user function + if stack_level == 3 then return true end + for i = 3, stack_level do + -- return if it is not safe to abort + local info = debug.getinfo(i, "S") + if not info then return true end + if info.what == "C" then return false end + end + return true +end + +local function in_debugger() + local this = debug.getinfo(1, "S").source + -- only need to check few frames as mobdebug frames should be close + for i = 3, 7 do + local info = debug.getinfo(i, "S") + if not info then return false end + if info.source == this then return true end + end + return false +end + +local function is_pending(peer) + -- if there is something already in the buffer, skip check + if not buf and checkcount >= mobdebug.checkcount then + peer:settimeout(0) -- non-blocking + buf = peer:receive(1) + peer:settimeout() -- back to blocking + checkcount = 0 + end + return buf +end + +local function readnext(peer, num) + peer:settimeout(0) -- non-blocking + local res, err, partial = peer:receive(num) + peer:settimeout() -- back to blocking + return res or partial or '', err +end + +local function handle_breakpoint(peer) + -- check if the buffer has the beginning of SETB/DELB command; + -- this is to avoid reading the entire line for commands that + -- don't need to be handled here. + if not buf or not (buf:sub(1,1) == 'S' or buf:sub(1,1) == 'D') then return end + + -- check second character to avoid reading STEP or other S* and D* commands + if #buf == 1 then buf = buf .. readnext(peer, 1) end + if buf:sub(2,2) ~= 'E' then return end + + -- need to read few more characters + buf = buf .. readnext(peer, 5-#buf) + if buf ~= 'SETB ' and buf ~= 'DELB ' then return end + + local res, _, partial = peer:receive() -- get the rest of the line; blocking + if not res then + if partial then buf = buf .. partial end + return + end + + local _, _, cmd, file, line = (buf..res):find("^([A-Z]+)%s+(.-)%s+(%d+)%s*$") + if cmd == 'SETB' then set_breakpoint(file, tonumber(line)) + elseif cmd == 'DELB' then remove_breakpoint(file, tonumber(line)) + else + -- this looks like a breakpoint command, but something went wrong; + -- return here to let the "normal" processing to handle, + -- although this is likely to not go well. + return + end + + buf = nil +end + +local function normalize_path(file) + local n + repeat + file, n = file:gsub("/+%.?/+","/") -- remove all `//` and `/./` references + until n == 0 + -- collapse all up-dir references: this will clobber UNC prefix (\\?\) + -- and disk on Windows when there are too many up-dir references: `D:\foo\..\..\bar`; + -- handle the case of multiple up-dir references: `foo/bar/baz/../../../more`; + -- only remove one at a time as otherwise `../../` could be removed; + repeat + file, n = file:gsub("[^/]+/%.%./", "", 1) + until n == 0 + -- there may still be a leading up-dir reference left (as `/../` or `../`); remove it + return (file:gsub("^(/?)%.%./", "%1")) +end + +local function debug_hook(event, line) + -- (1) LuaJIT needs special treatment. Because debug_hook is set for + -- *all* coroutines, and not just the one being debugged as in regular Lua + -- (http://lua-users.org/lists/lua-l/2011-06/msg00513.html), + -- need to avoid debugging mobdebug's own code as LuaJIT doesn't + -- always correctly generate call/return hook events (there are more + -- calls than returns, which breaks stack depth calculation and + -- 'step' and 'step over' commands stop working; possibly because + -- 'tail return' events are not generated by LuaJIT). + -- the next line checks if the debugger is run under LuaJIT and if + -- one of debugger methods is present in the stack, it simply returns. + if jit then + -- when luajit is compiled with LUAJIT_ENABLE_LUA52COMPAT, + -- coroutine.running() returns non-nil for the main thread. + local coro, main = coroutine.running() + if not coro or main then coro = 'main' end + local disabled = coroutines[coro] == false + or coroutines[coro] == nil and coro ~= (coro_debugee or 'main') + if coro_debugee and disabled or not coro_debugee and (disabled or in_debugger()) + then return end + end + + -- (2) check if abort has been requested and it's safe to abort + if abort and is_safe(stack_level) then error(abort) end + + -- (3) also check if this debug hook has not been visited for any reason. + -- this check is needed to avoid stepping in too early + -- (for example, when coroutine.resume() is executed inside start()). + if not seen_hook and in_debugger() then return end + + if event == "call" then + stack_level = stack_level + 1 + elseif event == "return" or event == "tail return" then + stack_level = stack_level - 1 + elseif event == "line" then + if mobdebug.linemap then + local ok, mappedline = pcall(mobdebug.linemap, line, debug.getinfo(2, "S").source) + if ok then line = mappedline end + if not line then return end + end + + -- may need to fall through because of the following: + -- (1) step_into + -- (2) step_over and stack_level <= step_level (need stack_level) + -- (3) breakpoint; check for line first as it's known; then for file + -- (4) socket call (only do every Xth check) + -- (5) at least one watch is registered + if not ( + step_into or step_over or breakpoints[line] or watchescnt > 0 + or is_pending(server) + ) then checkcount = checkcount + 1; return end + + checkcount = mobdebug.checkcount -- force check on the next command + + -- this is needed to check if the stack got shorter or longer. + -- unfortunately counting call/return calls is not reliable. + -- the discrepancy may happen when "pcall(load, '')" call is made + -- or when "error()" is called in a function. + -- in either case there are more "call" than "return" events reported. + -- this validation is done for every "line" event, but should be "cheap" + -- as it checks for the stack to get shorter (or longer by one call). + -- start from one level higher just in case we need to grow the stack. + -- this may happen after coroutine.resume call to a function that doesn't + -- have any other instructions to execute. it triggers three returns: + -- "return, tail return, return", which needs to be accounted for. + stack_level = stack_depth(stack_level+1) + + local caller = debug.getinfo(2, "S") + + -- grab the filename and fix it if needed + local file = lastfile + if (lastsource ~= caller.source) then + file, lastsource = caller.source, caller.source + -- technically, users can supply names that may not use '@', + -- for example when they call loadstring('...', 'filename.lua'). + -- Unfortunately, there is no reliable/quick way to figure out + -- what is the filename and what is the source code. + -- If the name doesn't start with `@`, assume it's a file name if it's all on one line. + if find(file, "^@") or not find(file, "[\r\n]") then + file = gsub(gsub(file, "^@", ""), "\\", "/") + -- normalize paths that may include up-dir or same-dir references + -- if the path starts from the up-dir or reference, + -- prepend `basedir` to generate absolute path to keep breakpoints working. + -- ignore qualified relative path (`D:../`) and UNC paths (`\\?\`) + if find(file, "^%.%./") then file = basedir..file end + if find(file, "/%.%.?/") then file = normalize_path(file) end + -- need this conversion to be applied to relative and absolute + -- file names as you may write "require 'Foo'" to + -- load "foo.lua" (on a case insensitive file system) and breakpoints + -- set on foo.lua will not work if not converted to the same case. + if iscasepreserving then file = string.lower(file) end + if find(file, "^%./") then file = sub(file, 3) + else file = gsub(file, "^"..q(basedir), "") end + -- some file systems allow newlines in file names; remove these. + file = gsub(file, "\n", ' ') + else + file = mobdebug.line(file) + end + + -- set to true if we got here; this only needs to be done once per + -- session, so do it here to at least avoid setting it for every line. + seen_hook = true + lastfile = file + end + + if is_pending(server) then handle_breakpoint(server) end + + local vars, status, res + if (watchescnt > 0) then + vars = capture_vars(1) + for index, value in pairs(watches) do + setfenv(value, vars) + local ok, fired = pcall(value) + if ok and fired then + status, res = cororesume(coro_debugger, events.WATCH, vars, file, line, index) + break -- any one watch is enough; don't check multiple times + end + end + end + + -- need to get into the "regular" debug handler, but only if there was + -- no watch that was fired. If there was a watch, handle its result. + local getin = (status == nil) and + (step_into + -- when coroutine.running() return `nil` (main thread in Lua 5.1), + -- step_over will equal 'main', so need to check for that explicitly. + or (step_over and step_over == (coroutine.running() or 'main') and stack_level <= step_level) + or has_breakpoint(file, line) + or is_pending(server)) + + if getin then + vars = vars or capture_vars(1) + step_into = false + step_over = false + status, res = cororesume(coro_debugger, events.BREAK, vars, file, line) + end + + -- handle 'stack' command that provides stack() information to the debugger + while status and res == 'stack' do + -- resume with the stack trace and variables + if vars then restore_vars(vars) end -- restore vars so they are reflected in stack values + status, res = cororesume(coro_debugger, events.STACK, stack(3), file, line) + end + + -- need to recheck once more as resume after 'stack' command may + -- return something else (for example, 'exit'), which needs to be handled + if status and res and res ~= 'stack' then + if not abort and res == "exit" then mobdebug.onexit(1, true); return end + if not abort and res == "done" then mobdebug.done(); return end + abort = res + -- only abort if safe; if not, there is another (earlier) check inside + -- debug_hook, which will abort execution at the first safe opportunity + if is_safe(stack_level) then error(abort) end + elseif not status and res then + error(res, 2) -- report any other (internal) errors back to the application + end + + if vars then restore_vars(vars) end + + -- last command requested Step Over/Out; store the current thread + if step_over == true then step_over = coroutine.running() or 'main' end + end +end + +local function stringify_results(params, status, ...) + if not status then return status, ... end -- on error report as it + + params = params or {} + if params.nocode == nil then params.nocode = true end + if params.comment == nil then params.comment = 1 end + + local t = {...} + for i,v in pairs(t) do -- stringify each of the returned values + local ok, res = pcall(mobdebug.line, v, params) + t[i] = ok and res or ("%q"):format(res):gsub("\010","n"):gsub("\026","\\026") + end + -- stringify table with all returned values + -- this is done to allow each returned value to be used (serialized or not) + -- intependently and to preserve "original" comments + return pcall(mobdebug.dump, t, {sparse = false}) +end + +local function isrunning() + return coro_debugger and (corostatus(coro_debugger) == 'suspended' or corostatus(coro_debugger) == 'running') +end + +-- this is a function that removes all hooks and closes the socket to +-- report back to the controller that the debugging is done. +-- the script that called `done` can still continue. +local function done() + if not (isrunning() and server) then return end + + if not jit then + for co, debugged in pairs(coroutines) do + if debugged then debug.sethook(co) end + end + end + + debug.sethook() + server:close() + + coro_debugger = nil -- to make sure isrunning() returns `false` + seen_hook = nil -- to make sure that the next start() call works + abort = nil -- to make sure that callback calls use proper "abort" value +end + +local function debugger_loop(sev, svars, sfile, sline) + local command + local app, osname + local eval_env = svars or {} + local function emptyWatch () return false end + local loaded = {} + for k in pairs(package.loaded) do loaded[k] = true end + + while true do + local line, err + local wx = rawget(genv, "wx") -- use rawread to make strict.lua happy + if (wx or mobdebug.yield) and server.settimeout then server:settimeout(mobdebug.yieldtimeout) end + while true do + line, err = server:receive() + if not line and err == "timeout" then + -- yield for wx GUI applications if possible to avoid "busyness" + app = app or (wx and wx.wxGetApp and wx.wxGetApp()) + if app then + local win = app:GetTopWindow() + local inloop = app:IsMainLoopRunning() + osname = osname or wx.wxPlatformInfo.Get():GetOperatingSystemFamilyName() + if win and not inloop then + -- process messages in a regular way + -- and exit as soon as the event loop is idle + if osname == 'Unix' then wx.wxTimer(app):Start(10, true) end + local exitLoop = function() + win:Disconnect(wx.wxID_ANY, wx.wxID_ANY, wx.wxEVT_IDLE) + win:Disconnect(wx.wxID_ANY, wx.wxID_ANY, wx.wxEVT_TIMER) + app:ExitMainLoop() + end + win:Connect(wx.wxEVT_IDLE, exitLoop) + win:Connect(wx.wxEVT_TIMER, exitLoop) + app:MainLoop() + end + elseif mobdebug.yield then mobdebug.yield() + end + elseif not line and err == "closed" then + error("Debugger connection closed", 0) + else + -- if there is something in the pending buffer, prepend it to the line + if buf then line = buf .. line; buf = nil end + break + end + end + if server.settimeout then server:settimeout() end -- back to blocking + command = string.sub(line, string.find(line, "^[A-Z]+")) + if command == "SETB" then + local _, _, _, file, line = string.find(line, "^([A-Z]+)%s+(.-)%s+(%d+)%s*$") + if file and line then + set_breakpoint(file, tonumber(line)) + server:send("200 OK\n") + else + server:send("400 Bad Request\n") + end + elseif command == "DELB" then + local _, _, _, file, line = string.find(line, "^([A-Z]+)%s+(.-)%s+(%d+)%s*$") + if file and line then + remove_breakpoint(file, tonumber(line)) + server:send("200 OK\n") + else + server:send("400 Bad Request\n") + end + elseif command == "EXEC" then + -- extract any optional parameters + local params = string.match(line, "--%s*(%b{})%s*$") + local _, _, chunk = string.find(line, "^[A-Z]+%s+(.+)$") + if chunk then + local func, res = mobdebug.loadstring(chunk) + local status + if func then + local pfunc = params and loadstring("return "..params) -- use internal function + params = pfunc and pfunc() + params = (type(params) == "table" and params or {}) + local stack = tonumber(params.stack) + -- if the requested stack frame is not the current one, then use a new capture + -- with a specific stack frame: `capture_vars(0, coro_debugee)` + local env = stack and coro_debugee and capture_vars(stack-1, coro_debugee) or eval_env + setfenv(func, env) + status, res = stringify_results(params, pcall(func, unpack(env['...'] or {}))) + end + if status then + if mobdebug.onscratch then mobdebug.onscratch(res) end + server:send("200 OK " .. tostring(#res) .. "\n") + server:send(res) + else + -- fix error if not set (for example, when loadstring is not present) + if not res then res = "Unknown error" end + server:send("401 Error in Expression " .. tostring(#res) .. "\n") + server:send(res) + end + else + server:send("400 Bad Request\n") + end + elseif command == "LOAD" then + local _, _, size, name = string.find(line, "^[A-Z]+%s+(%d+)%s+(%S.-)%s*$") + size = tonumber(size) + + if abort == nil then -- no LOAD/RELOAD allowed inside start() + if size > 0 then server:receive(size) end + if sfile and sline then + server:send("201 Started " .. sfile .. " " .. tostring(sline) .. "\n") + else + server:send("200 OK 0\n") + end + else + -- reset environment to allow required modules to load again + -- remove those packages that weren't loaded when debugger started + for k in pairs(package.loaded) do + if not loaded[k] then package.loaded[k] = nil end + end + + if size == 0 and name == '-' then -- RELOAD the current script being debugged + server:send("200 OK 0\n") + coroyield("load") + else + -- receiving 0 bytes blocks (at least in luasocket 2.0.2), so skip reading + local chunk = size == 0 and "" or server:receive(size) + if chunk then -- LOAD a new script for debugging + local func, res = mobdebug.loadstring(chunk, "@"..name) + if func then + server:send("200 OK 0\n") + debugee = func + coroyield("load") + else + server:send("401 Error in Expression " .. tostring(#res) .. "\n") + server:send(res) + end + else + server:send("400 Bad Request\n") + end + end + end + elseif command == "SETW" then + local _, _, exp = string.find(line, "^[A-Z]+%s+(.+)%s*$") + if exp then + local func, res = mobdebug.loadstring("return(" .. exp .. ")") + if func then + watchescnt = watchescnt + 1 + local newidx = #watches + 1 + watches[newidx] = func + server:send("200 OK " .. tostring(newidx) .. "\n") + else + server:send("401 Error in Expression " .. tostring(#res) .. "\n") + server:send(res) + end + else + server:send("400 Bad Request\n") + end + elseif command == "DELW" then + local _, _, index = string.find(line, "^[A-Z]+%s+(%d+)%s*$") + index = tonumber(index) + if index > 0 and index <= #watches then + watchescnt = watchescnt - (watches[index] ~= emptyWatch and 1 or 0) + watches[index] = emptyWatch + server:send("200 OK\n") + else + server:send("400 Bad Request\n") + end + elseif command == "RUN" then + server:send("200 OK\n") + + local ev, vars, file, line, idx_watch = coroyield() + eval_env = vars + if ev == events.BREAK then + server:send("202 Paused " .. file .. " " .. tostring(line) .. "\n") + elseif ev == events.WATCH then + server:send("203 Paused " .. file .. " " .. tostring(line) .. " " .. tostring(idx_watch) .. "\n") + elseif ev == events.RESTART then + -- nothing to do + else + server:send("401 Error in Execution " .. tostring(#file) .. "\n") + server:send(file) + end + elseif command == "STEP" then + server:send("200 OK\n") + step_into = true + + local ev, vars, file, line, idx_watch = coroyield() + eval_env = vars + if ev == events.BREAK then + server:send("202 Paused " .. file .. " " .. tostring(line) .. "\n") + elseif ev == events.WATCH then + server:send("203 Paused " .. file .. " " .. tostring(line) .. " " .. tostring(idx_watch) .. "\n") + elseif ev == events.RESTART then + -- nothing to do + else + server:send("401 Error in Execution " .. tostring(#file) .. "\n") + server:send(file) + end + elseif command == "OVER" or command == "OUT" then + server:send("200 OK\n") + step_over = true + + -- OVER and OUT are very similar except for + -- the stack level value at which to stop + if command == "OUT" then step_level = stack_level - 1 + else step_level = stack_level end + + local ev, vars, file, line, idx_watch = coroyield() + eval_env = vars + if ev == events.BREAK then + server:send("202 Paused " .. file .. " " .. tostring(line) .. "\n") + elseif ev == events.WATCH then + server:send("203 Paused " .. file .. " " .. tostring(line) .. " " .. tostring(idx_watch) .. "\n") + elseif ev == events.RESTART then + -- nothing to do + else + server:send("401 Error in Execution " .. tostring(#file) .. "\n") + server:send(file) + end + elseif command == "BASEDIR" then + local _, _, dir = string.find(line, "^[A-Z]+%s+(.+)%s*$") + if dir then + basedir = iscasepreserving and string.lower(dir) or dir + -- reset cached source as it may change with basedir + lastsource = nil + server:send("200 OK\n") + else + server:send("400 Bad Request\n") + end + elseif command == "SUSPEND" then + -- do nothing; it already fulfilled its role + elseif command == "DONE" then + coroyield("done") + return -- done with all the debugging + elseif command == "STACK" then + -- first check if we can execute the stack command + -- as it requires yielding back to debug_hook it cannot be executed + -- if we have not seen the hook yet as happens after start(). + -- in this case we simply return an empty result + local vars, ev = {} + if seen_hook then + ev, vars = coroyield("stack") + end + if ev and ev ~= events.STACK then + server:send("401 Error in Execution " .. tostring(#vars) .. "\n") + server:send(vars) + else + local params = string.match(line, "--%s*(%b{})%s*$") + local pfunc = params and loadstring("return "..params) -- use internal function + params = pfunc and pfunc() + params = (type(params) == "table" and params or {}) + if params.nocode == nil then params.nocode = true end + if params.sparse == nil then params.sparse = false end + -- take into account additional levels for the stack frames and data management + if tonumber(params.maxlevel) then params.maxlevel = tonumber(params.maxlevel)+4 end + + local ok, res = pcall(mobdebug.dump, vars, params) + if ok then + server:send("200 OK " .. tostring(res) .. "\n") + else + server:send("401 Error in Execution " .. tostring(#res) .. "\n") + server:send(res) + end + end + elseif command == "OUTPUT" then + local _, _, stream, mode = string.find(line, "^[A-Z]+%s+(%w+)%s+([dcr])%s*$") + if stream and mode and stream == "stdout" then + -- assign "print" in the global environment + local default = mode == 'd' + genv.print = default and iobase.print or corowrap(function() + -- wrapping into coroutine.wrap protects this function from + -- being stepped through in the debugger. + -- don't use vararg (...) as it adds a reference for its values, + -- which may affect how they are garbage collected + while true do + local tbl = {coroutine.yield()} + if mode == 'c' then iobase.print(unpack(tbl)) end + for n = 1, #tbl do + tbl[n] = select(2, pcall(mobdebug.line, tbl[n], {nocode = true, comment = false})) end + local file = table.concat(tbl, "\t").."\n" + server:send("204 Output " .. stream .. " " .. tostring(#file) .. "\n" .. file) + end + end) + if not default then genv.print() end -- "fake" print to start printing loop + server:send("200 OK\n") + else + server:send("400 Bad Request\n") + end + elseif command == "EXIT" then + server:send("200 OK\n") + coroyield("exit") + else + server:send("400 Bad Request\n") + end + end +end + +local function output(stream, data) + if server then return server:send("204 Output "..stream.." "..tostring(#data).."\n"..data) end +end + +local function connect(controller_host, controller_port) + local sock, err = socket.tcp() + if not sock then return nil, err end + + if sock.settimeout then sock:settimeout(mobdebug.connecttimeout) end + local res, err = sock:connect(controller_host, tostring(controller_port)) + if sock.settimeout then sock:settimeout() end + + if not res then return nil, err end + return sock +end + +local lasthost, lastport + +-- Starts a debug session by connecting to a controller +local function start(controller_host, controller_port) + -- only one debugging session can be run (as there is only one debug hook) + if isrunning() then return end + + lasthost = controller_host or lasthost + lastport = controller_port or lastport + + controller_host = lasthost or "localhost" + controller_port = lastport or mobdebug.port + + local err + server, err = mobdebug.connect(controller_host, controller_port) + if server then + -- correct stack depth which already has some calls on it + -- so it doesn't go into negative when those calls return + -- as this breaks subsequence checks in stack_depth(). + -- start from 16th frame, which is sufficiently large for this check. + stack_level = stack_depth(16) + + -- provide our own traceback function to report errors remotely + -- but only under Lua 5.1/LuaJIT as it's not called under Lua 5.2+ + -- (http://lua-users.org/lists/lua-l/2016-05/msg00297.html) + local function f() return function()end end + if f() ~= f() then -- Lua 5.1 or LuaJIT + local dtraceback = debug.traceback + debug.traceback = function (...) + if select('#', ...) >= 1 then + local thr, err, lvl = ... + if type(thr) ~= 'thread' then err, lvl = thr, err end + local trace = dtraceback(err, (lvl or 1)+1) + if genv.print == iobase.print then -- no remote redirect + return trace + else + genv.print(trace) -- report the error remotely + return -- don't report locally to avoid double reporting + end + end + -- direct call to debug.traceback: return the original. + -- debug.traceback(nil, level) doesn't work in Lua 5.1 + -- (http://lua-users.org/lists/lua-l/2011-06/msg00574.html), so + -- simply remove first frame from the stack trace + local tb = dtraceback("", 2) -- skip debugger frames + -- if the string is returned, then remove the first new line as it's not needed + return type(tb) == "string" and tb:gsub("^\n","") or tb + end + end + coro_debugger = corocreate(debugger_loop) + debug.sethook(debug_hook, HOOKMASK) + seen_hook = nil -- reset in case the last start() call was refused + step_into = true -- start with step command + return true + else + print(("Could not connect to %s:%s: %s") + :format(controller_host, controller_port, err or "unknown error")) + end +end + +local function controller(controller_host, controller_port, scratchpad) + -- only one debugging session can be run (as there is only one debug hook) + if isrunning() then return end + + lasthost = controller_host or lasthost + lastport = controller_port or lastport + + controller_host = lasthost or "localhost" + controller_port = lastport or mobdebug.port + + local exitonerror = not scratchpad + local err + server, err = mobdebug.connect(controller_host, controller_port) + if server then + local function report(trace, err) + local msg = err .. "\n" .. trace + server:send("401 Error in Execution " .. tostring(#msg) .. "\n") + server:send(msg) + return err + end + + seen_hook = true -- allow to accept all commands + coro_debugger = corocreate(debugger_loop) + + while true do + step_into = true -- start with step command + abort = false -- reset abort flag from the previous loop + if scratchpad then checkcount = mobdebug.checkcount end -- force suspend right away + + coro_debugee = corocreate(debugee) + debug.sethook(coro_debugee, debug_hook, HOOKMASK) + local status, err = cororesume(coro_debugee, unpack(arg or {})) + + -- was there an error or is the script done? + -- 'abort' state is allowed here; ignore it + if abort then + if tostring(abort) == 'exit' then break end + else + if status then -- normal execution is done + break + elseif err and not string.find(tostring(err), deferror) then + -- report the error back + -- err is not necessarily a string, so convert to string to report + report(debug.traceback(coro_debugee), tostring(err)) + if exitonerror then break end + -- check if the debugging is done (coro_debugger is nil) + if not coro_debugger then break end + -- resume once more to clear the response the debugger wants to send + -- need to use capture_vars(0) to capture only two (default) level, + -- as even though there is controller() call, because of the tail call, + -- the caller may not exist for it; + -- This is not entirely safe as the user may see the local + -- variable from console, but they will be reset anyway. + -- This functionality is used when scratchpad is paused to + -- gain access to remote console to modify global variables. + local status, err = cororesume(coro_debugger, events.RESTART, capture_vars(0)) + if not status or status and err == "exit" then break end + end + end + end + else + print(("Could not connect to %s:%s: %s") + :format(controller_host, controller_port, err or "unknown error")) + return false + end + return true +end + +local function scratchpad(controller_host, controller_port) + return controller(controller_host, controller_port, true) +end + +local function loop(controller_host, controller_port) + return controller(controller_host, controller_port, false) +end + +local function on() + if not (isrunning() and server) then return end + + -- main is set to true under Lua5.2 for the "main" chunk. + -- Lua5.1 returns co as `nil` in that case. + local co, main = coroutine.running() + if main then co = nil end + if co then + coroutines[co] = true + debug.sethook(co, debug_hook, HOOKMASK) + else + if jit then coroutines.main = true end + debug.sethook(debug_hook, HOOKMASK) + end +end + +local function off() + if not (isrunning() and server) then return end + + -- main is set to true under Lua5.2 for the "main" chunk. + -- Lua5.1 returns co as `nil` in that case. + local co, main = coroutine.running() + if main then co = nil end + + -- don't remove coroutine hook under LuaJIT as there is only one (global) hook + if co then + coroutines[co] = false + if not jit then debug.sethook(co) end + else + if jit then coroutines.main = false end + if not jit then debug.sethook() end + end + + -- check if there is any thread that is still being debugged under LuaJIT; + -- if not, turn the debugging off + if jit then + local remove = true + for _, debugged in pairs(coroutines) do + if debugged then remove = false; break end + end + if remove then debug.sethook() end + end +end + +-- Handles server debugging commands +local function handle(params, client, options) + -- when `options.verbose` is not provided, use normal `print`; verbose output can be + -- disabled (`options.verbose == false`) or redirected (`options.verbose == function()...end`) + local verbose = not options or options.verbose ~= nil and options.verbose + local print = verbose and (type(verbose) == "function" and verbose or print) or function() end + local file, line, watch_idx + local _, _, command = string.find(params, "^([a-z]+)") + if command == "run" or command == "step" or command == "out" + or command == "over" or command == "exit" then + client:send(string.upper(command) .. "\n") + client:receive() -- this should consume the first '200 OK' response + while true do + local done = true + local breakpoint = client:receive() + if not breakpoint then + print("Program finished") + return nil, nil, false + end + local _, _, status = string.find(breakpoint, "^(%d+)") + if status == "200" then + -- don't need to do anything + elseif status == "202" then + _, _, file, line = string.find(breakpoint, "^202 Paused%s+(.-)%s+(%d+)%s*$") + if file and line then + print("Paused at file " .. file .. " line " .. line) + end + elseif status == "203" then + _, _, file, line, watch_idx = string.find(breakpoint, "^203 Paused%s+(.-)%s+(%d+)%s+(%d+)%s*$") + if file and line and watch_idx then + print("Paused at file " .. file .. " line " .. line .. " (watch expression " .. watch_idx .. ": [" .. watches[watch_idx] .. "])") + end + elseif status == "204" then + local _, _, stream, size = string.find(breakpoint, "^204 Output (%w+) (%d+)$") + if stream and size then + local size = tonumber(size) + local msg = size > 0 and client:receive(size) or "" + print(msg) + if outputs[stream] then outputs[stream](msg) end + -- this was just the output, so go back reading the response + done = false + end + elseif status == "401" then + local _, _, size = string.find(breakpoint, "^401 Error in Execution (%d+)$") + if size then + local msg = client:receive(tonumber(size)) + print("Error in remote application: " .. msg) + return nil, nil, msg + end + else + print("Unknown error") + return nil, nil, "Debugger error: unexpected response '" .. breakpoint .. "'" + end + if done then break end + end + elseif command == "done" then + client:send(string.upper(command) .. "\n") + -- no response is expected + elseif command == "setb" or command == "asetb" then + _, _, _, file, line = string.find(params, "^([a-z]+)%s+(.-)%s+(%d+)%s*$") + if file and line then + -- if this is a file name, and not a file source + if not file:find('^".*"$') then + file = string.gsub(file, "\\", "/") -- convert slash + file = removebasedir(file, basedir) + end + client:send("SETB " .. file .. " " .. line .. "\n") + if command == "asetb" or client:receive() == "200 OK" then + set_breakpoint(file, line) + else + print("Error: breakpoint not inserted") + end + else + print("Invalid command") + end + elseif command == "setw" then + local _, _, exp = string.find(params, "^[a-z]+%s+(.+)$") + if exp then + client:send("SETW " .. exp .. "\n") + local answer = client:receive() + local _, _, watch_idx = string.find(answer, "^200 OK (%d+)%s*$") + if watch_idx then + watches[watch_idx] = exp + print("Inserted watch exp no. " .. watch_idx) + else + local _, _, size = string.find(answer, "^401 Error in Expression (%d+)$") + if size then + local err = client:receive(tonumber(size)):gsub(".-:%d+:%s*","") + print("Error: watch expression not set: " .. err) + else + print("Error: watch expression not set") + end + end + else + print("Invalid command") + end + elseif command == "delb" or command == "adelb" then + _, _, _, file, line = string.find(params, "^([a-z]+)%s+(.-)%s+(%d+)%s*$") + if file and line then + -- if this is a file name, and not a file source + if not file:find('^".*"$') then + file = string.gsub(file, "\\", "/") -- convert slash + file = removebasedir(file, basedir) + end + client:send("DELB " .. file .. " " .. line .. "\n") + if command == "adelb" or client:receive() == "200 OK" then + remove_breakpoint(file, line) + else + print("Error: breakpoint not removed") + end + else + print("Invalid command") + end + elseif command == "delallb" then + local file, line = "*", 0 + client:send("DELB " .. file .. " " .. tostring(line) .. "\n") + if client:receive() == "200 OK" then + remove_breakpoint(file, line) + else + print("Error: all breakpoints not removed") + end + elseif command == "delw" then + local _, _, index = string.find(params, "^[a-z]+%s+(%d+)%s*$") + if index then + client:send("DELW " .. index .. "\n") + if client:receive() == "200 OK" then + watches[index] = nil + else + print("Error: watch expression not removed") + end + else + print("Invalid command") + end + elseif command == "delallw" then + for index, exp in pairs(watches) do + client:send("DELW " .. index .. "\n") + if client:receive() == "200 OK" then + watches[index] = nil + else + print("Error: watch expression at index " .. index .. " [" .. exp .. "] not removed") + end + end + elseif command == "eval" or command == "exec" + or command == "load" or command == "loadstring" + or command == "reload" then + local _, _, exp = string.find(params, "^[a-z]+%s+(.+)$") + if exp or (command == "reload") then + if command == "eval" or command == "exec" then + exp = (exp:gsub("%-%-%[(=*)%[.-%]%1%]", "") -- remove comments + :gsub("%-%-.-\n", " ") -- remove line comments + :gsub("\n", " ")) -- convert new lines + if command == "eval" then exp = "return " .. exp end + client:send("EXEC " .. exp .. "\n") + elseif command == "reload" then + client:send("LOAD 0 -\n") + elseif command == "loadstring" then + local _, _, _, file, lines = string.find(exp, "^([\"'])(.-)%1%s+(.+)") + if not file then + _, _, file, lines = string.find(exp, "^(%S+)%s+(.+)") + end + client:send("LOAD " .. tostring(#lines) .. " " .. file .. "\n") + client:send(lines) + else + local file = io.open(exp, "r") + if not file and pcall(require, "winapi") then + -- if file is not open and winapi is there, try with a short path; + -- this may be needed for unicode paths on windows + winapi.set_encoding(winapi.CP_UTF8) + local shortp = winapi.short_path(exp) + file = shortp and io.open(shortp, "r") + end + if not file then return nil, nil, "Cannot open file " .. exp end + -- read the file and remove the shebang line as it causes a compilation error + local lines = file:read("*all"):gsub("^#!.-\n", "\n") + file:close() + + local file = string.gsub(exp, "\\", "/") -- convert slash + file = removebasedir(file, basedir) + client:send("LOAD " .. tostring(#lines) .. " " .. file .. "\n") + if #lines > 0 then client:send(lines) end + end + while true do + local params, err = client:receive() + if not params then + return nil, nil, "Debugger connection " .. (err or "error") + end + local done = true + local _, _, status, len = string.find(params, "^(%d+).-%s+(%d+)%s*$") + if status == "200" then + len = tonumber(len) + if len > 0 then + local status, res + local str = client:receive(len) + -- handle serialized table with results + local func, err = loadstring(str) + if func then + status, res = pcall(func) + if not status then err = res + elseif type(res) ~= "table" then + err = "received "..type(res).." instead of expected 'table'" + end + end + if err then + print("Error in processing results: " .. err) + return nil, nil, "Error in processing results: " .. err + end + print(unpack(res)) + return res[1], res + end + elseif status == "201" then + _, _, file, line = string.find(params, "^201 Started%s+(.-)%s+(%d+)%s*$") + elseif status == "202" or params == "200 OK" then + -- do nothing; this only happens when RE/LOAD command gets the response + -- that was for the original command that was aborted + elseif status == "204" then + local _, _, stream, size = string.find(params, "^204 Output (%w+) (%d+)$") + if stream and size then + local size = tonumber(size) + local msg = size > 0 and client:receive(size) or "" + print(msg) + if outputs[stream] then outputs[stream](msg) end + -- this was just the output, so go back reading the response + done = false + end + elseif status == "401" then + len = tonumber(len) + local res = client:receive(len) + print("Error in expression: " .. res) + return nil, nil, res + else + print("Unknown error") + return nil, nil, "Debugger error: unexpected response after EXEC/LOAD '" .. params .. "'" + end + if done then break end + end + else + print("Invalid command") + end + elseif command == "listb" then + for l, v in pairs(breakpoints) do + for f in pairs(v) do + print(f .. ": " .. l) + end + end + elseif command == "listw" then + for i, v in pairs(watches) do + print("Watch exp. " .. i .. ": " .. v) + end + elseif command == "suspend" then + client:send("SUSPEND\n") + elseif command == "stack" then + local opts = string.match(params, "^[a-z]+%s+(.+)$") + client:send("STACK" .. (opts and " "..opts or "") .."\n") + local resp = client:receive() + local _, _, status, res = string.find(resp, "^(%d+)%s+%w+%s+(.+)%s*$") + if status == "200" then + local func, err = loadstring(res) + if func == nil then + print("Error in stack information: " .. err) + return nil, nil, err + end + local ok, stack = pcall(func) + if not ok then + print("Error in stack information: " .. stack) + return nil, nil, stack + end + for _,frame in ipairs(stack) do + print(mobdebug.line(frame[1], {comment = false})) + end + return stack + elseif status == "401" then + local _, _, len = string.find(resp, "%s+(%d+)%s*$") + len = tonumber(len) + local res = len > 0 and client:receive(len) or "Invalid stack information." + print("Error in expression: " .. res) + return nil, nil, res + else + print("Unknown error") + return nil, nil, "Debugger error: unexpected response after STACK" + end + elseif command == "output" then + local _, _, stream, mode = string.find(params, "^[a-z]+%s+(%w+)%s+([dcr])%s*$") + if stream and mode then + client:send("OUTPUT "..stream.." "..mode.."\n") + local resp, err = client:receive() + if not resp then + print("Unknown error: "..err) + return nil, nil, "Debugger connection error: "..err + end + local _, _, status = string.find(resp, "^(%d+)%s+%w+%s*$") + if status == "200" then + print("Stream "..stream.." redirected") + outputs[stream] = type(options) == 'table' and options.handler or nil + -- the client knows when she is doing, so install the handler + elseif type(options) == 'table' and options.handler then + outputs[stream] = options.handler + else + print("Unknown error") + return nil, nil, "Debugger error: can't redirect "..stream + end + else + print("Invalid command") + end + elseif command == "basedir" then + local _, _, dir = string.find(params, "^[a-z]+%s+(.+)$") + if dir then + dir = string.gsub(dir, "\\", "/") -- convert slash + if not string.find(dir, "/$") then dir = dir .. "/" end + + local remdir = dir:match("\t(.+)") + if remdir then dir = dir:gsub("/?\t.+", "/") end + basedir = dir + + client:send("BASEDIR "..(remdir or dir).."\n") + local resp, err = client:receive() + if not resp then + print("Unknown error: "..err) + return nil, nil, "Debugger connection error: "..err + end + local _, _, status = string.find(resp, "^(%d+)%s+%w+%s*$") + if status == "200" then + print("New base directory is " .. basedir) + else + print("Unknown error") + return nil, nil, "Debugger error: unexpected response after BASEDIR" + end + else + print(basedir) + end + elseif command == "help" then + print("setb -- sets a breakpoint") + print("delb -- removes a breakpoint") + print("delallb -- removes all breakpoints") + print("setw -- adds a new watch expression") + print("delw -- removes the watch expression at index") + print("delallw -- removes all watch expressions") + print("run -- runs until next breakpoint") + print("step -- runs until next line, stepping into function calls") + print("over -- runs until next line, stepping over function calls") + print("out -- runs until line after returning from current function") + print("listb -- lists breakpoints") + print("listw -- lists watch expressions") + print("eval -- evaluates expression on the current context and returns its value") + print("exec -- executes statement on the current context") + print("load -- loads a local file for debugging") + print("reload -- restarts the current debugging session") + print("stack -- reports stack trace") + print("output stdout -- capture and redirect io stream (default|copy|redirect)") + print("basedir [] -- sets the base path of the remote application, or shows the current one") + print("done -- stops the debugger and continues application execution") + print("exit -- exits debugger and the application") + else + local _, _, spaces = string.find(params, "^(%s*)$") + if spaces then + return nil, nil, "Empty command" + else + print("Invalid command") + return nil, nil, "Invalid command" + end + end + return file, line +end + +-- Starts debugging server +local function listen(host, port) + host = host or "*" + port = port or mobdebug.port + + local socket = require "socket" + + print("Lua Remote Debugger") + print("Run the program you wish to debug") + + local server = socket.bind(host, port) + local client = server:accept() + + client:send("STEP\n") + client:receive() + + local breakpoint = client:receive() + local _, _, file, line = string.find(breakpoint, "^202 Paused%s+(.-)%s+(%d+)%s*$") + if file and line then + print("Paused at file " .. file ) + print("Type 'help' for commands") + else + local _, _, size = string.find(breakpoint, "^401 Error in Execution (%d+)%s*$") + if size then + print("Error in remote application: ") + print(client:receive(size)) + end + end + + while true do + io.write("> ") + local file, line, err = handle(io.read("*line"), client) + if not file and err == false then break end -- completed debugging + end + + client:close() +end + +local cocreate +local function coro() + if cocreate then return end -- only set once + cocreate = cocreate or coroutine.create + coroutine.create = function(f, ...) + return cocreate(function(...) + mobdebug.on() + return f(...) + end, ...) + end +end + +local moconew +local function moai() + if moconew then return end -- only set once + moconew = moconew or (MOAICoroutine and MOAICoroutine.new) + if not moconew then return end + MOAICoroutine.new = function(...) + local thread = moconew(...) + -- need to support both thread.run and getmetatable(thread).run, which + -- was used in earlier MOAI versions + local mt = thread.run and thread or getmetatable(thread) + local patched = mt.run + mt.run = function(self, f, ...) + return patched(self, function(...) + mobdebug.on() + return f(...) + end, ...) + end + return thread + end +end + +-- make public functions available +mobdebug.setbreakpoint = set_breakpoint +mobdebug.removebreakpoint = remove_breakpoint +mobdebug.listen = listen +mobdebug.loop = loop +mobdebug.scratchpad = scratchpad +mobdebug.handle = handle +mobdebug.connect = connect +mobdebug.start = start +mobdebug.on = on +mobdebug.off = off +mobdebug.moai = moai +mobdebug.coro = coro +mobdebug.done = done +mobdebug.pause = function() step_into = true end +mobdebug.yield = nil -- callback +mobdebug.output = output +mobdebug.onexit = os and os.exit or done +mobdebug.onscratch = nil -- callback +mobdebug.basedir = function(b) if b then basedir = b end return basedir end + +return mobdebug diff --git a/src/share/lib/cjson.so b/src/share/lib/cjson.so new file mode 100644 index 0000000..af89ca2 Binary files /dev/null and b/src/share/lib/cjson.so differ diff --git a/src/share/pgmoon.lua b/src/share/pgmoon.lua new file mode 100644 index 0000000..a0761d5 --- /dev/null +++ b/src/share/pgmoon.lua @@ -0,0 +1 @@ +return require('pgmoon.init') diff --git a/src/share/pgmoon/arrays.lua b/src/share/pgmoon/arrays.lua new file mode 100644 index 0000000..3ca5771 --- /dev/null +++ b/src/share/pgmoon/arrays.lua @@ -0,0 +1,196 @@ +local OIDS = { + boolean = 1000, + number = 1231, + string = 1009, + array_json = 199, + array_jsonb = 3807 +} +local is_array +is_array = function(oid) + for k, v in pairs(OIDS) do + if v == oid then + return true + end + end + return false +end +local PostgresArray +do + local _class_0 + local _base_0 = { } + _base_0.__index = _base_0 + _class_0 = setmetatable({ + __init = function() end, + __base = _base_0, + __name = "PostgresArray" + }, { + __index = _base_0, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + local self = _class_0 + self.__base.pgmoon_serialize = function(v, pg) + local escaped + do + local _accum_0 = { } + local _len_0 = 1 + for _index_0 = 1, #v do + local val = v[_index_0] + if val == pg.NULL then + _accum_0[_len_0] = "NULL" + else + local _exp_0 = type(val) + if "number" == _exp_0 then + _accum_0[_len_0] = tostring(val) + elseif "string" == _exp_0 then + _accum_0[_len_0] = '"' .. val:gsub('"', [[\"]]) .. '"' + elseif "boolean" == _exp_0 then + _accum_0[_len_0] = val and "t" or "f" + elseif "table" == _exp_0 then + local _oid, _value + do + local v_mt = getmetatable(val) + if v_mt then + if v_mt.pgmoon_serialize then + _oid, _value = v_mt.pgmoon_serialize(val, pg) + end + end + end + if _oid then + if is_array(_oid) then + _accum_0[_len_0] = _value + else + _accum_0[_len_0] = '"' .. _value:gsub('"', [[\"]]) .. '"' + end + else + return nil, "table does not implement pgmoon_serialize, can't serialize" + end + end + end + _len_0 = _len_0 + 1 + end + escaped = _accum_0 + end + local type_oid = 0 + for _index_0 = 1, #v do + local _continue_0 = false + repeat + do + local val = v[_index_0] + if val == pg.NULL then + _continue_0 = true + break + end + type_oid = OIDS[type(val)] or type_oid + break + end + _continue_0 = true + until true + if not _continue_0 then + break + end + end + return type_oid, "{" .. tostring(table.concat(escaped, ",")) .. "}" + end + PostgresArray = _class_0 +end +getmetatable(PostgresArray).__call = function(self, t) + return setmetatable(t, self.__base) +end +local default_escape_literal = nil +local insert, concat +do + local _obj_0 = table + insert, concat = _obj_0.insert, _obj_0.concat +end +local encode_array +do + local append_buffer + append_buffer = function(escape_literal, buffer, values) + for _index_0 = 1, #values do + local item = values[_index_0] + if type(item) == "table" and not getmetatable(item) then + insert(buffer, "[") + append_buffer(escape_literal, buffer, item) + buffer[#buffer] = "]" + insert(buffer, ",") + else + insert(buffer, escape_literal(item)) + insert(buffer, ",") + end + end + return buffer + end + encode_array = function(tbl, escape_literal) + escape_literal = escape_literal or default_escape_literal + if not (escape_literal) then + local Postgres + Postgres = require("pgmoon").Postgres + default_escape_literal = function(v) + return Postgres.escape_literal(nil, v) + end + escape_literal = default_escape_literal + end + local buffer = append_buffer(escape_literal, { + "ARRAY[" + }, tbl) + if buffer[#buffer] == "," then + buffer[#buffer] = "]" + else + insert(buffer, "]") + end + return concat(buffer) + end +end +local convert_values +convert_values = function(array, fn, pg) + for idx, v in ipairs(array) do + if type(v) == "table" then + convert_values(v, fn) + else + if v == "NULL" then + array[idx] = pg.NULL + elseif fn then + array[idx] = fn(v) + else + array[idx] = v + end + end + end + return array +end +local decode_array +do + local P, R, S, V, Ct, C, Cs + do + local _obj_0 = require("lpeg") + P, R, S, V, Ct, C, Cs = _obj_0.P, _obj_0.R, _obj_0.S, _obj_0.V, _obj_0.Ct, _obj_0.C, _obj_0.Cs + end + local g = P({ + "array", + array = Ct(V("open") * (V("value") * (P(",") * V("value")) ^ 0) ^ -1 * V("close")), + value = V("invalid_char") + V("string") + V("array") + V("literal"), + string = P('"') * Cs((P([[\\]]) / [[\]] + P([[\"]]) / [["]] + (P(1) - P('"'))) ^ 0) * P('"'), + literal = C((P(1) - S("},")) ^ 1), + invalid_char = S(" \t\r\n") / function() + return error("got unexpected whitespace") + end, + open = P("{"), + delim = P(","), + close = P("}") + }) + decode_array = function(str, convert_fn, pg) + local out = (assert(g:match(str), "failed to parse postgresql array")) + setmetatable(out, PostgresArray.__base) + return convert_values(out, convert_fn, (pg or require("pgmoon").Postgres)) + end +end +return { + encode_array = encode_array, + decode_array = decode_array, + PostgresArray = PostgresArray +} diff --git a/src/share/pgmoon/bit.lua b/src/share/pgmoon/bit.lua new file mode 100644 index 0000000..1b99ceb --- /dev/null +++ b/src/share/pgmoon/bit.lua @@ -0,0 +1,67 @@ +local rshift, lshift, band, bxor +local load_code +load_code = function(str) + local sent = false + return pcall(load(function() + if sent then + return nil + end + sent = true + return str + end)) +end +local ok +ok, band = load_code([[ return function(a,b) + a = a & b + if a > 0x7FFFFFFF then + -- extend the sign bit + a = ~0xFFFFFFFF | a + end + return a + end +]]) +if ok then + local _ + _, bxor = load_code([[ return function(a,b) + a = a ~ b + if a > 0x7FFFFFFF then + -- extend the sign bit + a = ~0xFFFFFFFF | a + end + return a + end + ]]) + _, lshift = load_code([[ return function(x,y) + -- limit to 32-bit shifts + y = y % 32 + x = x << y + if x > 0x7FFFFFFF then + -- extend the sign bit + x = ~0xFFFFFFFF | x + end + return x + end + ]]) + _, rshift = load_code([[ return function(x,y) + y = y % 32 + -- truncate to 32-bit before applying shift + x = x & 0xFFFFFFFF + x = x >> y + if x > 0x7FFFFFFF then + x = ~0xFFFFFFFF | x + end + return x + end + ]]) +else + do + local _obj_0 = require("bit") + rshift, lshift, band, bxor = _obj_0.rshift, _obj_0.lshift, _obj_0.band, _obj_0.bxor + end +end +return { + rshift = rshift, + lshift = lshift, + band = band, + bxor = bxor +} diff --git a/src/share/pgmoon/cqueues.lua b/src/share/pgmoon/cqueues.lua new file mode 100644 index 0000000..a43ebbf --- /dev/null +++ b/src/share/pgmoon/cqueues.lua @@ -0,0 +1,75 @@ +local flatten +flatten = require("pgmoon.util").flatten +local CqueuesSocket +do + local _class_0 + local _base_0 = { + connect = function(self, host, port, opts) + local socket = require("cqueues.socket") + local errno = require("cqueues.errno") + self.sock = socket.connect({ + host = host, + port = port + }) + if self.timeout then + self.sock:settimeout(self.timeout) + end + self.sock:setmode("bn", "bn") + local success, err = self.sock:connect() + if not (success) then + return nil, errno.strerror(err) + end + return true + end, + starttls = function(self, ...) + return self.sock:starttls(...) + end, + getpeercertificate = function(self) + local ssl = assert(self.sock:checktls()) + return assert(ssl:getPeerCertificate(), "no peer certificate available") + end, + send = function(self, ...) + return self.sock:write(flatten(...)) + end, + receive = function(self, ...) + return self.sock:read(...) + end, + close = function(self) + return self.sock:close() + end, + settimeout = function(self, t) + if t then + t = t / 1000 + end + if self.sock then + return self.sock:settimeout(t) + else + self.timeout = t + end + end, + getreusedtimes = function(self) + return 0 + end, + setkeepalive = function(self) + return error("You attempted to call setkeepalive on a cqueues.socket. This method is only available for the ngx cosocket API for releasing a socket back into the connection pool") + end + } + _base_0.__index = _base_0 + _class_0 = setmetatable({ + __init = function() end, + __base = _base_0, + __name = "CqueuesSocket" + }, { + __index = _base_0, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + CqueuesSocket = _class_0 +end +return { + CqueuesSocket = CqueuesSocket +} diff --git a/src/share/pgmoon/crypto.lua b/src/share/pgmoon/crypto.lua new file mode 100644 index 0000000..7aecf11 --- /dev/null +++ b/src/share/pgmoon/crypto.lua @@ -0,0 +1,178 @@ +local md5 +if ngx then + md5 = ngx.md5 +elseif pcall(function() + return require("openssl.digest") +end) then + local openssl_digest = require("openssl.digest") + local hex_char + hex_char = function(c) + return string.format("%02x", string.byte(c)) + end + local hex + hex = function(str) + return (str:gsub(".", hex_char)) + end + md5 = function(str) + return hex(openssl_digest.new("md5"):final(str)) + end +elseif pcall(function() + return require("crypto") +end) then + local crypto = require("crypto") + md5 = function(str) + return crypto.digest("md5", str) + end +else + md5 = function() + return error("Either luaossl (recommended) or LuaCrypto is required to calculate md5") + end +end +local hmac_sha256 +if pcall(function() + return require("openssl.hmac") +end) then + hmac_sha256 = function(key, str) + local openssl_hmac = require("openssl.hmac") + local hmac = assert(openssl_hmac.new(key, "sha256")) + hmac:update(str) + return assert(hmac:final()) + end +elseif pcall(function() + return require("resty.openssl.hmac") +end) then + hmac_sha256 = function(key, str) + local openssl_hmac = require("resty.openssl.hmac") + local hmac = assert(openssl_hmac.new(key, "sha256")) + hmac:update(str) + return assert(hmac:final()) + end +else + hmac_sha256 = function() + return error("Either luaossl or resty.openssl is required to calculate hmac sha256 digest") + end +end +local digest_sha256 +if pcall(function() + return require("openssl.digest") +end) then + digest_sha256 = function(str) + local digest = assert(require("openssl.digest").new("sha256")) + digest:update(str) + return assert(digest:final()) + end +elseif pcall(function() + return require("resty.sha256") +end) then + digest_sha256 = function(str) + local digest = assert(require("resty.sha256"):new()) + digest:update(str) + return assert(digest:final()) + end +elseif pcall(function() + return require("resty.openssl.digest") +end) then + digest_sha256 = function(str) + local digest = assert(require("resty.openssl.digest").new("sha256")) + digest:update(str) + return assert(digest:final()) + end +else + digest_sha256 = function() + return error("Either luaossl or resty.openssl is required to calculate sha256 digest") + end +end +local kdf_derive_sha256 +if pcall(function() + return require("openssl.kdf") +end) then + kdf_derive_sha256 = function(str, salt, i) + local openssl_kdf = require("openssl.kdf") + local decode_base64 + decode_base64 = require("pgmoon.util").decode_base64 + salt = decode_base64(salt) + local key, err = openssl_kdf.derive({ + type = "PBKDF2", + md = "sha256", + salt = salt, + iter = i, + pass = str, + outlen = 32 + }) + if not (key) then + return nil, "failed to derive pbkdf2 key: " .. tostring(err) + end + return key + end +elseif pcall(function() + return require("resty.openssl.kdf") +end) then + kdf_derive_sha256 = function(str, salt, i) + local openssl_kdf = require("resty.openssl.kdf") + local decode_base64 + decode_base64 = require("pgmoon.util").decode_base64 + salt = decode_base64(salt) + local key, err = openssl_kdf.derive({ + type = openssl_kdf.PBKDF2, + md = "sha256", + salt = salt, + pbkdf2_iter = i, + pass = str, + outlen = 32 + }) + if not (key) then + return nil, "failed to derive pbkdf2 key: " .. tostring(err) + end + return key + end +else + kdf_derive_sha256 = function() + return error("Either luaossl or resty.openssl is required to derive pbkdf2 key") + end +end +local random_bytes +if pcall(function() + return require("openssl.rand") +end) then + random_bytes = require("openssl.rand").bytes +elseif pcall(function() + return require("resty.random") +end) then + random_bytes = require("resty.random").bytes +elseif pcall(function() + return require("resty.openssl.rand") +end) then + random_bytes = require("resty.openssl.rand").bytes +else + random_bytes = function() + return error("Either luaossl or resty.openssl is required to generate random bytes") + end +end +local x509_digest +if pcall(function() + return require("openssl.x509") +end) then + local x509 = require("openssl.x509") + x509_digest = function(pem, hash_type) + return x509.new(pem, "PEM"):digest(hash_type, "s") + end +elseif pcall(function() + return require("resty.openssl.x509") +end) then + local x509 = require("resty.openssl.x509") + x509_digest = function(pem, hash_type) + return x509.new(pem, "PEM"):digest(hash_type) + end +else + x509_digest = function() + return error("Either luaossl or resty.openssl is required to calculate x509 digest") + end +end +return { + md5 = md5, + hmac_sha256 = hmac_sha256, + digest_sha256 = digest_sha256, + kdf_derive_sha256 = kdf_derive_sha256, + random_bytes = random_bytes, + x509_digest = x509_digest +} diff --git a/src/share/pgmoon/hstore.lua b/src/share/pgmoon/hstore.lua new file mode 100644 index 0000000..b5cb4dd --- /dev/null +++ b/src/share/pgmoon/hstore.lua @@ -0,0 +1,72 @@ +local PostgresHstore +do + local _class_0 + local _base_0 = { } + _base_0.__index = _base_0 + _class_0 = setmetatable({ + __init = function() end, + __base = _base_0, + __name = "PostgresHstore" + }, { + __index = _base_0, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + PostgresHstore = _class_0 +end +getmetatable(PostgresHstore).__call = function(self, t) + return setmetatable(t, self.__base) +end +local encode_hstore +do + encode_hstore = function(tbl, escape_literal) + if not (escape_literal) then + local Postgres + Postgres = require("pgmoon").Postgres + local default_escape_literal + default_escape_literal = function(v) + return Postgres.escape_literal(nil, v) + end + escape_literal = default_escape_literal + end + local buffer = { } + for k, v in pairs(tbl) do + table.insert(buffer, '"' .. k .. '"=>"' .. v .. '"') + end + return escape_literal(table.concat(buffer, ", ")) + end +end +local decode_hstore +do + local P, R, S, V, Ct, C, Cs, Cg, Cf + do + local _obj_0 = require("lpeg") + P, R, S, V, Ct, C, Cs, Cg, Cf = _obj_0.P, _obj_0.R, _obj_0.S, _obj_0.V, _obj_0.Ct, _obj_0.C, _obj_0.Cs, _obj_0.Cg, _obj_0.Cf + end + local g = P({ + "hstore", + hstore = Cf(Ct("") * (V("pair") * (V("delim") * V("pair")) ^ 0) ^ -1, rawset) * -1, + pair = Cg(V("value") * "=>" * (V("value") + V("null"))), + value = V("invalid_char") + V("string"), + string = P('"') * Cs((P([[\\]]) / [[\]] + P([[\"]]) / [["]] + (P(1) - P('"'))) ^ 0) * P('"'), + null = C('NULL'), + invalid_char = S(" \t\r\n") / function() + return error("got unexpected whitespace") + end, + delim = P(", ") + }) + decode_hstore = function(str, convert_fn) + local out = (assert(g:match(str), "failed to parse postgresql hstore")) + setmetatable(out, PostgresHstore.__base) + return out + end +end +return { + encode_hstore = encode_hstore, + decode_hstore = decode_hstore, + PostgresHstore = PostgresHstore +} diff --git a/src/share/pgmoon/init.lua b/src/share/pgmoon/init.lua new file mode 100644 index 0000000..89ac0e2 --- /dev/null +++ b/src/share/pgmoon/init.lua @@ -0,0 +1,1118 @@ +local socket = require("pgmoon.socket") +local insert +insert = table.insert +local rshift, lshift, band, bxor +do + local _obj_0 = require("pgmoon.bit") + rshift, lshift, band, bxor = _obj_0.rshift, _obj_0.lshift, _obj_0.band, _obj_0.bxor +end +local unpack = table.unpack or unpack +local DEBUG = false +local VERSION = "1.16.0" +local _len +_len = function(thing, t) + if t == nil then + t = type(thing) + end + local _exp_0 = t + if "string" == _exp_0 then + return #thing + elseif "table" == _exp_0 then + local l = 0 + for _index_0 = 1, #thing do + local inner = thing[_index_0] + local inner_t = type(inner) + if inner_t == "string" then + l = l + #inner + else + l = l + _len(inner, inner_t) + end + end + return l + else + return error("don't know how to calculate length of " .. tostring(t)) + end +end +local _debug_msg +_debug_msg = function(str) + return require("moon").dump((function() + local _accum_0 = { } + local _len_0 = 1 + for p in str:gmatch("[^%z]+") do + _accum_0[_len_0] = p + _len_0 = _len_0 + 1 + end + return _accum_0 + end)()) +end +local flipped +flipped = function(t) + local keys + do + local _accum_0 = { } + local _len_0 = 1 + for k in pairs(t) do + _accum_0[_len_0] = k + _len_0 = _len_0 + 1 + end + keys = _accum_0 + end + for _index_0 = 1, #keys do + local key = keys[_index_0] + t[t[key]] = key + end + return t +end +local MSG_TYPE_F = flipped({ + password = "p", + query = "Q", + parse = "P", + bind = "B", + describe = "D", + execute = "E", + close = "C", + sync = "S", + terminate = "X" +}) +local MSG_TYPE_B = flipped({ + auth = "R", + parameter_status = "S", + backend_key = "K", + ready_for_query = "Z", + parse_complete = "1", + bind_complete = "2", + close_complete = "3", + row_description = "T", + data_row = "D", + command_complete = "C", + error = "E", + notice = "N", + notification = "A" +}) +local ERROR_TYPES = flipped({ + severity = "S", + code = "C", + message = "M", + position = "P", + detail = "D", + schema = "s", + table = "t", + constraint = "n" +}) +local PG_TYPES = { + [16] = "boolean", + [17] = "bytea", + [20] = "number", + [21] = "number", + [23] = "number", + [700] = "number", + [701] = "number", + [1700] = "number", + [114] = "json", + [3802] = "json", + [1000] = "array_boolean", + [1005] = "array_number", + [1007] = "array_number", + [1016] = "array_number", + [1021] = "array_number", + [1022] = "array_number", + [1041] = "array_inet", + [1231] = "array_number", + [1009] = "array_string", + [1015] = "array_string", + [1002] = "array_string", + [1014] = "array_string", + [2951] = "array_string", + [199] = "array_json", + [3807] = "array_json" +} +local NULL = "\0" +local tobool +tobool = function(str) + return str == "t" +end +local Postgres +do + local _class_0 + local _base_0 = { + convert_null = false, + NULL = { + "NULL" + }, + PG_TYPES = PG_TYPES, + default_config = { + application_name = "pgmoon", + user = "postgres", + host = "127.0.0.1", + port = "5432", + ssl = false + }, + type_serializers = { + string = function(self, v) + return 25, v + end, + boolean = function(self, v) + return 16, v and "t" or "f" + end, + number = function(self, v) + return 1700, tostring(v) + end, + table = function(self, v) + do + local v_mt = getmetatable(v) + if v_mt then + if v_mt.pgmoon_serialize then + return v_mt.pgmoon_serialize(v, self) + end + end + end + return nil, "table does not implement pgmoon_serialize, can't serialize" + end + }, + type_deserializers = { + json = function(self, val, name) + local decode_json + decode_json = require("pgmoon.json").decode_json + return decode_json(val) + end, + bytea = function(self, val, name) + return self:decode_bytea(val) + end, + array_inet = function(self, val, name) + local decode_array + decode_array = require("pgmoon.arrays").decode_array + return decode_array(val, nil, self) + end, + array_boolean = function(self, val, name) + local decode_array + decode_array = require("pgmoon.arrays").decode_array + return decode_array(val, tobool, self) + end, + array_number = function(self, val, name) + local decode_array + decode_array = require("pgmoon.arrays").decode_array + return decode_array(val, tonumber, self) + end, + array_string = function(self, val, name) + local decode_array + decode_array = require("pgmoon.arrays").decode_array + return decode_array(val, nil, self) + end, + array_json = function(self, val, name) + local decode_array + decode_array = require("pgmoon.arrays").decode_array + local decode_json + decode_json = require("pgmoon.json").decode_json + return decode_array(val, decode_json, self) + end, + hstore = function(self, val, name) + local decode_hstore + decode_hstore = require("pgmoon.hstore").decode_hstore + return decode_hstore(val) + end + }, + set_type_oid = function(self, a, b) + print("pgmoon: WARNING: set_type_oid is deprecated for set_type_deserializer") + return self:set_type_deserializer(a, b) + end, + set_type_deserializer = function(self, oid, name, deserializer) + if not (rawget(self, "PG_TYPES")) then + do + local _tbl_0 = { } + for k, v in pairs(self.PG_TYPES) do + _tbl_0[k] = v + end + self.PG_TYPES = _tbl_0 + end + end + self.PG_TYPES[assert(tonumber(oid))] = name + if deserializer then + if not (rawget(self, "type_deserializers")) then + do + local _tbl_0 = { } + for k, v in pairs(self.type_deserializers) do + _tbl_0[k] = v + end + self.type_deserializers = _tbl_0 + end + end + self.type_deserializers[name] = deserializer + end + end, + setup_hstore = function(self) + local res = unpack(self:query("SELECT oid FROM pg_type WHERE typname = 'hstore'")) + assert(res, "hstore oid not found") + return self:set_type_deserializer(tonumber(res.oid), "hstore") + end, + connect = function(self) + local connect_opts + local _exp_0 = self.sock_type + if "nginx" == _exp_0 then + connect_opts = { + pool = self.config.pool_name or tostring(self.config.host) .. ":" .. tostring(self.config.port) .. ":" .. tostring(self.config.database) .. ":" .. tostring(self.config.user), + pool_size = self.config.pool_size, + backlog = self.config.backlog + } + end + local ok, err = self.sock:connect(self.config.host, self.config.port, connect_opts) + if not (ok) then + return nil, err + end + if self.sock:getreusedtimes() == 0 then + if self.config.ssl then + local success + success, err = self:send_ssl_message() + if not (success) then + return nil, err + end + end + local success + success, err = self:send_startup_message() + if not (success) then + return nil, err + end + success, err = self:auth() + if not (success) then + return nil, err + end + success, err = self:wait_until_ready() + if not (success) then + return nil, err + end + end + return true + end, + settimeout = function(self, ...) + return self.sock:settimeout(...) + end, + disconnect = function(self) + self:send_message(MSG_TYPE_F.terminate, { }) + return self.sock:close() + end, + keepalive = function(self, ...) + return self.sock:setkeepalive(...) + end, + create_cqueues_openssl_context = function(self) + if not (self.config.ssl_verify ~= nil or self.config.cert or self.config.key or self.config.ssl_version) then + return + end + local ssl_context = require("openssl.ssl.context") + local out = ssl_context.new(self.config.ssl_version) + if self.config.ssl_verify == true then + out:setVerify(ssl_context.VERIFY_PEER) + end + if self.config.ssl_verify == false then + out:setVerify(ssl_context.VERIFY_NONE) + end + if self.config.cert then + out:setCertificate(self.config.cert) + end + if self.config.key then + out:setPrivateKey(self.config.key) + end + return out + end, + create_luasec_opts = function(self) + return { + key = self.config.key, + certificate = self.config.cert, + cafile = self.config.cafile, + protocol = self.config.ssl_version, + verify = self.config.ssl_verify and "peer" or "none" + } + end, + auth = function(self) + local t, msg = self:receive_message() + if not (t) then + return nil, msg + end + if not (MSG_TYPE_B.auth == t) then + if MSG_TYPE_B.error == t then + return nil, self:parse_error(msg) + end + error("unexpected message during auth: " .. tostring(t)) + end + local auth_type = self:decode_int(msg, 4) + local _exp_0 = auth_type + if 0 == _exp_0 then + return true + elseif 3 == _exp_0 then + return self:cleartext_auth(msg) + elseif 5 == _exp_0 then + return self:md5_auth(msg) + elseif 10 == _exp_0 then + return self:scram_sha_256_auth(msg) + else + return error("don't know how to auth: " .. tostring(auth_type)) + end + end, + cleartext_auth = function(self, msg) + assert(self.config.password, "the database is requesting a password for authentication but you did not provide a password") + self:send_message(MSG_TYPE_F.password, { + self.config.password, + NULL + }) + return self:check_auth() + end, + scram_sha_256_auth = function(self, msg) + assert(self.config.password, "the database is requesting a password for authentication but you did not provide a password") + local random_bytes, x509_digest + do + local _obj_0 = require("pgmoon.crypto") + random_bytes, x509_digest = _obj_0.random_bytes, _obj_0.x509_digest + end + local rand_bytes = assert(random_bytes(18)) + local encode_base64 + encode_base64 = require("pgmoon.util").encode_base64 + local c_nonce = encode_base64(rand_bytes) + local nonce = "r=" .. c_nonce + local saslname = "" + local username = "n=" .. saslname + local client_first_message_bare = username .. "," .. nonce + local plus = false + local bare = false + if msg:match("SCRAM%-SHA%-256%-PLUS") then + plus = true + elseif msg:match("SCRAM%-SHA%-256") then + bare = true + else + error("unsupported SCRAM mechanism name: " .. tostring(msg)) + end + local gs2_cbind_flag + local gs2_header + local cbind_input + local mechanism_name + if bare then + gs2_cbind_flag = "n" + gs2_header = gs2_cbind_flag .. ",," + cbind_input = gs2_header + mechanism_name = "SCRAM-SHA-256" .. NULL + elseif plus then + local cb_name = "tls-server-end-point" + gs2_cbind_flag = "p=" .. cb_name + gs2_header = gs2_cbind_flag .. ",," + mechanism_name = "SCRAM-SHA-256-PLUS" .. NULL + local cbind_data + do + if self.sock_type == "cqueues" then + local openssl_x509 = self.sock:getpeercertificate() + cbind_data = openssl_x509:digest("sha256", "s") + else + local pem, signature + if self.sock_type == "nginx" then + local ssl = require("resty.openssl.ssl").from_socket(self.sock) + local server_cert = ssl:get_peer_certificate() + pem, signature = server_cert:to_PEM(), server_cert:get_signature_name() + else + local server_cert = self.sock:getpeercertificate() + pem, signature = server_cert:pem(), server_cert:getsignaturename() + end + signature = signature:lower() + local _, with_sig + _, _, with_sig = signature:find("%-with%-(.*)") + if with_sig then + signature = with_sig + end + if signature:match("^md5") or signature:match("^sha1") or signature:match("sha1$") then + signature = "sha256" + end + cbind_data = assert(x509_digest(pem, signature)) + end + end + cbind_input = gs2_header .. cbind_data + end + local client_first_message = gs2_header .. client_first_message_bare + self:send_message(MSG_TYPE_F.password, { + mechanism_name, + self:encode_int(#client_first_message), + client_first_message + }) + local t + t, msg = self:receive_message() + if not (t) then + return nil, msg + end + local server_first_message = msg:sub(5) + local int32 = self:decode_int(msg, 4) + if int32 == nil or int32 ~= 11 then + return nil, "server_first_message error: " .. msg + end + local channel_binding = "c=" .. encode_base64(cbind_input) + nonce = server_first_message:match("([^,]+)") + if not (nonce) then + return nil, "malformed server message (nonce)" + end + local client_final_message_without_proof = channel_binding .. "," .. nonce + local xor + xor = function(a, b) + local result + do + local _accum_0 = { } + local _len_0 = 1 + for i = 1, #a do + local x = a:byte(i) + local y = b:byte(i) + if not (x and y) then + return nil + end + local _value_0 = string.char(bxor(x, y)) + _accum_0[_len_0] = _value_0 + _len_0 = _len_0 + 1 + end + result = _accum_0 + end + return table.concat(result) + end + local salt = server_first_message:match(",s=([^,]+)") + if not (salt) then + return nil, "malformed server message (salt)" + end + local i = server_first_message:match(",i=(.+)") + if not (i) then + return nil, "malformed server message (iteraton count)" + end + if tonumber(i) < 4096 then + return nil, "the iteration-count sent by the server is less than 4096" + end + local kdf_derive_sha256, hmac_sha256, digest_sha256 + do + local _obj_0 = require("pgmoon.crypto") + kdf_derive_sha256, hmac_sha256, digest_sha256 = _obj_0.kdf_derive_sha256, _obj_0.hmac_sha256, _obj_0.digest_sha256 + end + local salted_password, err = kdf_derive_sha256(self.config.password, salt, tonumber(i)) + if not (salted_password) then + return nil, err + end + local client_key + client_key, err = hmac_sha256(salted_password, "Client Key") + if not (client_key) then + return nil, err + end + local stored_key + stored_key, err = digest_sha256(client_key) + if not (stored_key) then + return nil, err + end + local auth_message = tostring(client_first_message_bare) .. "," .. tostring(server_first_message) .. "," .. tostring(client_final_message_without_proof) + local client_signature + client_signature, err = hmac_sha256(stored_key, auth_message) + if not (client_signature) then + return nil, err + end + local proof = xor(client_key, client_signature) + if not (proof) then + return nil, "failed to generate the client proof" + end + local client_final_message = tostring(client_final_message_without_proof) .. ",p=" .. tostring(encode_base64(proof)) + self:send_message(MSG_TYPE_F.password, { + client_final_message + }) + t, msg = self:receive_message() + if not (t) then + return nil, msg + end + local server_key + server_key, err = hmac_sha256(salted_password, "Server Key") + if not (server_key) then + return nil, err + end + local server_signature + server_signature, err = hmac_sha256(server_key, auth_message) + if not (server_signature) then + return nil, err + end + server_signature = encode_base64(server_signature) + local sent_server_signature = msg:match("v=([^,]+)") + if server_signature ~= sent_server_signature then + return nil, "authentication exchange unsuccessful" + end + return self:check_auth() + end, + md5_auth = function(self, msg) + local md5 + md5 = require("pgmoon.crypto").md5 + local salt = msg:sub(5, 8) + assert(self.config.password, "missing password, required for connect") + self:send_message(MSG_TYPE_F.password, { + "md5", + md5(md5(self.config.password .. self.config.user) .. salt), + NULL + }) + return self:check_auth() + end, + check_auth = function(self) + local t, msg = self:receive_message() + if not (t) then + return nil, msg + end + local _exp_0 = t + if MSG_TYPE_B.error == _exp_0 then + return nil, self:parse_error(msg) + elseif MSG_TYPE_B.auth == _exp_0 then + return true + else + return error("unknown response from auth") + end + end, + query = function(self, q, ...) + if select("#", ...) > 0 then + return self:extended_query(q, ...) + else + return self:simple_query(q) + end + end, + simple_query = function(self, q) + if q:find(NULL) then + return nil, "invalid null byte in query" + end + self:send_message(MSG_TYPE_F.query, { + q, + NULL + }) + return self:receive_query_result() + end, + extended_query = function(self, q, ...) + if q:find(NULL) then + return nil, "invalid null byte in query" + end + local num_params = select("#", ...) + local parse_data = { + NULL, + q, + NULL, + self:encode_int(num_params, 2) + } + local bind_data = { + NULL, + NULL, + self:encode_int(0, 2), + self:encode_int(num_params, 2) + } + for idx = 1, num_params do + local v = select(idx, ...) + if v == self.NULL or v == nil then + insert(parse_data, self:encode_int(0)) + insert(bind_data, self:encode_int(-1)) + else + local v_type = type(v) + local type_oid, value_bytes + do + local fn = self.type_serializers[v_type] + if fn then + local _oid, _value_or_err, _third = fn(self, v) + if _oid == nil then + local full_error = "pgmoon: param " .. tostring(idx) .. ": " .. tostring(_value_or_err or "failed to serialize type: " .. tostring(v_type)) + return nil, full_error + end + if _third ~= nil then + return nil, "pgmoon: param " .. tostring(idx) .. ": please do not return a third value from serializer function, we may use this value in the future for binary formats" + end + type_oid, value_bytes = _oid, _value_or_err + else + type_oid, value_bytes = 0, tostring(v) + end + end + insert(parse_data, self:encode_int(type_oid)) + insert(bind_data, self:encode_int(#value_bytes)) + insert(bind_data, value_bytes) + end + end + insert(bind_data, self:encode_int(0, 2)) + self:send_messages({ + { + MSG_TYPE_F.parse, + parse_data + }, + { + MSG_TYPE_F.bind, + bind_data + }, + { + MSG_TYPE_F.describe, + { + "P", + NULL + } + }, + { + MSG_TYPE_F.execute, + { + NULL, + self:encode_int(0) + } + }, + { + MSG_TYPE_F.close, + { + "P", + NULL + } + }, + { + MSG_TYPE_F.sync, + { } + } + }) + return self:receive_query_result() + end, + receive_query_result = function(self) + local row_desc, data_rows, command_complete, err_msg + local result, notifications, notices + local num_queries = 0 + while true do + local t, msg = self:receive_message() + if not (t) then + return nil, msg + end + local _exp_0 = t + if MSG_TYPE_B.data_row == _exp_0 then + if not (data_rows) then + data_rows = { } + end + insert(data_rows, msg) + elseif MSG_TYPE_B.row_description == _exp_0 then + row_desc = msg + elseif MSG_TYPE_B.error == _exp_0 then + err_msg = msg + elseif MSG_TYPE_B.notice == _exp_0 then + if not (notices) then + notices = { } + end + insert(notices, (self:parse_error(msg))) + elseif MSG_TYPE_B.command_complete == _exp_0 then + command_complete = msg + local next_result = self:format_query_result(row_desc, data_rows, command_complete) + num_queries = num_queries + 1 + if num_queries == 1 then + result = next_result + elseif num_queries == 2 then + result = { + result, + next_result + } + else + insert(result, next_result) + end + row_desc, data_rows, command_complete = nil + elseif MSG_TYPE_B.ready_for_query == _exp_0 then + break + elseif MSG_TYPE_B.notification == _exp_0 then + if not (notifications) then + notifications = { } + end + insert(notifications, self:parse_notification(msg)) + elseif MSG_TYPE_B.parse_complete == _exp_0 or MSG_TYPE_B.bind_complete == _exp_0 or MSG_TYPE_B.close_complete == _exp_0 then + local _ = nil + else + if DEBUG then + print("Unhandled message in query result: " .. tostring(t)) + end + end + end + if err_msg then + return nil, self:parse_error(err_msg), result, num_queries, notifications, notices + end + return result, num_queries, notifications, notices + end, + wait_for_notification = function(self) + while true do + local t, msg = self:receive_message() + if not (t) then + return nil, msg + end + local _exp_0 = t + if MSG_TYPE_B.notification == _exp_0 then + return self:parse_notification(msg) + end + end + end, + format_query_result = function(self, row_desc, data_rows, command_complete) + local command, affected_rows + if command_complete then + command = command_complete:match("^%w+") + affected_rows = tonumber(command_complete:match("(%d+)%z$")) + end + if row_desc then + if not (data_rows) then + return { } + end + local fields = self:parse_row_desc(row_desc) + local num_rows = #data_rows + for i = 1, num_rows do + data_rows[i] = self:parse_data_row(data_rows[i], fields) + end + if affected_rows and command ~= "SELECT" then + data_rows.affected_rows = affected_rows + end + return data_rows + end + if affected_rows then + return { + affected_rows = affected_rows + } + else + return true + end + end, + parse_error = function(self, err_msg) + local severity, message, detail, position + local error_data = { } + local offset = 1 + while offset <= #err_msg do + local t = err_msg:sub(offset, offset) + local str = err_msg:match("[^%z]+", offset + 1) + if not (str) then + break + end + offset = offset + (2 + #str) + do + local field = ERROR_TYPES[t] + if field then + error_data[field] = str + end + end + local _exp_0 = t + if ERROR_TYPES.severity == _exp_0 then + severity = str + elseif ERROR_TYPES.message == _exp_0 then + message = str + elseif ERROR_TYPES.position == _exp_0 then + position = str + elseif ERROR_TYPES.detail == _exp_0 then + detail = str + end + end + local msg = tostring(severity) .. ": " .. tostring(message) + if position then + msg = tostring(msg) .. " (" .. tostring(position) .. ")" + end + if detail then + msg = tostring(msg) .. "\n" .. tostring(detail) + end + return msg, error_data + end, + parse_row_desc = function(self, row_desc) + local num_fields = self:decode_int(row_desc:sub(1, 2)) + local offset = 3 + local fields + do + local _accum_0 = { } + local _len_0 = 1 + for i = 1, num_fields do + local name = row_desc:match("[^%z]+", offset) + offset = offset + #name + 1 + local data_type = self:decode_int(row_desc:sub(offset + 6, offset + 6 + 3)) + data_type = self.PG_TYPES[data_type] or "string" + local format = self:decode_int(row_desc:sub(offset + 16, offset + 16 + 1)) + assert(0 == format, "don't know how to handle format") + offset = offset + 18 + local _value_0 = { + name, + data_type + } + _accum_0[_len_0] = _value_0 + _len_0 = _len_0 + 1 + end + fields = _accum_0 + end + return fields + end, + parse_data_row = function(self, data_row, fields) + local num_fields = self:decode_int(data_row:sub(1, 2)) + local out = { } + local offset = 3 + for i = 1, num_fields do + local _continue_0 = false + repeat + local field = fields[i] + if not (field) then + _continue_0 = true + break + end + local field_name, field_type + field_name, field_type = field[1], field[2] + local len = self:decode_int(data_row:sub(offset, offset + 3)) + offset = offset + 4 + if len < 0 then + if self.convert_null then + out[field_name] = self.NULL + end + _continue_0 = true + break + end + local value = data_row:sub(offset, offset + len - 1) + offset = offset + len + local _exp_0 = field_type + if "number" == _exp_0 then + value = tonumber(value) + elseif "boolean" == _exp_0 then + value = value == "t" + elseif "string" == _exp_0 then + local _ = nil + else + do + local fn = self.type_deserializers[field_type] + if fn then + value = fn(self, value, field_type) + end + end + end + out[field_name] = value + _continue_0 = true + until true + if not _continue_0 then + break + end + end + return out + end, + parse_notification = function(self, msg) + local pid = self:decode_int(msg:sub(1, 4)) + local offset = 4 + local channel, payload = msg:match("^([^%z]+)%z([^%z]*)%z$", offset + 1) + if not (channel) then + error("parse_notification: failed to parse notification") + end + return { + operation = "notification", + pid = pid, + channel = channel, + payload = payload + } + end, + wait_until_ready = function(self) + while true do + local t, msg = self:receive_message() + if not (t) then + return nil, msg + end + if MSG_TYPE_B.error == t then + return nil, self:parse_error(msg) + end + if MSG_TYPE_B.ready_for_query == t then + break + end + end + return true + end, + receive_message = function(self) + local prefix, err = self.sock:receive(5) + if not (prefix) then + return nil, "receive_message: failed to get type: " .. tostring(err) + end + local t = prefix:sub(1, 1) + local len = prefix:sub(2) + len = self:decode_int(len) + len = len - 4 + local msg = self.sock:receive(len) + return t, msg + end, + send_startup_message = function(self) + assert(self.config.user, "missing user for connect") + assert(self.config.database, "missing database for connect") + local data = { + self:encode_int(196608), + "user", + NULL, + self.config.user, + NULL, + "database", + NULL, + self.config.database, + NULL, + "application_name", + NULL, + self.config.application_name, + NULL, + NULL + } + return self.sock:send({ + self:encode_int(_len(data) + 4), + data + }) + end, + send_ssl_message = function(self) + local success, err = self.sock:send({ + self:encode_int(8), + self:encode_int(80877103) + }) + if not (success) then + return nil, err + end + local t + t, err = self.sock:receive(1) + if not (t) then + return nil, err + end + if t == MSG_TYPE_B.parameter_status then + local _exp_0 = self.sock_type + if "nginx" == _exp_0 then + return self.sock:sslhandshake(false, nil, self.config.ssl_verify) + elseif "luasocket" == _exp_0 then + return self.sock:sslhandshake(self.config.luasec_opts or self:create_luasec_opts()) + elseif "cqueues" == _exp_0 then + return self.sock:starttls(self.config.cqueues_openssl_context or self:create_cqueues_openssl_context()) + else + return error("don't know how to do ssl handshake for socket type: " .. tostring(self.sock_type)) + end + elseif t == MSG_TYPE_B.error or self.config.ssl_required then + return nil, "the server does not support SSL connections" + else + return true + end + end, + send_messages = function(self, messages) + local data + do + local _accum_0 = { } + local _len_0 = 1 + for _index_0 = 1, #messages do + local _des_0 = messages[_index_0] + local message_type, message_data + message_type, message_data = _des_0[1], _des_0[2] + local len = _len(message_data) + len = len + 4 + local _value_0 = { + message_type, + self:encode_int(len), + message_data + } + _accum_0[_len_0] = _value_0 + _len_0 = _len_0 + 1 + end + data = _accum_0 + end + return self.sock:send(data) + end, + send_message = function(self, t, data, len) + if len == nil then + len = _len(data) + end + len = len + 4 + return self.sock:send({ + t, + self:encode_int(len), + data + }) + end, + decode_int = function(self, str, bytes) + if bytes == nil then + bytes = #str + end + local _exp_0 = str + if "\0\0" == _exp_0 or "\0\0\0\0" == _exp_0 then + return 0 + end + local _exp_1 = bytes + if 4 == _exp_1 then + local d, c, b, a = str:byte(1, 4) + return a + lshift(b, 8) + lshift(c, 16) + lshift(d, 24) + elseif 2 == _exp_1 then + local b, a = str:byte(1, 2) + return a + lshift(b, 8) + else + return error("don't know how to decode " .. tostring(bytes) .. " byte(s)") + end + end, + encode_int = function(self, n, bytes) + if bytes == nil then + bytes = 4 + end + if n == 0 then + if bytes == 2 then + return "\0\0" + end + if bytes == 4 then + return "\0\0\0\0" + end + end + local _exp_0 = bytes + if 4 == _exp_0 then + local a = band(n, 0xff) + local b = band(rshift(n, 8), 0xff) + local c = band(rshift(n, 16), 0xff) + local d = band(rshift(n, 24), 0xff) + return string.char(d, c, b, a) + elseif 2 == _exp_0 then + local a = band(n, 0xff) + local b = band(rshift(n, 8), 0xff) + return string.char(b, a) + else + return error("don't know how to encode " .. tostring(bytes) .. " byte(s)") + end + end, + decode_bytea = function(self, str) + if str:sub(1, 2) == '\\x' then + return str:sub(3):gsub('..', function(hex) + return string.char(tonumber(hex, 16)) + end) + else + return str:gsub('\\(%d%d%d)', function(oct) + return string.char(tonumber(oct, 8)) + end) + end + end, + encode_bytea = function(self, str) + return string.format("E'\\\\x%s'", str:gsub('.', function(byte) + return string.format('%02x', string.byte(byte)) + end)) + end, + escape_identifier = function(self, ident) + return '"' .. (tostring(ident):gsub('"', '""')) .. '"' + end, + escape_literal = function(self, val) + if val == (self and self.NULL or Postgres.NULL) then + return "NULL" + end + local _exp_0 = type(val) + if "number" == _exp_0 then + return tostring(val) + elseif "string" == _exp_0 then + return "'" .. tostring((val:gsub("'", "''"))) .. "'" + elseif "boolean" == _exp_0 then + return val and "TRUE" or "FALSE" + end + return error("don't know how to escape value: " .. tostring(val)) + end, + __tostring = function(self) + return "" + end + } + _base_0.__index = _base_0 + _class_0 = setmetatable({ + __init = function(self, _config) + if _config == nil then + _config = { } + end + self._config = _config + self.config = setmetatable({ }, { + __index = function(t, key) + local value = self._config[key] + if value == nil then + return self.default_config[key] + else + return value + end + end + }) + self.convert_null = self.config.convert_null + self.sock, self.sock_type = socket.new(self.config.socket_type) + end, + __base = _base_0, + __name = "Postgres" + }, { + __index = _base_0, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + Postgres = _class_0 +end +return { + Postgres = Postgres, + new = Postgres, + VERSION = VERSION +} diff --git a/src/share/pgmoon/json.lua b/src/share/pgmoon/json.lua new file mode 100644 index 0000000..c7872cb --- /dev/null +++ b/src/share/pgmoon/json.lua @@ -0,0 +1,25 @@ +local default_escape_literal = nil +local encode_json +encode_json = function(tbl, escape_literal) + escape_literal = escape_literal or default_escape_literal + local json = require("cjson") + if not (escape_literal) then + local Postgres + Postgres = require("pgmoon").Postgres + default_escape_literal = function(v) + return Postgres.escape_literal(nil, v) + end + escape_literal = default_escape_literal + end + local enc = json.encode(tbl) + return escape_literal(enc) +end +local decode_json +decode_json = function(str) + local json = require("cjson") + return json.decode(str) +end +return { + encode_json = encode_json, + decode_json = decode_json +} diff --git a/src/share/pgmoon/socket.lua b/src/share/pgmoon/socket.lua new file mode 100644 index 0000000..8db4fc1 --- /dev/null +++ b/src/share/pgmoon/socket.lua @@ -0,0 +1,109 @@ +local create_luasocket +do + local flatten + flatten = require("pgmoon.util").flatten + local proxy_mt = { + __index = function(self, key) + local sock = self.sock + local original = sock[key] + if type(original) == "function" then + local fn + fn = function(_, ...) + return original(sock, ...) + end + self[key] = fn + return fn + else + return original + end + end + } + local method_overrides + method_overrides = { + send = function(self, ...) + return self.sock:send(flatten(...)) + end, + settimeout = function(self, t) + if t then + t = t / 1000 + end + return self.sock:settimeout(t) + end, + setkeepalive = function(self) + return error("You attempted to call setkeepalive on a LuaSocket socket. This method is only available for the ngx cosocket API for releasing a socket back into the connection pool") + end, + getreusedtimes = function(self, t) + return 0 + end, + sslhandshake = function(self, opts) + if opts == nil then + opts = { } + end + local ssl = require("ssl") + local params = { + mode = "client", + protocol = "any", + verify = "none", + options = { + "all", + "no_sslv2", + "no_sslv3", + "no_tlsv1" + } + } + for k, v in pairs(opts) do + params[k] = v + end + local sec_sock, err = ssl.wrap(self.sock, params) + if not (sec_sock) then + return false, err + end + local success + success, err = sec_sock:dohandshake() + if not (success) then + return false, err + end + for k, v in pairs(self) do + if not method_overrides[k] and type(v) == "function" then + self[k] = nil + end + end + self.sock = sec_sock + return true + end + } + create_luasocket = function(...) + local socket = require("socket") + local proxy = { + sock = socket.tcp(...) + } + for k, v in pairs(method_overrides) do + proxy[k] = v + end + return setmetatable(proxy, proxy_mt) + end +end +return { + create_luasocket = create_luasocket, + new = function(socket_type) + if socket_type == nil then + if ngx and ngx.get_phase() ~= "init" then + socket_type = "nginx" + else + socket_type = "luasocket" + end + end + local socket + local _exp_0 = socket_type + if "nginx" == _exp_0 then + socket = ngx.socket.tcp() + elseif "luasocket" == _exp_0 then + socket = create_luasocket() + elseif "cqueues" == _exp_0 then + socket = require("pgmoon.cqueues").CqueuesSocket() + else + socket = error("got unknown or unset socket type: " .. tostring(socket_type)) + end + return socket, socket_type + end +} diff --git a/src/share/pgmoon/util.lua b/src/share/pgmoon/util.lua new file mode 100644 index 0000000..9752404 --- /dev/null +++ b/src/share/pgmoon/util.lua @@ -0,0 +1,46 @@ +local flatten +do + local __flatten + __flatten = function(t, buffer) + local _exp_0 = type(t) + if "string" == _exp_0 then + buffer[#buffer + 1] = t + elseif "number" == _exp_0 then + buffer[#buffer + 1] = tostring(t) + elseif "table" == _exp_0 then + for _index_0 = 1, #t do + local thing = t[_index_0] + __flatten(thing, buffer) + end + end + end + flatten = function(t) + local buffer = { } + __flatten(t, buffer) + return table.concat(buffer) + end +end +local encode_base64, decode_base64 +if ngx then + do + local _obj_0 = ngx + encode_base64, decode_base64 = _obj_0.encode_base64, _obj_0.decode_base64 + end +else + local b64, unb64 + do + local _obj_0 = require("mime") + b64, unb64 = _obj_0.b64, _obj_0.unb64 + end + encode_base64 = function(...) + return (b64(...)) + end + decode_base64 = function(...) + return (unb64(...)) + end +end +return { + flatten = flatten, + encode_base64 = encode_base64, + decode_base64 = decode_base64 +} diff --git a/src/share/redis.lua b/src/share/redis.lua new file mode 100644 index 0000000..7ced7fb --- /dev/null +++ b/src/share/redis.lua @@ -0,0 +1,62 @@ +local redis = require("loadlib.resty_redis") +local db_config = require('config.database') + +local _M = setmetatable({}, {__index=function(self, key) + local red = redis:new() + local ok,err = red:connect(db_config.redis.host, db_config.redis.port) + if not ok then + ngx.log(ngx.ERR, err) + end + if key == 'red' then + return red + end +end}) + +function _M:set(key, value, time) + local ok, err = self.red:set(key, value) + if not ok then + return false, "redis failed to set data: " .. err + end + if time then + ok,err = self.red:expire(key, time) -- default expire time is seconds + if not ok then + return false,err + end + end + return true +end + +function _M:get(key) + local value = self.red:get(key) + if value == ngx.null then + return nil + else + return value + end +end + +function _M:del(key) + return self.red:del(key) +end + +function _M:expire(key, time) + local ok,err = self.red:expire(key, time) -- default time is seconds + if not ok then + return false,err + end + return true +end + +function _M:incr(key) + local ok,err = self.red:incr(key) + if not ok then + return false, err + end + return true +end + +function _M:ttl(key) + return self.red:ttl(key) +end + +return _M \ No newline at end of file diff --git a/src/share/resty_redis.lua b/src/share/resty_redis.lua new file mode 100644 index 0000000..f4b2dfd --- /dev/null +++ b/src/share/resty_redis.lua @@ -0,0 +1,410 @@ +-- Copyright (C) + +local sub = string.sub +local byte = string.byte +local tcp = ngx.socket.tcp +local null = ngx.null +local type = type +local pairs = pairs +local unpack = unpack +local setmetatable = setmetatable +local tonumber = tonumber +local tostring = tostring +local rawget = rawget +--local error = error + +local ok, new_tab = pcall(require, "table.new") +if not ok or type(new_tab) ~= "function" then + new_tab = function (narr, nrec) return {} end +end + +local _M = new_tab(0, 54) + +_M._VERSION = '0.26' + +local common_cmds = { + "get", "set", "mget", "mset", + "del", "incr", "decr", -- Strings + "llen", "lindex", "lpop", "lpush", + "lrange", "linsert", -- Lists + "hexists", "hget", "hset", "hmget", + --[[ "hmset", ]] "hdel", -- Hashes + "smembers", "sismember", "sadd", "srem", + "sdiff", "sinter", "sunion", -- Sets + "zrange", "zrangebyscore", "zrank", "zadd", + "zrem", "zincrby", -- Sorted Sets + "auth", "eval", "expire", "script", + "sort" -- Others +} + +local sub_commands = { + "subscribe", "psubscribe" +} + +local unsub_commands = { + "unsubscribe", "punsubscribe" +} + +local mt = { __index = _M } + +function _M.new(self) + local sock, err = tcp() + if not sock then + return nil, err + end + return setmetatable({ _sock = sock, _subscribed = false }, mt) +end + +function _M.set_timeout(self, timeout) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + return sock:settimeout(timeout) +end + +function _M.connect(self, ...) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + self._subscribed = false + + return sock:connect(...) +end + +function _M.set_keepalive(self, ...) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + if rawget(self, "_subscribed") then + return nil, "subscribed state" + end + + return sock:setkeepalive(...) +end + +function _M.get_reused_times(self) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + return sock:getreusedtimes() +end + +local function close(self) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + return sock:close() +end +_M.close = close + +local function _read_reply(self, sock) + local line, err = sock:receive() + if not line then + if err == "timeout" and not rawget(self, "_subscribed") then + sock:close() + end + return nil, err + end + + local prefix = byte(line) + + if prefix == 36 then -- char '$' + -- print("bulk reply") + + local size = tonumber(sub(line, 2)) + if size < 0 then + return null + end + + local data, err = sock:receive(size) + if not data then + if err == "timeout" then + sock:close() + end + return nil, err + end + + local dummy, err = sock:receive(2) -- ignore CRLF + if not dummy then + return nil, err + end + + return data + + elseif prefix == 43 then -- char '+' + -- print("status reply") + + return sub(line, 2) + + elseif prefix == 42 then -- char '*' + local n = tonumber(sub(line, 2)) + + -- print("multi-bulk reply: ", n) + if n < 0 then + return null + end + + local vals = new_tab(n, 0) + local nvals = 0 + for i = 1, n do + local res, err = _read_reply(self, sock) + if res then + nvals = nvals + 1 + vals[nvals] = res + + elseif res == nil then + return nil, err + + else + -- be a valid redis error value + nvals = nvals + 1 + vals[nvals] = {false, err} + end + end + + return vals + + elseif prefix == 58 then -- char ':' + -- print("integer reply") + return tonumber(sub(line, 2)) + + elseif prefix == 45 then -- char '-' + -- print("error reply: ", n) + + return false, sub(line, 2) + + else + -- when `line` is an empty string, `prefix` will be equal to nil. + return nil, "unknown prefix: \"" .. tostring(prefix) .. "\"" + end +end + +local function _gen_req(args) + local nargs = #args + + local req = new_tab(nargs * 5 + 1, 0) + req[1] = "*" .. nargs .. "\r\n" + local nbits = 2 + + for i = 1, nargs do + local arg = args[i] + if type(arg) ~= "string" then + arg = tostring(arg) + end + + req[nbits] = "$" + req[nbits + 1] = #arg + req[nbits + 2] = "\r\n" + req[nbits + 3] = arg + req[nbits + 4] = "\r\n" + + nbits = nbits + 5 + end + + -- it is much faster to do string concatenation on the C land + -- in real world (large number of strings in the Lua VM) + return req +end + +local function _do_cmd(self, ...) + local args = {...} + + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + local req = _gen_req(args) + + local reqs = rawget(self, "_reqs") + if reqs then + reqs[#reqs + 1] = req + return + end + + -- print("request: ", table.concat(req)) + + local bytes, err = sock:send(req) + if not bytes then + return nil, err + end + + return _read_reply(self, sock) +end + +local function _check_subscribed(self, res) + if type(res) == "table" + and (res[1] == "unsubscribe" or res[1] == "punsubscribe") + and res[3] == 0 + then + self._subscribed = false + end +end + +function _M.read_reply(self) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + if not rawget(self, "_subscribed") then + return nil, "not subscribed" + end + + local res, err = _read_reply(self, sock) + _check_subscribed(self, res) + + return res, err +end + +for i = 1, #common_cmds do + local cmd = common_cmds[i] + + _M[cmd] = + function (self, ...) + return _do_cmd(self, cmd, ...) + end +end + +for i = 1, #sub_commands do + local cmd = sub_commands[i] + + _M[cmd] = + function (self, ...) + self._subscribed = true + return _do_cmd(self, cmd, ...) + end +end + +for i = 1, #unsub_commands do + local cmd = unsub_commands[i] + + _M[cmd] = + function (self, ...) + local res, err = _do_cmd(self, cmd, ...) + _check_subscribed(self, res) + return res, err + end +end + +function _M.hmset(self, hashname, ...) + if select('#', ...) == 1 then + local t = select(1, ...) + + local n = 0 + for k, v in pairs(t) do + n = n + 2 + end + + local array = new_tab(n, 0) + + local i = 0 + for k, v in pairs(t) do + array[i + 1] = k + array[i + 2] = v + i = i + 2 + end + -- print("key", hashname) + return _do_cmd(self, "hmset", hashname, unpack(array)) + end + + -- backwards compatibility + return _do_cmd(self, "hmset", hashname, ...) +end + +function _M.init_pipeline(self, n) + self._reqs = new_tab(n or 4, 0) +end + +function _M.cancel_pipeline(self) + self._reqs = nil +end + +function _M.commit_pipeline(self) + local reqs = rawget(self, "_reqs") + if not reqs then + return nil, "no pipeline" + end + + self._reqs = nil + + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + local bytes, err = sock:send(reqs) + if not bytes then + return nil, err + end + + local nvals = 0 + local nreqs = #reqs + local vals = new_tab(nreqs, 0) + for i = 1, nreqs do + local res, err = _read_reply(self, sock) + if res then + nvals = nvals + 1 + vals[nvals] = res + + elseif res == nil then + if err == "timeout" then + close(self) + end + return nil, err + + else + -- be a valid redis error value + nvals = nvals + 1 + vals[nvals] = {false, err} + end + end + + return vals +end + +function _M.array_to_hash(self, t) + local n = #t + -- print("n = ", n) + local h = new_tab(0, n / 2) + for i = 1, n, 2 do + h[t[i]] = t[i + 1] + end + return h +end + +-- this method is deperate since we already do lazy method generation. +function _M.add_commands(...) + local cmds = {...} + for i = 1, #cmds do + local cmd = cmds[i] + _M[cmd] = + function (self, ...) + return _do_cmd(self, cmd, ...) + end + end +end + +setmetatable(_M, {__index = function(self, cmd) + local method = + function (self, ...) + return _do_cmd(self, cmd, ...) + end + + -- cache the lazily generated method in our + -- module table + _M[cmd] = method + return method +end}) + +return _M \ No newline at end of file diff --git a/src/test.lua b/src/test.lua new file mode 100644 index 0000000..151bfdf --- /dev/null +++ b/src/test.lua @@ -0,0 +1,22 @@ +require("mobdebug").start("127.0.0.1") + +--用于接收前端数据的对象 +local args = nil +--获取前端的请求方式 并获取传递的参数 +local request_method = ngx.var.request_method +--判断是get请求还是post请求并分别拿出相应的数据 +if "GET" == request_method then + args = ngx.req.get_uri_args() +elseif "POST" == request_method then + ngx.req.read_body() + args = ngx.req.get_post_args() + --兼容请求使用post请求,但是传参以get方式传造成的无法获取到数据的bug + if (args == nil or args.data == null) then + args = ngx.req.get_uri_args() + end +end + +--获取前端传递的name值 +local name = args.name +--响应前端 +ngx.say("hello:"..name) \ No newline at end of file