mysql/mysql_client.lua

915 lines
21 KiB
Lua
Raw Normal View History

2021-04-29 21:47:39 +02:00
2021-04-29 21:50:40 +02:00
-- MySQL client protocol in Lua.
2021-04-29 21:51:58 +02:00
-- Written by Yichun Zhang (agentzh). BSD license.
2021-04-29 21:24:55 +02:00
local tcp = require'sock'.tcp
2021-04-30 00:27:35 +02:00
local sha1 = require'sha1'.sha1
2021-04-29 21:24:55 +02:00
local bit = require'bit'
local sub = string.sub
local strbyte = string.byte
local strchar = string.char
local strfind = string.find
local format = string.format
local strrep = string.rep
local null = function() end
local band = bit.band
local bxor = bit.bxor
local bor = bit.bor
local lshift = bit.lshift
local rshift = bit.rshift
local tohex = bit.tohex
local concat = table.concat
local unpack = unpack
local setmetatable = setmetatable
local error = error
local tonumber = tonumber
local ok, new_tab = pcall(require, 'table.new')
if not ok then
new_tab = function (narr, nrec) return {} end
end
local _M = { _VERSION = '0.21' }
-- constants
local STATE_CONNECTED = 1
local STATE_COMMAND_SENT = 2
local COM_QUIT = 0x01
local COM_QUERY = 0x03
local CLIENT_SSL = 0x0800
local SERVER_MORE_RESULTS_EXISTS = 8
-- 16MB - 1, the default max allowed packet size used by libmysqlclient
local FULL_PACKET_SIZE = 16777215
-- the following charset map is generated from the following mysql query:
-- SELECT CHARACTER_SET_NAME, ID
-- FROM information_schema.collations
-- WHERE IS_DEFAULT = 'Yes' ORDER BY id;
local CHARSET_MAP = {
_default = 0,
big5 = 1,
dec8 = 3,
cp850 = 4,
hp8 = 6,
koi8r = 7,
latin1 = 8,
latin2 = 9,
swe7 = 10,
ascii = 11,
ujis = 12,
sjis = 13,
hebrew = 16,
tis620 = 18,
euckr = 19,
koi8u = 22,
gb2312 = 24,
greek = 25,
cp1250 = 26,
gbk = 28,
latin5 = 30,
armscii8 = 32,
utf8 = 33,
ucs2 = 35,
cp866 = 36,
keybcs2 = 37,
macce = 38,
macroman = 39,
cp852 = 40,
latin7 = 41,
utf8mb4 = 45,
cp1251 = 51,
utf16 = 54,
utf16le = 56,
cp1256 = 57,
cp1257 = 59,
utf32 = 60,
binary = 63,
geostd8 = 92,
cp932 = 95,
eucjpms = 97,
gb18030 = 248
}
local mt = { __index = _M }
-- mysql field value type converters
local converters = new_tab(0, 9)
for i = 0x01, 0x05 do
-- tiny, short, long, float, double
converters[i] = tonumber
end
converters[0x00] = tonumber -- decimal
-- converters[0x08] = tonumber -- long long
converters[0x09] = tonumber -- int24
converters[0x0d] = tonumber -- year
converters[0xf6] = tonumber -- newdecimal
local function _get_byte2(data, i)
local a, b = strbyte(data, i, i + 1)
return bor(a, lshift(b, 8)), i + 2
end
local function _get_byte3(data, i)
local a, b, c = strbyte(data, i, i + 2)
return bor(a, lshift(b, 8), lshift(c, 16)), i + 3
end
local function _get_byte4(data, i)
local a, b, c, d = strbyte(data, i, i + 3)
return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4
end
local function _get_byte8(data, i)
local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7)
-- XXX workaround for the lack of 64-bit support in bitop:
local lo = bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24))
local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24))
return lo + hi * 4294967296, i + 8
-- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32),
-- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8
end
local function _set_byte2(n)
return strchar(band(n, 0xff), band(rshift(n, 8), 0xff))
end
local function _set_byte3(n)
return strchar(band(n, 0xff),
band(rshift(n, 8), 0xff),
band(rshift(n, 16), 0xff))
end
local function _set_byte4(n)
return strchar(band(n, 0xff),
band(rshift(n, 8), 0xff),
band(rshift(n, 16), 0xff),
band(rshift(n, 24), 0xff))
end
local function _from_cstring(data, i)
2021-04-30 00:27:35 +02:00
local last = strfind(data, '\0', i, true)
2021-04-29 21:24:55 +02:00
if not last then
return nil, nil
end
return sub(data, i, last - 1), last + 1
end
local function _to_cstring(data)
2021-04-30 00:27:35 +02:00
return data .. '\0'
2021-04-29 21:24:55 +02:00
end
local function _to_binary_coded_string(data)
return strchar(#data) .. data
end
local function _dump(data)
local len = #data
local bytes = new_tab(len, 0)
for i = 1, len do
2021-04-30 00:27:35 +02:00
bytes[i] = format('%x', strbyte(data, i))
2021-04-29 21:24:55 +02:00
end
2021-04-30 00:27:35 +02:00
return concat(bytes, ' ')
2021-04-29 21:24:55 +02:00
end
local function _dumphex(data)
local len = #data
local bytes = new_tab(len, 0)
for i = 1, len do
bytes[i] = tohex(strbyte(data, i), 2)
end
2021-04-30 00:27:35 +02:00
return concat(bytes, ' ')
2021-04-29 21:24:55 +02:00
end
local function _compute_token(password, scramble)
2021-04-30 00:27:35 +02:00
if password == '' then
return ''
2021-04-29 21:24:55 +02:00
end
local stage1 = sha1(password)
local stage2 = sha1(stage1)
local stage3 = sha1(scramble .. stage2)
local n = #stage1
local bytes = new_tab(n, 0)
for i = 1, n do
bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i)))
end
return concat(bytes)
end
local function _send_packet(self, req, size)
local sock = self.sock
self.packet_no = self.packet_no + 1
2021-04-30 00:27:35 +02:00
-- print('packet no: ', self.packet_no)
2021-04-29 21:24:55 +02:00
local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req
2021-04-30 00:27:35 +02:00
-- print('sending packet: ', _dump(packet))
2021-04-29 21:24:55 +02:00
2021-04-30 00:27:35 +02:00
-- print('sending packet... of size ' .. #packet)
2021-04-29 21:24:55 +02:00
return sock:send(packet)
end
2021-04-30 00:27:35 +02:00
--static, auto-growing buffer allocation pattern (ctype must be vla).
local function grow_buffer(ctype)
local vla = ffi.typeof(ctype)
local buf, len = nil, -1
return function(minlen)
if minlen == false then
buf, len = nil, -1
elseif minlen > len then
len = glue.nextpow2(minlen)
buf = vla(len)
end
return buf, len
end
end
local function _recv(self, sz)
local buf = self.buf
if not buf then
buf = grow_buffer'char[?]'
self.buf = buf
end
local buf = buf(sz)
local sock = self.sock
local offset = 0
while sz > 0 do
local n, err = sock:recv(buf + offset, sz)
if not n then return nil, err end
sz = sz - n
offset = offset + n
end
return ffi.string(buf, offset)
end
2021-04-29 21:24:55 +02:00
local function _recv_packet(self)
local sock = self.sock
2021-04-30 00:27:35 +02:00
local data, err = _recv(self, 4) -- packet header
2021-04-29 21:24:55 +02:00
if not data then
2021-04-30 00:27:35 +02:00
return nil, nil, 'failed to receive packet header: ' .. err
2021-04-29 21:24:55 +02:00
end
2021-04-30 00:27:35 +02:00
--print('packet header: ', _dump(data))
2021-04-29 21:24:55 +02:00
local len, pos = _get_byte3(data, 1)
2021-04-30 00:27:35 +02:00
--print('packet length: ', len)
2021-04-29 21:24:55 +02:00
if len == 0 then
2021-04-30 00:27:35 +02:00
return nil, nil, 'empty packet'
2021-04-29 21:24:55 +02:00
end
if len > self._max_packet_size then
2021-04-30 00:27:35 +02:00
return nil, nil, 'packet size too big: ' .. len
2021-04-29 21:24:55 +02:00
end
local num = strbyte(data, pos)
2021-04-30 00:27:35 +02:00
--print('recv packet: packet no: ', num)
2021-04-29 21:24:55 +02:00
self.packet_no = num
2021-04-30 00:27:35 +02:00
data, err = _recv(self, len)
2021-04-29 21:24:55 +02:00
2021-04-30 00:27:35 +02:00
--print('receive returned')
2021-04-29 21:24:55 +02:00
if not data then
2021-04-30 00:27:35 +02:00
return nil, nil, 'failed to read packet content: ' .. err
2021-04-29 21:24:55 +02:00
end
2021-04-30 00:27:35 +02:00
--print('packet content: ', _dump(data))
--print('packet content (ascii): ', data)
2021-04-29 21:24:55 +02:00
local field_count = strbyte(data, 1)
local typ
if field_count == 0x00 then
2021-04-30 00:27:35 +02:00
typ = 'OK'
2021-04-29 21:24:55 +02:00
elseif field_count == 0xff then
2021-04-30 00:27:35 +02:00
typ = 'ERR'
2021-04-29 21:24:55 +02:00
elseif field_count == 0xfe then
2021-04-30 00:27:35 +02:00
typ = 'EOF'
2021-04-29 21:24:55 +02:00
else
2021-04-30 00:27:35 +02:00
typ = 'DATA'
2021-04-29 21:24:55 +02:00
end
return data, typ
end
local function _from_length_coded_bin(data, pos)
local first = strbyte(data, pos)
2021-04-30 00:27:35 +02:00
--print('LCB: first: ', first)
2021-04-29 21:24:55 +02:00
if not first then
return nil, pos
end
if first >= 0 and first <= 250 then
return first, pos + 1
end
if first == 251 then
return null, pos + 1
end
if first == 252 then
pos = pos + 1
return _get_byte2(data, pos)
end
if first == 253 then
pos = pos + 1
return _get_byte3(data, pos)
end
if first == 254 then
pos = pos + 1
return _get_byte8(data, pos)
end
return nil, pos + 1
end
local function _from_length_coded_str(data, pos)
local len
len, pos = _from_length_coded_bin(data, pos)
if not len or len == null then
return null, pos
end
return sub(data, pos, pos + len - 1), pos + len
end
local function _parse_ok_packet(packet)
local res = new_tab(0, 5)
local pos
res.affected_rows, pos = _from_length_coded_bin(packet, 2)
2021-04-30 00:27:35 +02:00
--print('affected rows: ', res.affected_rows, ', pos:', pos)
2021-04-29 21:24:55 +02:00
res.insert_id, pos = _from_length_coded_bin(packet, pos)
2021-04-30 00:27:35 +02:00
--print('insert id: ', res.insert_id, ', pos:', pos)
2021-04-29 21:24:55 +02:00
res.server_status, pos = _get_byte2(packet, pos)
2021-04-30 00:27:35 +02:00
--print('server status: ', res.server_status, ', pos:', pos)
2021-04-29 21:24:55 +02:00
res.warning_count, pos = _get_byte2(packet, pos)
2021-04-30 00:27:35 +02:00
--print('warning count: ', res.warning_count, ', pos: ', pos)
2021-04-29 21:24:55 +02:00
local message = _from_length_coded_str(packet, pos)
if message and message ~= null then
res.message = message
end
2021-04-30 00:27:35 +02:00
--print('message: ', res.message, ', pos:', pos)
2021-04-29 21:24:55 +02:00
return res
end
local function _parse_eof_packet(packet)
local pos = 2
local warning_count, pos = _get_byte2(packet, pos)
local status_flags = _get_byte2(packet, pos)
return warning_count, status_flags
end
local function _parse_err_packet(packet)
local errno, pos = _get_byte2(packet, 2)
local marker = sub(packet, pos, pos)
local sqlstate
if marker == '#' then
-- with sqlstate
pos = pos + 1
sqlstate = sub(packet, pos, pos + 5 - 1)
pos = pos + 5
end
local message = sub(packet, pos)
return errno, message, sqlstate
end
local function _parse_result_set_header_packet(packet)
local field_count, pos = _from_length_coded_bin(packet, 1)
local extra
extra = _from_length_coded_bin(packet, pos)
return field_count, extra
end
local NOT_NULL_FLAG = 1
local PRI_KEY_FLAG = 2
local UNIQUE_KEY_FLAG = 4
local UNSIGNED_FLAG = 32
local AUTO_INCREMENT_FLAG = 512
local function _parse_field_packet(data)
local col = new_tab(0, 2)
local pos
col.catalog, pos = _from_length_coded_str(data, 1)
col.db, pos = _from_length_coded_str(data, pos)
col.table, pos = _from_length_coded_str(data, pos)
col.orig_table, pos = _from_length_coded_str(data, pos)
col.name, pos = _from_length_coded_str(data, pos)
col.orig_name, pos = _from_length_coded_str(data, pos)
pos = pos + 1 -- ignore the filler
col.charsetnr, pos = _get_byte2(data, pos)
col.length, pos = _get_byte4(data, pos)
col.type = strbyte(data, pos)
pos = pos + 1
col.flags, pos = _get_byte2(data, pos)
col.allow_null = band(col.flags, NOT_NULL_FLAG) == 0
col.pri_key = band(col.flags, PRI_KEY_FLAG) ~= 0
col.unique_key = band(col.flags, UNIQUE_KEY_FLAG) ~= 0
col.unsigned = band(col.flags, UNSIGNED_FLAG) ~= 0
col.auto_increment = band(col.flags, AUTO_INCREMENT_FLAG) ~= 0
col.decimals = strbyte(data, pos)
pos = pos + 1
local default = sub(data, pos + 2)
2021-04-30 00:27:35 +02:00
if default and default ~= '' then
2021-04-29 21:24:55 +02:00
col.default = default
end
return col
end
local function _parse_row_data_packet(data, cols, compact)
local pos = 1
local ncols = #cols
local row
if compact then
row = new_tab(ncols, 0)
else
row = new_tab(0, ncols)
end
for i = 1, ncols do
local value
value, pos = _from_length_coded_str(data, pos)
local col = cols[i]
local typ = col.type
local name = col.name
2021-04-30 00:27:35 +02:00
--print('row field value: ', value, ', type: ', typ)
2021-04-29 21:24:55 +02:00
if value ~= null then
local conv = converters[typ]
if conv then
value = conv(value)
end
end
if compact then
row[i] = value
elseif value ~= null then
row[name] = value
end
end
return row
end
local function _recv_field_packet(self)
local packet, typ, err = _recv_packet(self)
if not packet then
return nil, err
end
2021-04-30 00:27:35 +02:00
if typ == 'ERR' then
2021-04-29 21:24:55 +02:00
local errno, msg, sqlstate = _parse_err_packet(packet)
return nil, msg, errno, sqlstate
end
if typ ~= 'DATA' then
2021-04-30 00:27:35 +02:00
return nil, 'bad field packet type: ' .. typ
2021-04-29 21:24:55 +02:00
end
-- typ == 'DATA'
return _parse_field_packet(packet)
end
function _M.new(self)
local sock, err = tcp()
if not sock then
return nil, err
end
return setmetatable({ sock = sock }, mt)
end
function _M.connect(self, opts)
local sock = self.sock
if not sock then
2021-04-30 00:27:35 +02:00
return nil, 'not initialized'
2021-04-29 21:24:55 +02:00
end
local max_packet_size = opts.max_packet_size
if not max_packet_size then
max_packet_size = 1024 * 1024 -- default 1 MB
end
self._max_packet_size = max_packet_size
local ok, err
self.compact = opts.compact_arrays
2021-04-30 00:27:35 +02:00
local database = opts.database or ''
local user = opts.user or ''
2021-04-29 21:24:55 +02:00
2021-04-30 00:27:35 +02:00
local charset = CHARSET_MAP[opts.charset or '_default']
2021-04-29 21:24:55 +02:00
if not charset then
2021-04-30 00:27:35 +02:00
return nil, 'charset \'' .. opts.charset .. '\' is not supported'
2021-04-29 21:24:55 +02:00
end
local host = opts.host
2021-04-30 00:27:35 +02:00
local port = opts.port or 3306
ok, err = sock:connect(host, port)
2021-04-29 21:24:55 +02:00
if not ok then
return nil, 'failed to connect: ' .. err
end
local packet, typ, err = _recv_packet(self)
if not packet then
return nil, err
end
2021-04-30 00:27:35 +02:00
if typ == 'ERR' then
2021-04-29 21:24:55 +02:00
local errno, msg, sqlstate = _parse_err_packet(packet)
return nil, msg, errno, sqlstate
end
self.protocol_ver = strbyte(packet)
2021-04-30 00:27:35 +02:00
--print('protocol version: ', self.protocol_ver)
2021-04-29 21:24:55 +02:00
local server_ver, pos = _from_cstring(packet, 2)
if not server_ver then
2021-04-30 00:27:35 +02:00
return nil, 'bad handshake initialization packet: bad server version'
2021-04-29 21:24:55 +02:00
end
2021-04-30 00:27:35 +02:00
--print('server version: ', server_ver)
2021-04-29 21:24:55 +02:00
self._server_ver = server_ver
local thread_id, pos = _get_byte4(packet, pos)
2021-04-30 00:27:35 +02:00
--print('thread id: ', thread_id)
2021-04-29 21:24:55 +02:00
local scramble = sub(packet, pos, pos + 8 - 1)
if not scramble then
2021-04-30 00:27:35 +02:00
return nil, '1st part of scramble not found'
2021-04-29 21:24:55 +02:00
end
pos = pos + 9 -- skip filler
-- two lower bytes
local capabilities -- server capabilities
capabilities, pos = _get_byte2(packet, pos)
2021-04-30 00:27:35 +02:00
-- print(format('server capabilities: %#x', capabilities))
2021-04-29 21:24:55 +02:00
self._server_lang = strbyte(packet, pos)
pos = pos + 1
2021-04-30 00:27:35 +02:00
--print('server lang: ', self._server_lang)
2021-04-29 21:24:55 +02:00
self._server_status, pos = _get_byte2(packet, pos)
2021-04-30 00:27:35 +02:00
--print('server status: ', self._server_status)
2021-04-29 21:24:55 +02:00
local more_capabilities
more_capabilities, pos = _get_byte2(packet, pos)
capabilities = bor(capabilities, lshift(more_capabilities, 16))
2021-04-30 00:27:35 +02:00
--print('server capabilities: ', capabilities)
2021-04-29 21:24:55 +02:00
-- local len = strbyte(packet, pos)
local len = 21 - 8 - 1
2021-04-30 00:27:35 +02:00
--print('scramble len: ', len)
2021-04-29 21:24:55 +02:00
pos = pos + 1 + 10
local scramble_part2 = sub(packet, pos, pos + len - 1)
if not scramble_part2 then
2021-04-30 00:27:35 +02:00
return nil, '2nd part of scramble not found'
2021-04-29 21:24:55 +02:00
end
scramble = scramble .. scramble_part2
2021-04-30 00:27:35 +02:00
--print('scramble: ', _dump(scramble))
2021-04-29 21:24:55 +02:00
local client_flags = 0x3f7cf;
local ssl_verify = opts.ssl_verify
local use_ssl = opts.ssl or ssl_verify
if use_ssl then
if band(capabilities, CLIENT_SSL) == 0 then
2021-04-30 00:27:35 +02:00
return nil, 'ssl disabled on server'
2021-04-29 21:24:55 +02:00
end
-- send a SSL Request Packet
local req = _set_byte4(bor(client_flags, CLIENT_SSL))
.. _set_byte4(self._max_packet_size)
.. strchar(charset)
2021-04-30 00:27:35 +02:00
.. strrep('\0', 23)
2021-04-29 21:24:55 +02:00
local packet_len = 4 + 4 + 1 + 23
local bytes, err = _send_packet(self, req, packet_len)
if not bytes then
2021-04-30 00:27:35 +02:00
return nil, 'failed to send client authentication packet: ' .. err
2021-04-29 21:24:55 +02:00
end
local ok, err = sock:sslhandshake(false, nil, ssl_verify)
if not ok then
2021-04-30 00:27:35 +02:00
return nil, 'failed to do ssl handshake: ' .. (err or '')
2021-04-29 21:24:55 +02:00
end
end
2021-04-30 00:27:35 +02:00
local password = opts.password or ''
2021-04-29 21:24:55 +02:00
local token = _compute_token(password, scramble)
2021-04-30 00:27:35 +02:00
--print('token: ', _dump(token))
2021-04-29 21:24:55 +02:00
local req = _set_byte4(client_flags)
.. _set_byte4(self._max_packet_size)
.. strchar(charset)
2021-04-30 00:27:35 +02:00
.. strrep('\0', 23)
2021-04-29 21:24:55 +02:00
.. _to_cstring(user)
.. _to_binary_coded_string(token)
.. _to_cstring(database)
local packet_len = 4 + 4 + 1 + 23 + #user + 1
+ #token + 1 + #database + 1
2021-04-30 00:27:35 +02:00
-- print('packet content length: ', packet_len)
-- print('packet content: ', _dump(concat(req, '')))
2021-04-29 21:24:55 +02:00
local bytes, err = _send_packet(self, req, packet_len)
if not bytes then
2021-04-30 00:27:35 +02:00
return nil, 'failed to send client authentication packet: ' .. err
2021-04-29 21:24:55 +02:00
end
2021-04-30 00:27:35 +02:00
--print('packet sent ', bytes, ' bytes')
2021-04-29 21:24:55 +02:00
local packet, typ, err = _recv_packet(self)
if not packet then
2021-04-30 00:27:35 +02:00
return nil, 'failed to receive the result packet: ' .. err
2021-04-29 21:24:55 +02:00
end
if typ == 'ERR' then
local errno, msg, sqlstate = _parse_err_packet(packet)
return nil, msg, errno, sqlstate
end
if typ == 'EOF' then
2021-04-30 00:27:35 +02:00
return nil, 'old pre-4.1 authentication protocol not supported'
2021-04-29 21:24:55 +02:00
end
if typ ~= 'OK' then
2021-04-30 00:27:35 +02:00
return nil, 'bad packet type: ' .. typ
2021-04-29 21:24:55 +02:00
end
self.state = STATE_CONNECTED
return 1
end
function _M.close(self)
local sock = self.sock
if not sock then
2021-04-30 00:27:35 +02:00
return nil, 'not initialized'
2021-04-29 21:24:55 +02:00
end
self.state = nil
local bytes, err = _send_packet(self, strchar(COM_QUIT), 1)
if not bytes then
return nil, err
end
return sock:close()
end
function _M.server_ver(self)
return self._server_ver
end
local function send_query(self, query)
if self.state ~= STATE_CONNECTED then
2021-04-30 00:27:35 +02:00
return nil, 'cannot send query in the current context: '
.. (self.state or 'nil')
2021-04-29 21:24:55 +02:00
end
local sock = self.sock
if not sock then
2021-04-30 00:27:35 +02:00
return nil, 'not initialized'
2021-04-29 21:24:55 +02:00
end
self.packet_no = -1
local cmd_packet = strchar(COM_QUERY) .. query
local packet_len = 1 + #query
local bytes, err = _send_packet(self, cmd_packet, packet_len)
if not bytes then
return nil, err
end
self.state = STATE_COMMAND_SENT
2021-04-30 00:27:35 +02:00
--print('packet sent ', bytes, ' bytes')
2021-04-29 21:24:55 +02:00
return bytes
end
_M.send_query = send_query
local function read_result(self, est_nrows)
if self.state ~= STATE_COMMAND_SENT then
2021-04-30 00:27:35 +02:00
return nil, 'cannot read result in the current context: '
.. (self.state or 'nil')
2021-04-29 21:24:55 +02:00
end
local sock = self.sock
if not sock then
2021-04-30 00:27:35 +02:00
return nil, 'not initialized'
2021-04-29 21:24:55 +02:00
end
local packet, typ, err = _recv_packet(self)
if not packet then
return nil, err
end
2021-04-30 00:27:35 +02:00
if typ == 'ERR' then
2021-04-29 21:24:55 +02:00
self.state = STATE_CONNECTED
local errno, msg, sqlstate = _parse_err_packet(packet)
return nil, msg, errno, sqlstate
end
if typ == 'OK' then
local res = _parse_ok_packet(packet)
if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
2021-04-30 00:27:35 +02:00
return res, 'again'
2021-04-29 21:24:55 +02:00
end
self.state = STATE_CONNECTED
return res
end
if typ ~= 'DATA' then
self.state = STATE_CONNECTED
2021-04-30 00:27:35 +02:00
return nil, 'packet type ' .. typ .. ' not supported'
2021-04-29 21:24:55 +02:00
end
-- typ == 'DATA'
2021-04-30 00:27:35 +02:00
--print('read the result set header packet')
2021-04-29 21:24:55 +02:00
local field_count, extra = _parse_result_set_header_packet(packet)
2021-04-30 00:27:35 +02:00
--print('field count: ', field_count)
2021-04-29 21:24:55 +02:00
local cols = new_tab(field_count, 0)
for i = 1, field_count do
local col, err, errno, sqlstate = _recv_field_packet(self)
if not col then
return nil, err, errno, sqlstate
end
cols[i] = col
end
local packet, typ, err = _recv_packet(self)
if not packet then
return nil, err
end
if typ ~= 'EOF' then
2021-04-30 00:27:35 +02:00
return nil, 'unexpected packet type ' .. typ .. ' while eof packet is '
.. 'expected'
2021-04-29 21:24:55 +02:00
end
-- typ == 'EOF'
local compact = self.compact
local rows = new_tab(est_nrows or 4, 0)
local i = 0
while true do
2021-04-30 00:27:35 +02:00
--print('reading a row')
2021-04-29 21:24:55 +02:00
packet, typ, err = _recv_packet(self)
if not packet then
return nil, err
end
if typ == 'EOF' then
local warning_count, status_flags = _parse_eof_packet(packet)
2021-04-30 00:27:35 +02:00
--print('status flags: ', status_flags)
2021-04-29 21:24:55 +02:00
if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
2021-04-30 00:27:35 +02:00
return rows, 'again', cols
2021-04-29 21:24:55 +02:00
end
break
end
-- if typ ~= 'DATA' then
-- return nil, 'bad row packet type: ' .. typ
-- end
-- typ == 'DATA'
local row = _parse_row_data_packet(packet, cols, compact)
i = i + 1
rows[i] = row
end
self.state = STATE_CONNECTED
return rows, nil, cols
end
_M.read_result = read_result
function _M.query(self, query, est_nrows)
local bytes, err = send_query(self, query)
if not bytes then
2021-04-30 00:27:35 +02:00
return nil, 'failed to send query: ' .. err
2021-04-29 21:24:55 +02:00
end
return read_result(self, est_nrows)
end
function _M.set_compact_arrays(self, value)
self.compact = value
end
local qmap = {
['\0' ] = '\\0',
['\b' ] = '\\b',
['\n' ] = '\\n',
['\r' ] = '\\r',
['\t' ] = '\\t',
['\26'] = '\\Z',
['\\' ] = '\\\\',
['\'' ] = '\\\'',
['\"' ] = '\\"',
}
function _M.quote(s)
return s:gsub('[%z\b\n\r\t\26\\\'\"]', qmap)
end
return _M