mirror of
https://github.com/luapower/mysql.git
synced 2025-01-04 07:10:25 +01:00
unimportant
This commit is contained in:
commit
89969e2809
948
mysql_client.lua
Normal file
948
mysql_client.lua
Normal file
@ -0,0 +1,948 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
local tcp = require'sock'.tcp
|
||||
local sha1 = require'sha1'
|
||||
local bit = require'bit'
|
||||
|
||||
local sub = string.sub
|
||||
local strbyte = string.byte
|
||||
local strchar = string.char
|
||||
local strfind = string.find
|
||||
local format = string.format
|
||||
local strrep = string.rep
|
||||
local null = function() end
|
||||
local band = bit.band
|
||||
local bxor = bit.bxor
|
||||
local bor = bit.bor
|
||||
local lshift = bit.lshift
|
||||
local rshift = bit.rshift
|
||||
local tohex = bit.tohex
|
||||
local concat = table.concat
|
||||
local unpack = unpack
|
||||
local setmetatable = setmetatable
|
||||
local error = error
|
||||
local tonumber = tonumber
|
||||
|
||||
local ok, new_tab = pcall(require, 'table.new')
|
||||
if not ok then
|
||||
new_tab = function (narr, nrec) return {} end
|
||||
end
|
||||
|
||||
local _M = { _VERSION = '0.21' }
|
||||
|
||||
|
||||
-- constants
|
||||
|
||||
local STATE_CONNECTED = 1
|
||||
local STATE_COMMAND_SENT = 2
|
||||
|
||||
local COM_QUIT = 0x01
|
||||
local COM_QUERY = 0x03
|
||||
local CLIENT_SSL = 0x0800
|
||||
|
||||
local SERVER_MORE_RESULTS_EXISTS = 8
|
||||
|
||||
-- 16MB - 1, the default max allowed packet size used by libmysqlclient
|
||||
local FULL_PACKET_SIZE = 16777215
|
||||
|
||||
-- the following charset map is generated from the following mysql query:
|
||||
-- SELECT CHARACTER_SET_NAME, ID
|
||||
-- FROM information_schema.collations
|
||||
-- WHERE IS_DEFAULT = 'Yes' ORDER BY id;
|
||||
local CHARSET_MAP = {
|
||||
_default = 0,
|
||||
big5 = 1,
|
||||
dec8 = 3,
|
||||
cp850 = 4,
|
||||
hp8 = 6,
|
||||
koi8r = 7,
|
||||
latin1 = 8,
|
||||
latin2 = 9,
|
||||
swe7 = 10,
|
||||
ascii = 11,
|
||||
ujis = 12,
|
||||
sjis = 13,
|
||||
hebrew = 16,
|
||||
tis620 = 18,
|
||||
euckr = 19,
|
||||
koi8u = 22,
|
||||
gb2312 = 24,
|
||||
greek = 25,
|
||||
cp1250 = 26,
|
||||
gbk = 28,
|
||||
latin5 = 30,
|
||||
armscii8 = 32,
|
||||
utf8 = 33,
|
||||
ucs2 = 35,
|
||||
cp866 = 36,
|
||||
keybcs2 = 37,
|
||||
macce = 38,
|
||||
macroman = 39,
|
||||
cp852 = 40,
|
||||
latin7 = 41,
|
||||
utf8mb4 = 45,
|
||||
cp1251 = 51,
|
||||
utf16 = 54,
|
||||
utf16le = 56,
|
||||
cp1256 = 57,
|
||||
cp1257 = 59,
|
||||
utf32 = 60,
|
||||
binary = 63,
|
||||
geostd8 = 92,
|
||||
cp932 = 95,
|
||||
eucjpms = 97,
|
||||
gb18030 = 248
|
||||
}
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
-- mysql field value type converters
|
||||
local converters = new_tab(0, 9)
|
||||
|
||||
for i = 0x01, 0x05 do
|
||||
-- tiny, short, long, float, double
|
||||
converters[i] = tonumber
|
||||
end
|
||||
converters[0x00] = tonumber -- decimal
|
||||
-- converters[0x08] = tonumber -- long long
|
||||
converters[0x09] = tonumber -- int24
|
||||
converters[0x0d] = tonumber -- year
|
||||
converters[0xf6] = tonumber -- newdecimal
|
||||
|
||||
|
||||
local function _get_byte2(data, i)
|
||||
local a, b = strbyte(data, i, i + 1)
|
||||
return bor(a, lshift(b, 8)), i + 2
|
||||
end
|
||||
|
||||
|
||||
local function _get_byte3(data, i)
|
||||
local a, b, c = strbyte(data, i, i + 2)
|
||||
return bor(a, lshift(b, 8), lshift(c, 16)), i + 3
|
||||
end
|
||||
|
||||
|
||||
local function _get_byte4(data, i)
|
||||
local a, b, c, d = strbyte(data, i, i + 3)
|
||||
return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4
|
||||
end
|
||||
|
||||
|
||||
local function _get_byte8(data, i)
|
||||
local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7)
|
||||
|
||||
-- XXX workaround for the lack of 64-bit support in bitop:
|
||||
local lo = bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24))
|
||||
local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24))
|
||||
return lo + hi * 4294967296, i + 8
|
||||
|
||||
-- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32),
|
||||
-- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8
|
||||
end
|
||||
|
||||
|
||||
local function _set_byte2(n)
|
||||
return strchar(band(n, 0xff), band(rshift(n, 8), 0xff))
|
||||
end
|
||||
|
||||
|
||||
local function _set_byte3(n)
|
||||
return strchar(band(n, 0xff),
|
||||
band(rshift(n, 8), 0xff),
|
||||
band(rshift(n, 16), 0xff))
|
||||
end
|
||||
|
||||
|
||||
local function _set_byte4(n)
|
||||
return strchar(band(n, 0xff),
|
||||
band(rshift(n, 8), 0xff),
|
||||
band(rshift(n, 16), 0xff),
|
||||
band(rshift(n, 24), 0xff))
|
||||
end
|
||||
|
||||
|
||||
local function _from_cstring(data, i)
|
||||
local last = strfind(data, "\0", i, true)
|
||||
if not last then
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
return sub(data, i, last - 1), last + 1
|
||||
end
|
||||
|
||||
|
||||
local function _to_cstring(data)
|
||||
return data .. "\0"
|
||||
end
|
||||
|
||||
|
||||
local function _to_binary_coded_string(data)
|
||||
return strchar(#data) .. data
|
||||
end
|
||||
|
||||
|
||||
local function _dump(data)
|
||||
local len = #data
|
||||
local bytes = new_tab(len, 0)
|
||||
for i = 1, len do
|
||||
bytes[i] = format("%x", strbyte(data, i))
|
||||
end
|
||||
return concat(bytes, " ")
|
||||
end
|
||||
|
||||
|
||||
local function _dumphex(data)
|
||||
local len = #data
|
||||
local bytes = new_tab(len, 0)
|
||||
for i = 1, len do
|
||||
bytes[i] = tohex(strbyte(data, i), 2)
|
||||
end
|
||||
return concat(bytes, " ")
|
||||
end
|
||||
|
||||
|
||||
local function _compute_token(password, scramble)
|
||||
if password == "" then
|
||||
return ""
|
||||
end
|
||||
|
||||
local stage1 = sha1(password)
|
||||
local stage2 = sha1(stage1)
|
||||
local stage3 = sha1(scramble .. stage2)
|
||||
local n = #stage1
|
||||
local bytes = new_tab(n, 0)
|
||||
for i = 1, n do
|
||||
bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i)))
|
||||
end
|
||||
|
||||
return concat(bytes)
|
||||
end
|
||||
|
||||
|
||||
local function _send_packet(self, req, size)
|
||||
local sock = self.sock
|
||||
|
||||
self.packet_no = self.packet_no + 1
|
||||
|
||||
-- print("packet no: ", self.packet_no)
|
||||
|
||||
local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req
|
||||
|
||||
-- print("sending packet: ", _dump(packet))
|
||||
|
||||
-- print("sending packet... of size " .. #packet)
|
||||
|
||||
return sock:send(packet)
|
||||
end
|
||||
|
||||
|
||||
local function _recv_packet(self)
|
||||
local sock = self.sock
|
||||
|
||||
local data, err = sock:receive(4) -- packet header
|
||||
if not data then
|
||||
return nil, nil, "failed to receive packet header: " .. err
|
||||
end
|
||||
|
||||
--print("packet header: ", _dump(data))
|
||||
|
||||
local len, pos = _get_byte3(data, 1)
|
||||
|
||||
--print("packet length: ", len)
|
||||
|
||||
if len == 0 then
|
||||
return nil, nil, "empty packet"
|
||||
end
|
||||
|
||||
if len > self._max_packet_size then
|
||||
return nil, nil, "packet size too big: " .. len
|
||||
end
|
||||
|
||||
local num = strbyte(data, pos)
|
||||
|
||||
--print("recv packet: packet no: ", num)
|
||||
|
||||
self.packet_no = num
|
||||
|
||||
data, err = sock:receive(len)
|
||||
|
||||
--print("receive returned")
|
||||
|
||||
if not data then
|
||||
return nil, nil, "failed to read packet content: " .. err
|
||||
end
|
||||
|
||||
--print("packet content: ", _dump(data))
|
||||
--print("packet content (ascii): ", data)
|
||||
|
||||
local field_count = strbyte(data, 1)
|
||||
|
||||
local typ
|
||||
if field_count == 0x00 then
|
||||
typ = "OK"
|
||||
elseif field_count == 0xff then
|
||||
typ = "ERR"
|
||||
elseif field_count == 0xfe then
|
||||
typ = "EOF"
|
||||
else
|
||||
typ = "DATA"
|
||||
end
|
||||
|
||||
return data, typ
|
||||
end
|
||||
|
||||
|
||||
local function _from_length_coded_bin(data, pos)
|
||||
local first = strbyte(data, pos)
|
||||
|
||||
--print("LCB: first: ", first)
|
||||
|
||||
if not first then
|
||||
return nil, pos
|
||||
end
|
||||
|
||||
if first >= 0 and first <= 250 then
|
||||
return first, pos + 1
|
||||
end
|
||||
|
||||
if first == 251 then
|
||||
return null, pos + 1
|
||||
end
|
||||
|
||||
if first == 252 then
|
||||
pos = pos + 1
|
||||
return _get_byte2(data, pos)
|
||||
end
|
||||
|
||||
if first == 253 then
|
||||
pos = pos + 1
|
||||
return _get_byte3(data, pos)
|
||||
end
|
||||
|
||||
if first == 254 then
|
||||
pos = pos + 1
|
||||
return _get_byte8(data, pos)
|
||||
end
|
||||
|
||||
return nil, pos + 1
|
||||
end
|
||||
|
||||
|
||||
local function _from_length_coded_str(data, pos)
|
||||
local len
|
||||
len, pos = _from_length_coded_bin(data, pos)
|
||||
if not len or len == null then
|
||||
return null, pos
|
||||
end
|
||||
|
||||
return sub(data, pos, pos + len - 1), pos + len
|
||||
end
|
||||
|
||||
|
||||
local function _parse_ok_packet(packet)
|
||||
local res = new_tab(0, 5)
|
||||
local pos
|
||||
|
||||
res.affected_rows, pos = _from_length_coded_bin(packet, 2)
|
||||
|
||||
--print("affected rows: ", res.affected_rows, ", pos:", pos)
|
||||
|
||||
res.insert_id, pos = _from_length_coded_bin(packet, pos)
|
||||
|
||||
--print("insert id: ", res.insert_id, ", pos:", pos)
|
||||
|
||||
res.server_status, pos = _get_byte2(packet, pos)
|
||||
|
||||
--print("server status: ", res.server_status, ", pos:", pos)
|
||||
|
||||
res.warning_count, pos = _get_byte2(packet, pos)
|
||||
|
||||
--print("warning count: ", res.warning_count, ", pos: ", pos)
|
||||
|
||||
local message = _from_length_coded_str(packet, pos)
|
||||
if message and message ~= null then
|
||||
res.message = message
|
||||
end
|
||||
|
||||
--print("message: ", res.message, ", pos:", pos)
|
||||
|
||||
return res
|
||||
end
|
||||
|
||||
|
||||
local function _parse_eof_packet(packet)
|
||||
local pos = 2
|
||||
|
||||
local warning_count, pos = _get_byte2(packet, pos)
|
||||
local status_flags = _get_byte2(packet, pos)
|
||||
|
||||
return warning_count, status_flags
|
||||
end
|
||||
|
||||
|
||||
local function _parse_err_packet(packet)
|
||||
local errno, pos = _get_byte2(packet, 2)
|
||||
local marker = sub(packet, pos, pos)
|
||||
local sqlstate
|
||||
if marker == '#' then
|
||||
-- with sqlstate
|
||||
pos = pos + 1
|
||||
sqlstate = sub(packet, pos, pos + 5 - 1)
|
||||
pos = pos + 5
|
||||
end
|
||||
|
||||
local message = sub(packet, pos)
|
||||
return errno, message, sqlstate
|
||||
end
|
||||
|
||||
|
||||
local function _parse_result_set_header_packet(packet)
|
||||
local field_count, pos = _from_length_coded_bin(packet, 1)
|
||||
|
||||
local extra
|
||||
extra = _from_length_coded_bin(packet, pos)
|
||||
|
||||
return field_count, extra
|
||||
end
|
||||
|
||||
local NOT_NULL_FLAG = 1
|
||||
local PRI_KEY_FLAG = 2
|
||||
local UNIQUE_KEY_FLAG = 4
|
||||
local UNSIGNED_FLAG = 32
|
||||
local AUTO_INCREMENT_FLAG = 512
|
||||
|
||||
local function _parse_field_packet(data)
|
||||
local col = new_tab(0, 2)
|
||||
local pos
|
||||
col.catalog, pos = _from_length_coded_str(data, 1)
|
||||
col.db, pos = _from_length_coded_str(data, pos)
|
||||
col.table, pos = _from_length_coded_str(data, pos)
|
||||
col.orig_table, pos = _from_length_coded_str(data, pos)
|
||||
col.name, pos = _from_length_coded_str(data, pos)
|
||||
col.orig_name, pos = _from_length_coded_str(data, pos)
|
||||
pos = pos + 1 -- ignore the filler
|
||||
col.charsetnr, pos = _get_byte2(data, pos)
|
||||
col.length, pos = _get_byte4(data, pos)
|
||||
col.type = strbyte(data, pos)
|
||||
pos = pos + 1
|
||||
col.flags, pos = _get_byte2(data, pos)
|
||||
col.allow_null = band(col.flags, NOT_NULL_FLAG) == 0
|
||||
col.pri_key = band(col.flags, PRI_KEY_FLAG) ~= 0
|
||||
col.unique_key = band(col.flags, UNIQUE_KEY_FLAG) ~= 0
|
||||
col.unsigned = band(col.flags, UNSIGNED_FLAG) ~= 0
|
||||
col.auto_increment = band(col.flags, AUTO_INCREMENT_FLAG) ~= 0
|
||||
col.decimals = strbyte(data, pos)
|
||||
pos = pos + 1
|
||||
local default = sub(data, pos + 2)
|
||||
if default and default ~= "" then
|
||||
col.default = default
|
||||
end
|
||||
return col
|
||||
end
|
||||
|
||||
|
||||
local function _parse_row_data_packet(data, cols, compact)
|
||||
local pos = 1
|
||||
local ncols = #cols
|
||||
local row
|
||||
if compact then
|
||||
row = new_tab(ncols, 0)
|
||||
else
|
||||
row = new_tab(0, ncols)
|
||||
end
|
||||
for i = 1, ncols do
|
||||
local value
|
||||
value, pos = _from_length_coded_str(data, pos)
|
||||
local col = cols[i]
|
||||
local typ = col.type
|
||||
local name = col.name
|
||||
|
||||
--print("row field value: ", value, ", type: ", typ)
|
||||
|
||||
if value ~= null then
|
||||
local conv = converters[typ]
|
||||
if conv then
|
||||
value = conv(value)
|
||||
end
|
||||
end
|
||||
|
||||
if compact then
|
||||
row[i] = value
|
||||
elseif value ~= null then
|
||||
row[name] = value
|
||||
end
|
||||
end
|
||||
|
||||
return row
|
||||
end
|
||||
|
||||
|
||||
local function _recv_field_packet(self)
|
||||
local packet, typ, err = _recv_packet(self)
|
||||
if not packet then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if typ == "ERR" then
|
||||
local errno, msg, sqlstate = _parse_err_packet(packet)
|
||||
return nil, msg, errno, sqlstate
|
||||
end
|
||||
|
||||
if typ ~= 'DATA' then
|
||||
return nil, "bad field packet type: " .. typ
|
||||
end
|
||||
|
||||
-- typ == 'DATA'
|
||||
|
||||
return _parse_field_packet(packet)
|
||||
end
|
||||
|
||||
|
||||
function _M.new(self)
|
||||
local sock, err = tcp()
|
||||
if not sock then
|
||||
return nil, err
|
||||
end
|
||||
return setmetatable({ sock = sock }, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeout(self, timeout)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:settimeout(timeout)
|
||||
end
|
||||
|
||||
|
||||
function _M.connect(self, opts)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local max_packet_size = opts.max_packet_size
|
||||
if not max_packet_size then
|
||||
max_packet_size = 1024 * 1024 -- default 1 MB
|
||||
end
|
||||
self._max_packet_size = max_packet_size
|
||||
|
||||
local ok, err
|
||||
|
||||
self.compact = opts.compact_arrays
|
||||
|
||||
local database = opts.database or ""
|
||||
local user = opts.user or ""
|
||||
|
||||
local charset = CHARSET_MAP[opts.charset or "_default"]
|
||||
if not charset then
|
||||
return nil, "charset '" .. opts.charset .. "' is not supported"
|
||||
end
|
||||
|
||||
local pool = opts.pool
|
||||
|
||||
local host = opts.host
|
||||
if host then
|
||||
local port = opts.port or 3306
|
||||
if not pool then
|
||||
pool = user .. ":" .. database .. ":" .. host .. ":" .. port
|
||||
end
|
||||
|
||||
ok, err = sock:connect(host, port, { pool = pool })
|
||||
|
||||
else
|
||||
local path = opts.path
|
||||
if not path then
|
||||
return nil, 'neither "host" nor "path" options are specified'
|
||||
end
|
||||
|
||||
if not pool then
|
||||
pool = user .. ":" .. database .. ":" .. path
|
||||
end
|
||||
|
||||
ok, err = sock:connect("unix:" .. path, { pool = pool })
|
||||
end
|
||||
|
||||
if not ok then
|
||||
return nil, 'failed to connect: ' .. err
|
||||
end
|
||||
|
||||
local reused = sock:getreusedtimes()
|
||||
|
||||
if reused and reused > 0 then
|
||||
self.state = STATE_CONNECTED
|
||||
return 1
|
||||
end
|
||||
|
||||
local packet, typ, err = _recv_packet(self)
|
||||
if not packet then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if typ == "ERR" then
|
||||
local errno, msg, sqlstate = _parse_err_packet(packet)
|
||||
return nil, msg, errno, sqlstate
|
||||
end
|
||||
|
||||
self.protocol_ver = strbyte(packet)
|
||||
|
||||
--print("protocol version: ", self.protocol_ver)
|
||||
|
||||
local server_ver, pos = _from_cstring(packet, 2)
|
||||
if not server_ver then
|
||||
return nil, "bad handshake initialization packet: bad server version"
|
||||
end
|
||||
|
||||
--print("server version: ", server_ver)
|
||||
|
||||
self._server_ver = server_ver
|
||||
|
||||
local thread_id, pos = _get_byte4(packet, pos)
|
||||
|
||||
--print("thread id: ", thread_id)
|
||||
|
||||
local scramble = sub(packet, pos, pos + 8 - 1)
|
||||
if not scramble then
|
||||
return nil, "1st part of scramble not found"
|
||||
end
|
||||
|
||||
pos = pos + 9 -- skip filler
|
||||
|
||||
-- two lower bytes
|
||||
local capabilities -- server capabilities
|
||||
capabilities, pos = _get_byte2(packet, pos)
|
||||
|
||||
-- print(format("server capabilities: %#x", capabilities))
|
||||
|
||||
self._server_lang = strbyte(packet, pos)
|
||||
pos = pos + 1
|
||||
|
||||
--print("server lang: ", self._server_lang)
|
||||
|
||||
self._server_status, pos = _get_byte2(packet, pos)
|
||||
|
||||
--print("server status: ", self._server_status)
|
||||
|
||||
local more_capabilities
|
||||
more_capabilities, pos = _get_byte2(packet, pos)
|
||||
|
||||
capabilities = bor(capabilities, lshift(more_capabilities, 16))
|
||||
|
||||
--print("server capabilities: ", capabilities)
|
||||
|
||||
-- local len = strbyte(packet, pos)
|
||||
local len = 21 - 8 - 1
|
||||
|
||||
--print("scramble len: ", len)
|
||||
|
||||
pos = pos + 1 + 10
|
||||
|
||||
local scramble_part2 = sub(packet, pos, pos + len - 1)
|
||||
if not scramble_part2 then
|
||||
return nil, "2nd part of scramble not found"
|
||||
end
|
||||
|
||||
scramble = scramble .. scramble_part2
|
||||
--print("scramble: ", _dump(scramble))
|
||||
|
||||
local client_flags = 0x3f7cf;
|
||||
|
||||
local ssl_verify = opts.ssl_verify
|
||||
local use_ssl = opts.ssl or ssl_verify
|
||||
|
||||
if use_ssl then
|
||||
if band(capabilities, CLIENT_SSL) == 0 then
|
||||
return nil, "ssl disabled on server"
|
||||
end
|
||||
|
||||
-- send a SSL Request Packet
|
||||
local req = _set_byte4(bor(client_flags, CLIENT_SSL))
|
||||
.. _set_byte4(self._max_packet_size)
|
||||
.. strchar(charset)
|
||||
.. strrep("\0", 23)
|
||||
|
||||
local packet_len = 4 + 4 + 1 + 23
|
||||
local bytes, err = _send_packet(self, req, packet_len)
|
||||
if not bytes then
|
||||
return nil, "failed to send client authentication packet: " .. err
|
||||
end
|
||||
|
||||
local ok, err = sock:sslhandshake(false, nil, ssl_verify)
|
||||
if not ok then
|
||||
return nil, "failed to do ssl handshake: " .. (err or "")
|
||||
end
|
||||
end
|
||||
|
||||
local password = opts.password or ""
|
||||
|
||||
local token = _compute_token(password, scramble)
|
||||
|
||||
--print("token: ", _dump(token))
|
||||
|
||||
local req = _set_byte4(client_flags)
|
||||
.. _set_byte4(self._max_packet_size)
|
||||
.. strchar(charset)
|
||||
.. strrep("\0", 23)
|
||||
.. _to_cstring(user)
|
||||
.. _to_binary_coded_string(token)
|
||||
.. _to_cstring(database)
|
||||
|
||||
local packet_len = 4 + 4 + 1 + 23 + #user + 1
|
||||
+ #token + 1 + #database + 1
|
||||
|
||||
-- print("packet content length: ", packet_len)
|
||||
-- print("packet content: ", _dump(concat(req, "")))
|
||||
|
||||
local bytes, err = _send_packet(self, req, packet_len)
|
||||
if not bytes then
|
||||
return nil, "failed to send client authentication packet: " .. err
|
||||
end
|
||||
|
||||
--print("packet sent ", bytes, " bytes")
|
||||
|
||||
local packet, typ, err = _recv_packet(self)
|
||||
if not packet then
|
||||
return nil, "failed to receive the result packet: " .. err
|
||||
end
|
||||
|
||||
if typ == 'ERR' then
|
||||
local errno, msg, sqlstate = _parse_err_packet(packet)
|
||||
return nil, msg, errno, sqlstate
|
||||
end
|
||||
|
||||
if typ == 'EOF' then
|
||||
return nil, "old pre-4.1 authentication protocol not supported"
|
||||
end
|
||||
|
||||
if typ ~= 'OK' then
|
||||
return nil, "bad packet type: " .. typ
|
||||
end
|
||||
|
||||
self.state = STATE_CONNECTED
|
||||
|
||||
return 1
|
||||
end
|
||||
|
||||
|
||||
function _M.set_keepalive(self, ...)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
if self.state ~= STATE_CONNECTED then
|
||||
return nil, "cannot be reused in the current connection state: "
|
||||
.. (self.state or "nil")
|
||||
end
|
||||
|
||||
self.state = nil
|
||||
return sock:setkeepalive(...)
|
||||
end
|
||||
|
||||
|
||||
function _M.get_reused_times(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:getreusedtimes()
|
||||
end
|
||||
|
||||
|
||||
function _M.close(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
self.state = nil
|
||||
|
||||
local bytes, err = _send_packet(self, strchar(COM_QUIT), 1)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return sock:close()
|
||||
end
|
||||
|
||||
|
||||
function _M.server_ver(self)
|
||||
return self._server_ver
|
||||
end
|
||||
|
||||
|
||||
local function send_query(self, query)
|
||||
if self.state ~= STATE_CONNECTED then
|
||||
return nil, "cannot send query in the current context: "
|
||||
.. (self.state or "nil")
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
self.packet_no = -1
|
||||
|
||||
local cmd_packet = strchar(COM_QUERY) .. query
|
||||
local packet_len = 1 + #query
|
||||
|
||||
local bytes, err = _send_packet(self, cmd_packet, packet_len)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
self.state = STATE_COMMAND_SENT
|
||||
|
||||
--print("packet sent ", bytes, " bytes")
|
||||
|
||||
return bytes
|
||||
end
|
||||
_M.send_query = send_query
|
||||
|
||||
|
||||
local function read_result(self, est_nrows)
|
||||
if self.state ~= STATE_COMMAND_SENT then
|
||||
return nil, "cannot read result in the current context: "
|
||||
.. (self.state or "nil")
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local packet, typ, err = _recv_packet(self)
|
||||
if not packet then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if typ == "ERR" then
|
||||
self.state = STATE_CONNECTED
|
||||
|
||||
local errno, msg, sqlstate = _parse_err_packet(packet)
|
||||
return nil, msg, errno, sqlstate
|
||||
end
|
||||
|
||||
if typ == 'OK' then
|
||||
local res = _parse_ok_packet(packet)
|
||||
if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
|
||||
return res, "again"
|
||||
end
|
||||
|
||||
self.state = STATE_CONNECTED
|
||||
return res
|
||||
end
|
||||
|
||||
if typ ~= 'DATA' then
|
||||
self.state = STATE_CONNECTED
|
||||
|
||||
return nil, "packet type " .. typ .. " not supported"
|
||||
end
|
||||
|
||||
-- typ == 'DATA'
|
||||
|
||||
--print("read the result set header packet")
|
||||
|
||||
local field_count, extra = _parse_result_set_header_packet(packet)
|
||||
|
||||
--print("field count: ", field_count)
|
||||
|
||||
local cols = new_tab(field_count, 0)
|
||||
for i = 1, field_count do
|
||||
local col, err, errno, sqlstate = _recv_field_packet(self)
|
||||
if not col then
|
||||
return nil, err, errno, sqlstate
|
||||
end
|
||||
|
||||
cols[i] = col
|
||||
end
|
||||
|
||||
local packet, typ, err = _recv_packet(self)
|
||||
if not packet then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if typ ~= 'EOF' then
|
||||
return nil, "unexpected packet type " .. typ .. " while eof packet is "
|
||||
.. "expected"
|
||||
end
|
||||
|
||||
-- typ == 'EOF'
|
||||
|
||||
local compact = self.compact
|
||||
|
||||
local rows = new_tab(est_nrows or 4, 0)
|
||||
local i = 0
|
||||
while true do
|
||||
--print("reading a row")
|
||||
|
||||
packet, typ, err = _recv_packet(self)
|
||||
if not packet then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if typ == 'EOF' then
|
||||
local warning_count, status_flags = _parse_eof_packet(packet)
|
||||
|
||||
--print("status flags: ", status_flags)
|
||||
|
||||
if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
|
||||
return rows, "again", cols
|
||||
end
|
||||
|
||||
break
|
||||
end
|
||||
|
||||
-- if typ ~= 'DATA' then
|
||||
-- return nil, 'bad row packet type: ' .. typ
|
||||
-- end
|
||||
|
||||
-- typ == 'DATA'
|
||||
|
||||
local row = _parse_row_data_packet(packet, cols, compact)
|
||||
i = i + 1
|
||||
rows[i] = row
|
||||
end
|
||||
|
||||
self.state = STATE_CONNECTED
|
||||
|
||||
return rows, nil, cols
|
||||
end
|
||||
_M.read_result = read_result
|
||||
|
||||
|
||||
function _M.query(self, query, est_nrows)
|
||||
local bytes, err = send_query(self, query)
|
||||
if not bytes then
|
||||
return nil, "failed to send query: " .. err
|
||||
end
|
||||
|
||||
return read_result(self, est_nrows)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_compact_arrays(self, value)
|
||||
self.compact = value
|
||||
end
|
||||
|
||||
local qmap = {
|
||||
['\0' ] = '\\0',
|
||||
['\b' ] = '\\b',
|
||||
['\n' ] = '\\n',
|
||||
['\r' ] = '\\r',
|
||||
['\t' ] = '\\t',
|
||||
['\26'] = '\\Z',
|
||||
['\\' ] = '\\\\',
|
||||
['\'' ] = '\\\'',
|
||||
['\"' ] = '\\"',
|
||||
}
|
||||
function _M.quote(s)
|
||||
return s:gsub('[%z\b\n\r\t\26\\\'\"]', qmap)
|
||||
end
|
||||
|
||||
return _M
|
Loading…
Reference in New Issue
Block a user