diff --git a/mysql_client.lua b/mysql_client.lua index 19e540a..f841aad 100644 --- a/mysql_client.lua +++ b/mysql_client.lua @@ -1,8 +1,8 @@ -- MySQL client protocol in Lua. -- Written by Yichun Zhang (agentzh). BSD license. +-- Modified by Cosmin Apreutesei. Pulbic domain. -local tcp = require'sock'.tcp local sha1 = require'sha1'.sha1 local bit = require'bit' @@ -26,12 +26,9 @@ local error = error local tonumber = tonumber local ok, new_tab = pcall(require, 'table.new') -if not ok then - new_tab = function (narr, nrec) return {} end -end - -local _M = { _VERSION = '0.21' } +new_tab = ok and new_tab or function() return {} end +local mysql = {} -- constants @@ -52,59 +49,59 @@ local FULL_PACKET_SIZE = 16777215 -- FROM information_schema.collations -- WHERE IS_DEFAULT = 'Yes' ORDER BY id; local CHARSET_MAP = { - _default = 0, - big5 = 1, - dec8 = 3, - cp850 = 4, - hp8 = 6, - koi8r = 7, - latin1 = 8, - latin2 = 9, - swe7 = 10, - ascii = 11, - ujis = 12, - sjis = 13, - hebrew = 16, - tis620 = 18, - euckr = 19, - koi8u = 22, - gb2312 = 24, - greek = 25, - cp1250 = 26, - gbk = 28, - latin5 = 30, - armscii8 = 32, - utf8 = 33, - ucs2 = 35, - cp866 = 36, - keybcs2 = 37, - macce = 38, - macroman = 39, - cp852 = 40, - latin7 = 41, - utf8mb4 = 45, - cp1251 = 51, - utf16 = 54, - utf16le = 56, - cp1256 = 57, - cp1257 = 59, - utf32 = 60, - binary = 63, - geostd8 = 92, - cp932 = 95, - eucjpms = 97, - gb18030 = 248 + _default = 0, + big5 = 1, + dec8 = 3, + cp850 = 4, + hp8 = 6, + koi8r = 7, + latin1 = 8, + latin2 = 9, + swe7 = 10, + ascii = 11, + ujis = 12, + sjis = 13, + hebrew = 16, + tis620 = 18, + euckr = 19, + koi8u = 22, + gb2312 = 24, + greek = 25, + cp1250 = 26, + gbk = 28, + latin5 = 30, + armscii8 = 32, + utf8 = 33, + ucs2 = 35, + cp866 = 36, + keybcs2 = 37, + macce = 38, + macroman = 39, + cp852 = 40, + latin7 = 41, + utf8mb4 = 45, + cp1251 = 51, + utf16 = 54, + utf16le = 56, + cp1256 = 57, + cp1257 = 59, + utf32 = 60, + binary = 63, + geostd8 = 92, + cp932 = 95, + eucjpms = 97, + gb18030 = 248 } -local mt = { __index = _M } - +local conn = {} +local mt = {__index = conn} -- mysql field value type converters local converters = new_tab(0, 9) for i = 0x01, 0x05 do - -- tiny, short, long, float, double - converters[i] = tonumber + -- tiny, short, long, float, double + converters[i] = tonumber end converters[0x00] = tonumber -- decimal -- converters[0x08] = tonumber -- long long @@ -114,131 +111,135 @@ converters[0xf6] = tonumber -- newdecimal local function _get_byte2(data, i) - local a, b = strbyte(data, i, i + 1) - return bor(a, lshift(b, 8)), i + 2 + local a, b = strbyte(data, i, i + 1) + return bor(a, lshift(b, 8)), i + 2 end local function _get_byte3(data, i) - local a, b, c = strbyte(data, i, i + 2) - return bor(a, lshift(b, 8), lshift(c, 16)), i + 3 + local a, b, c = strbyte(data, i, i + 2) + return bor(a, lshift(b, 8), lshift(c, 16)), i + 3 end local function _get_byte4(data, i) - local a, b, c, d = strbyte(data, i, i + 3) - return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4 + local a, b, c, d = strbyte(data, i, i + 3) + return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4 end local function _get_byte8(data, i) - local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7) + local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7) - -- XXX workaround for the lack of 64-bit support in bitop: - local lo = bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)) - local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24)) - return lo + hi * 4294967296, i + 8 + -- XXX workaround for the lack of 64-bit support in bitop: + local lo = bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)) + local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24)) + return lo + hi * 4294967296, i + 8 - -- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32), - -- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8 + -- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32), + -- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8 end local function _set_byte2(n) - return strchar(band(n, 0xff), band(rshift(n, 8), 0xff)) + return strchar(band(n, 0xff), band(rshift(n, 8), 0xff)) end local function _set_byte3(n) - return strchar(band(n, 0xff), - band(rshift(n, 8), 0xff), - band(rshift(n, 16), 0xff)) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff)) end local function _set_byte4(n) - return strchar(band(n, 0xff), - band(rshift(n, 8), 0xff), - band(rshift(n, 16), 0xff), - band(rshift(n, 24), 0xff)) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff), + band(rshift(n, 24), 0xff)) end local function _from_cstring(data, i) - local last = strfind(data, '\0', i, true) - if not last then - return nil, nil - end + local last = strfind(data, '\0', i, true) + if not last then + return nil, nil + end - return sub(data, i, last - 1), last + 1 + return sub(data, i, last - 1), last + 1 end local function _to_cstring(data) - return data .. '\0' + return data .. '\0' end local function _to_binary_coded_string(data) - return strchar(#data) .. data + return strchar(#data) .. data end local function _dump(data) - local len = #data - local bytes = new_tab(len, 0) - for i = 1, len do - bytes[i] = format('%x', strbyte(data, i)) - end - return concat(bytes, ' ') + local len = #data + local bytes = new_tab(len, 0) + for i = 1, len do + bytes[i] = format('%x', strbyte(data, i)) + end + return concat(bytes, ' ') end local function _dumphex(data) - local len = #data - local bytes = new_tab(len, 0) - for i = 1, len do - bytes[i] = tohex(strbyte(data, i), 2) - end - return concat(bytes, ' ') + local len = #data + local bytes = new_tab(len, 0) + for i = 1, len do + bytes[i] = tohex(strbyte(data, i), 2) + end + return concat(bytes, ' ') end local function _compute_token(password, scramble) - if password == '' then - return '' - end + if password == '' then + return '' + end - local stage1 = sha1(password) - local stage2 = sha1(stage1) - local stage3 = sha1(scramble .. stage2) - local n = #stage1 - local bytes = new_tab(n, 0) - for i = 1, n do - bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i))) - end + local stage1 = sha1(password) + local stage2 = sha1(stage1) + local stage3 = sha1(scramble .. stage2) + local n = #stage1 + local bytes = new_tab(n, 0) + for i = 1, n do + bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i))) + end - return concat(bytes) + return concat(bytes) end local function _send_packet(self, req, size) - local sock = self.sock + local sock = self.sock - self.packet_no = self.packet_no + 1 + self.packet_no = self.packet_no + 1 - -- print('packet no: ', self.packet_no) + -- print('packet no: ', self.packet_no) - local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req + local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req - -- print('sending packet: ', _dump(packet)) + -- print('sending packet: ', _dump(packet)) - -- print('sending packet... of size ' .. #packet) + -- print('sending packet... of size ' .. #packet) - return sock:send(packet) + return sock:send(packet) end --static, auto-growing buffer allocation pattern (ctype must be vla). +local max, ceil, log = math.max, math.ceil, math.log +local function nextpow2(x) + return max(0, 2^(ceil(log(x) / log(2)))) +end local function grow_buffer(ctype) local vla = ffi.typeof(ctype) local buf, len = nil, -1 @@ -246,7 +247,7 @@ local function grow_buffer(ctype) if minlen == false then buf, len = nil, -1 elseif minlen > len then - len = glue.nextpow2(minlen) + len = nextpow2(minlen) buf = vla(len) end return buf, len @@ -254,191 +255,190 @@ local function grow_buffer(ctype) end local function _recv(self, sz) - local buf = self.buf - if not buf then - buf = grow_buffer'char[?]' - self.buf = buf - end - local buf = buf(sz) - local sock = self.sock - local offset = 0 - while sz > 0 do - local n, err = sock:recv(buf + offset, sz) - if not n then return nil, err end - sz = sz - n - offset = offset + n - end - return ffi.string(buf, offset) + local buf = self.buf + if not buf then + buf = grow_buffer'char[?]' + self.buf = buf + end + local buf = buf(sz) + local sock = self.sock + local offset = 0 + while sz > 0 do + local n, err = sock:recv(buf + offset, sz) + if not n then return nil, err end + sz = sz - n + offset = offset + n + end + return ffi.string(buf, offset) end - local function _recv_packet(self) - local sock = self.sock + local sock = self.sock - local data, err = _recv(self, 4) -- packet header - if not data then - return nil, nil, 'failed to receive packet header: ' .. err - end + local data, err = _recv(self, 4) -- packet header + if not data then + return nil, nil, 'failed to receive packet header: ' .. err + end - --print('packet header: ', _dump(data)) + --print('packet header: ', _dump(data)) - local len, pos = _get_byte3(data, 1) + local len, pos = _get_byte3(data, 1) - --print('packet length: ', len) + --print('packet length: ', len) - if len == 0 then - return nil, nil, 'empty packet' - end + if len == 0 then + return nil, nil, 'empty packet' + end - if len > self._max_packet_size then - return nil, nil, 'packet size too big: ' .. len - end + if len > self._max_packet_size then + return nil, nil, 'packet size too big: ' .. len + end - local num = strbyte(data, pos) + local num = strbyte(data, pos) - --print('recv packet: packet no: ', num) + --print('recv packet: packet no: ', num) - self.packet_no = num + self.packet_no = num - data, err = _recv(self, len) + data, err = _recv(self, len) - --print('receive returned') + --print('receive returned') - if not data then - return nil, nil, 'failed to read packet content: ' .. err - end + if not data then + return nil, nil, 'failed to read packet content: ' .. err + end - --print('packet content: ', _dump(data)) - --print('packet content (ascii): ', data) + --print('packet content: ', _dump(data)) + --print('packet content (ascii): ', data) - local field_count = strbyte(data, 1) + local field_count = strbyte(data, 1) - local typ - if field_count == 0x00 then - typ = 'OK' - elseif field_count == 0xff then - typ = 'ERR' - elseif field_count == 0xfe then - typ = 'EOF' - else - typ = 'DATA' - end + local typ + if field_count == 0x00 then + typ = 'OK' + elseif field_count == 0xff then + typ = 'ERR' + elseif field_count == 0xfe then + typ = 'EOF' + else + typ = 'DATA' + end - return data, typ + return data, typ end local function _from_length_coded_bin(data, pos) - local first = strbyte(data, pos) + local first = strbyte(data, pos) - --print('LCB: first: ', first) + --print('LCB: first: ', first) - if not first then - return nil, pos - end + if not first then + return nil, pos + end - if first >= 0 and first <= 250 then - return first, pos + 1 - end + if first >= 0 and first <= 250 then + return first, pos + 1 + end - if first == 251 then - return null, pos + 1 - end + if first == 251 then + return null, pos + 1 + end - if first == 252 then - pos = pos + 1 - return _get_byte2(data, pos) - end + if first == 252 then + pos = pos + 1 + return _get_byte2(data, pos) + end - if first == 253 then - pos = pos + 1 - return _get_byte3(data, pos) - end + if first == 253 then + pos = pos + 1 + return _get_byte3(data, pos) + end - if first == 254 then - pos = pos + 1 - return _get_byte8(data, pos) - end + if first == 254 then + pos = pos + 1 + return _get_byte8(data, pos) + end - return nil, pos + 1 + return nil, pos + 1 end local function _from_length_coded_str(data, pos) - local len - len, pos = _from_length_coded_bin(data, pos) - if not len or len == null then - return null, pos - end + local len + len, pos = _from_length_coded_bin(data, pos) + if not len or len == null then + return null, pos + end - return sub(data, pos, pos + len - 1), pos + len + return sub(data, pos, pos + len - 1), pos + len end local function _parse_ok_packet(packet) - local res = new_tab(0, 5) - local pos + local res = new_tab(0, 5) + local pos - res.affected_rows, pos = _from_length_coded_bin(packet, 2) + res.affected_rows, pos = _from_length_coded_bin(packet, 2) - --print('affected rows: ', res.affected_rows, ', pos:', pos) + --print('affected rows: ', res.affected_rows, ', pos:', pos) - res.insert_id, pos = _from_length_coded_bin(packet, pos) + res.insert_id, pos = _from_length_coded_bin(packet, pos) - --print('insert id: ', res.insert_id, ', pos:', pos) + --print('insert id: ', res.insert_id, ', pos:', pos) - res.server_status, pos = _get_byte2(packet, pos) + res.server_status, pos = _get_byte2(packet, pos) - --print('server status: ', res.server_status, ', pos:', pos) + --print('server status: ', res.server_status, ', pos:', pos) - res.warning_count, pos = _get_byte2(packet, pos) + res.warning_count, pos = _get_byte2(packet, pos) - --print('warning count: ', res.warning_count, ', pos: ', pos) + --print('warning count: ', res.warning_count, ', pos: ', pos) - local message = _from_length_coded_str(packet, pos) - if message and message ~= null then - res.message = message - end + local message = _from_length_coded_str(packet, pos) + if message and message ~= null then + res.message = message + end - --print('message: ', res.message, ', pos:', pos) + --print('message: ', res.message, ', pos:', pos) - return res + return res end local function _parse_eof_packet(packet) - local pos = 2 + local pos = 2 - local warning_count, pos = _get_byte2(packet, pos) - local status_flags = _get_byte2(packet, pos) + local warning_count, pos = _get_byte2(packet, pos) + local status_flags = _get_byte2(packet, pos) - return warning_count, status_flags + return warning_count, status_flags end local function _parse_err_packet(packet) - local errno, pos = _get_byte2(packet, 2) - local marker = sub(packet, pos, pos) - local sqlstate - if marker == '#' then - -- with sqlstate - pos = pos + 1 - sqlstate = sub(packet, pos, pos + 5 - 1) - pos = pos + 5 - end + local errno, pos = _get_byte2(packet, 2) + local marker = sub(packet, pos, pos) + local sqlstate + if marker == '#' then + -- with sqlstate + pos = pos + 1 + sqlstate = sub(packet, pos, pos + 5 - 1) + pos = pos + 5 + end - local message = sub(packet, pos) - return errno, message, sqlstate + local message = sub(packet, pos) + return errno, message, sqlstate end local function _parse_result_set_header_packet(packet) - local field_count, pos = _from_length_coded_bin(packet, 1) + local field_count, pos = _from_length_coded_bin(packet, 1) - local extra - extra = _from_length_coded_bin(packet, pos) + local extra + extra = _from_length_coded_bin(packet, pos) - return field_count, extra + return field_count, extra end local NOT_NULL_FLAG = 1 @@ -448,467 +448,441 @@ local UNSIGNED_FLAG = 32 local AUTO_INCREMENT_FLAG = 512 local function _parse_field_packet(data) - local col = new_tab(0, 2) - local pos - col.catalog, pos = _from_length_coded_str(data, 1) - col.db, pos = _from_length_coded_str(data, pos) - col.table, pos = _from_length_coded_str(data, pos) - col.orig_table, pos = _from_length_coded_str(data, pos) - col.name, pos = _from_length_coded_str(data, pos) - col.orig_name, pos = _from_length_coded_str(data, pos) - pos = pos + 1 -- ignore the filler - col.charsetnr, pos = _get_byte2(data, pos) - col.length, pos = _get_byte4(data, pos) - col.type = strbyte(data, pos) - pos = pos + 1 - col.flags, pos = _get_byte2(data, pos) - col.allow_null = band(col.flags, NOT_NULL_FLAG) == 0 - col.pri_key = band(col.flags, PRI_KEY_FLAG) ~= 0 - col.unique_key = band(col.flags, UNIQUE_KEY_FLAG) ~= 0 - col.unsigned = band(col.flags, UNSIGNED_FLAG) ~= 0 - col.auto_increment = band(col.flags, AUTO_INCREMENT_FLAG) ~= 0 - col.decimals = strbyte(data, pos) - pos = pos + 1 - local default = sub(data, pos + 2) - if default and default ~= '' then - col.default = default - end - return col + local col = new_tab(0, 2) + local pos + col.catalog, pos = _from_length_coded_str(data, 1) + col.db, pos = _from_length_coded_str(data, pos) + col.table, pos = _from_length_coded_str(data, pos) + col.orig_table, pos = _from_length_coded_str(data, pos) + col.name, pos = _from_length_coded_str(data, pos) + col.orig_name, pos = _from_length_coded_str(data, pos) + pos = pos + 1 -- ignore the filler + col.charsetnr, pos = _get_byte2(data, pos) + col.length, pos = _get_byte4(data, pos) + col.type = strbyte(data, pos) + pos = pos + 1 + col.flags, pos = _get_byte2(data, pos) + col.allow_null = band(col.flags, NOT_NULL_FLAG) == 0 + col.pri_key = band(col.flags, PRI_KEY_FLAG) ~= 0 + col.unique_key = band(col.flags, UNIQUE_KEY_FLAG) ~= 0 + col.unsigned = band(col.flags, UNSIGNED_FLAG) ~= 0 + col.auto_increment = band(col.flags, AUTO_INCREMENT_FLAG) ~= 0 + col.decimals = strbyte(data, pos) + pos = pos + 1 + local default = sub(data, pos + 2) + if default and default ~= '' then + col.default = default + end + return col end local function _parse_row_data_packet(data, cols, compact) - local pos = 1 - local ncols = #cols - local row - if compact then - row = new_tab(ncols, 0) - else - row = new_tab(0, ncols) - end - for i = 1, ncols do - local value - value, pos = _from_length_coded_str(data, pos) - local col = cols[i] - local typ = col.type - local name = col.name + local pos = 1 + local ncols = #cols + local row + if compact then + row = new_tab(ncols, 0) + else + row = new_tab(0, ncols) + end + for i = 1, ncols do + local value + value, pos = _from_length_coded_str(data, pos) + local col = cols[i] + local typ = col.type + local name = col.name - --print('row field value: ', value, ', type: ', typ) + --print('row field value: ', value, ', type: ', typ) - if value ~= null then - local conv = converters[typ] - if conv then - value = conv(value) - end - end + if value ~= null then + local conv = converters[typ] + if conv then + value = conv(value) + end + end - if compact then - row[i] = value - elseif value ~= null then - row[name] = value - end - end + if compact then + row[i] = value + elseif value ~= null then + row[name] = value + end + end - return row + return row end local function _recv_field_packet(self) - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end - if typ == 'ERR' then - local errno, msg, sqlstate = _parse_err_packet(packet) - return nil, msg, errno, sqlstate - end + if typ == 'ERR' then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end - if typ ~= 'DATA' then - return nil, 'bad field packet type: ' .. typ - end + if typ ~= 'DATA' then + return nil, 'bad field packet type: ' .. typ + end - -- typ == 'DATA' + -- typ == 'DATA' - return _parse_field_packet(packet) + return _parse_field_packet(packet) end -function _M.new(self) - local sock, err = tcp() - if not sock then - return nil, err - end - return setmetatable({ sock = sock }, mt) +function mysql.new(self, opt) + local tcp = opt and opt.tcp or require'sock'.tcp + local sock, err = tcp() + if not sock then + return nil, err + end + return setmetatable({ sock = sock }, mt) end -function _M.connect(self, opts) - local sock = self.sock - if not sock then - return nil, 'not initialized' - end +function conn:connect(opts) + local sock = self.sock - local max_packet_size = opts.max_packet_size - if not max_packet_size then - max_packet_size = 1024 * 1024 -- default 1 MB - end - self._max_packet_size = max_packet_size + local max_packet_size = opts.max_packet_size + if not max_packet_size then + max_packet_size = 1024 * 1024 -- default 1 MB + end + self._max_packet_size = max_packet_size - local ok, err + local ok, err - self.compact = opts.compact_arrays + local database = opts.database or '' + local user = opts.user or '' - local database = opts.database or '' - local user = opts.user or '' + local charset = CHARSET_MAP[opts.charset or '_default'] + if not charset then + return nil, 'charset \'' .. opts.charset .. '\' is not supported' + end - local charset = CHARSET_MAP[opts.charset or '_default'] - if not charset then - return nil, 'charset \'' .. opts.charset .. '\' is not supported' - end + local host = opts.host + local port = opts.port or 3306 + ok, err = sock:connect(host, port) - local host = opts.host - local port = opts.port or 3306 - ok, err = sock:connect(host, port) + if not ok then + return nil, 'failed to connect: ' .. err + end - if not ok then - return nil, 'failed to connect: ' .. err - end + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end + if typ == 'ERR' then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end - if typ == 'ERR' then - local errno, msg, sqlstate = _parse_err_packet(packet) - return nil, msg, errno, sqlstate - end + self.protocol_ver = strbyte(packet) - self.protocol_ver = strbyte(packet) + --print('protocol version: ', self.protocol_ver) - --print('protocol version: ', self.protocol_ver) + local server_ver, pos = _from_cstring(packet, 2) + if not server_ver then + return nil, 'bad handshake initialization packet: bad server version' + end - local server_ver, pos = _from_cstring(packet, 2) - if not server_ver then - return nil, 'bad handshake initialization packet: bad server version' - end + --print('server version: ', server_ver) - --print('server version: ', server_ver) + self._server_ver = server_ver - self._server_ver = server_ver + local thread_id, pos = _get_byte4(packet, pos) - local thread_id, pos = _get_byte4(packet, pos) + --print('thread id: ', thread_id) - --print('thread id: ', thread_id) + local scramble = sub(packet, pos, pos + 8 - 1) + if not scramble then + return nil, '1st part of scramble not found' + end - local scramble = sub(packet, pos, pos + 8 - 1) - if not scramble then - return nil, '1st part of scramble not found' - end + pos = pos + 9 -- skip filler - pos = pos + 9 -- skip filler + -- two lower bytes + local capabilities -- server capabilities + capabilities, pos = _get_byte2(packet, pos) - -- two lower bytes - local capabilities -- server capabilities - capabilities, pos = _get_byte2(packet, pos) + -- print(format('server capabilities: %#x', capabilities)) - -- print(format('server capabilities: %#x', capabilities)) + self._server_lang = strbyte(packet, pos) + pos = pos + 1 - self._server_lang = strbyte(packet, pos) - pos = pos + 1 + --print('server lang: ', self._server_lang) - --print('server lang: ', self._server_lang) + self._server_status, pos = _get_byte2(packet, pos) - self._server_status, pos = _get_byte2(packet, pos) + --print('server status: ', self._server_status) - --print('server status: ', self._server_status) + local more_capabilities + more_capabilities, pos = _get_byte2(packet, pos) - local more_capabilities - more_capabilities, pos = _get_byte2(packet, pos) + capabilities = bor(capabilities, lshift(more_capabilities, 16)) - capabilities = bor(capabilities, lshift(more_capabilities, 16)) + --print('server capabilities: ', capabilities) - --print('server capabilities: ', capabilities) + -- local len = strbyte(packet, pos) + local len = 21 - 8 - 1 - -- local len = strbyte(packet, pos) - local len = 21 - 8 - 1 + --print('scramble len: ', len) - --print('scramble len: ', len) + pos = pos + 1 + 10 - pos = pos + 1 + 10 + local scramble_part2 = sub(packet, pos, pos + len - 1) + if not scramble_part2 then + return nil, '2nd part of scramble not found' + end - local scramble_part2 = sub(packet, pos, pos + len - 1) - if not scramble_part2 then - return nil, '2nd part of scramble not found' - end + scramble = scramble .. scramble_part2 + --print('scramble: ', _dump(scramble)) - scramble = scramble .. scramble_part2 - --print('scramble: ', _dump(scramble)) + local client_flags = 0x3f7cf; - local client_flags = 0x3f7cf; + local ssl_verify = opts.ssl_verify + local use_ssl = opts.ssl or ssl_verify - local ssl_verify = opts.ssl_verify - local use_ssl = opts.ssl or ssl_verify + if use_ssl then + if band(capabilities, CLIENT_SSL) == 0 then + return nil, 'ssl disabled on server' + end - if use_ssl then - if band(capabilities, CLIENT_SSL) == 0 then - return nil, 'ssl disabled on server' - end + -- send a SSL Request Packet + local req = _set_byte4(bor(client_flags, CLIENT_SSL)) + .. _set_byte4(self._max_packet_size) + .. strchar(charset) + .. strrep('\0', 23) - -- send a SSL Request Packet - local req = _set_byte4(bor(client_flags, CLIENT_SSL)) - .. _set_byte4(self._max_packet_size) - .. strchar(charset) - .. strrep('\0', 23) + local packet_len = 4 + 4 + 1 + 23 + local bytes, err = _send_packet(self, req, packet_len) + if not bytes then + return nil, 'failed to send client authentication packet: ' .. err + end - local packet_len = 4 + 4 + 1 + 23 - local bytes, err = _send_packet(self, req, packet_len) - if not bytes then - return nil, 'failed to send client authentication packet: ' .. err - end + local ok, err = sock:sslhandshake(false, nil, ssl_verify) + if not ok then + return nil, 'failed to do ssl handshake: ' .. (err or '') + end + end - local ok, err = sock:sslhandshake(false, nil, ssl_verify) - if not ok then - return nil, 'failed to do ssl handshake: ' .. (err or '') - end - end + local password = opts.password or '' - local password = opts.password or '' + local token = _compute_token(password, scramble) - local token = _compute_token(password, scramble) + --print('token: ', _dump(token)) - --print('token: ', _dump(token)) + local req = _set_byte4(client_flags) + .. _set_byte4(self._max_packet_size) + .. strchar(charset) + .. strrep('\0', 23) + .. _to_cstring(user) + .. _to_binary_coded_string(token) + .. _to_cstring(database) - local req = _set_byte4(client_flags) - .. _set_byte4(self._max_packet_size) - .. strchar(charset) - .. strrep('\0', 23) - .. _to_cstring(user) - .. _to_binary_coded_string(token) - .. _to_cstring(database) + local packet_len = 4 + 4 + 1 + 23 + #user + 1 + + #token + 1 + #database + 1 - local packet_len = 4 + 4 + 1 + 23 + #user + 1 - + #token + 1 + #database + 1 + -- print('packet content length: ', packet_len) + -- print('packet content: ', _dump(concat(req, ''))) - -- print('packet content length: ', packet_len) - -- print('packet content: ', _dump(concat(req, ''))) + local bytes, err = _send_packet(self, req, packet_len) + if not bytes then + return nil, 'failed to send client authentication packet: ' .. err + end - local bytes, err = _send_packet(self, req, packet_len) - if not bytes then - return nil, 'failed to send client authentication packet: ' .. err - end + --print('packet sent ', bytes, ' bytes') - --print('packet sent ', bytes, ' bytes') + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, 'failed to receive the result packet: ' .. err + end - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, 'failed to receive the result packet: ' .. err - end + if typ == 'ERR' then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end - if typ == 'ERR' then - local errno, msg, sqlstate = _parse_err_packet(packet) - return nil, msg, errno, sqlstate - end + if typ == 'EOF' then + return nil, 'old pre-4.1 authentication protocol not supported' + end - if typ == 'EOF' then - return nil, 'old pre-4.1 authentication protocol not supported' - end + if typ ~= 'OK' then + return nil, 'bad packet type: ' .. typ + end - if typ ~= 'OK' then - return nil, 'bad packet type: ' .. typ - end + self.state = STATE_CONNECTED - self.state = STATE_CONNECTED - - return 1 + return 1 end -function _M.close(self) - local sock = self.sock - if not sock then - return nil, 'not initialized' - end +function conn:close() + local sock = assert(self.sock) - self.state = nil + self.state = nil - local bytes, err = _send_packet(self, strchar(COM_QUIT), 1) - if not bytes then - return nil, err - end + local bytes, err = _send_packet(self, strchar(COM_QUIT), 1) + if not bytes then + return nil, err + end - return sock:close() + return sock:close() end -function _M.server_ver(self) - return self._server_ver +function conn:server_ver() + return self._server_ver end -local function send_query(self, query) - if self.state ~= STATE_CONNECTED then - return nil, 'cannot send query in the current context: ' - .. (self.state or 'nil') - end +function conn:send_query(query) + assert(self.state == STATE_CONNECTED) + local sock = assert(self.sock) - local sock = self.sock - if not sock then - return nil, 'not initialized' - end + self.packet_no = -1 - self.packet_no = -1 + local cmd_packet = strchar(COM_QUERY) .. query + local packet_len = 1 + #query - local cmd_packet = strchar(COM_QUERY) .. query - local packet_len = 1 + #query + local bytes, err = _send_packet(self, cmd_packet, packet_len) + if not bytes then + return nil, err + end - local bytes, err = _send_packet(self, cmd_packet, packet_len) - if not bytes then - return nil, err - end + self.state = STATE_COMMAND_SENT - self.state = STATE_COMMAND_SENT + --print('packet sent ', bytes, ' bytes') - --print('packet sent ', bytes, ' bytes') - - return bytes -end -_M.send_query = send_query - -local function read_result(self, est_nrows) - if self.state ~= STATE_COMMAND_SENT then - return nil, 'cannot read result in the current context: ' - .. (self.state or 'nil') - end - - local sock = self.sock - if not sock then - return nil, 'not initialized' - end - - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end - - if typ == 'ERR' then - self.state = STATE_CONNECTED - - local errno, msg, sqlstate = _parse_err_packet(packet) - return nil, msg, errno, sqlstate - end - - if typ == 'OK' then - local res = _parse_ok_packet(packet) - if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then - return res, 'again' - end - - self.state = STATE_CONNECTED - return res - end - - if typ ~= 'DATA' then - self.state = STATE_CONNECTED - - return nil, 'packet type ' .. typ .. ' not supported' - end - - -- typ == 'DATA' - - --print('read the result set header packet') - - local field_count, extra = _parse_result_set_header_packet(packet) - - --print('field count: ', field_count) - - local cols = new_tab(field_count, 0) - for i = 1, field_count do - local col, err, errno, sqlstate = _recv_field_packet(self) - if not col then - return nil, err, errno, sqlstate - end - - cols[i] = col - end - - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end - - if typ ~= 'EOF' then - return nil, 'unexpected packet type ' .. typ .. ' while eof packet is ' - .. 'expected' - end - - -- typ == 'EOF' - - local compact = self.compact - - local rows = new_tab(est_nrows or 4, 0) - local i = 0 - while true do - --print('reading a row') - - packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end - - if typ == 'EOF' then - local warning_count, status_flags = _parse_eof_packet(packet) - - --print('status flags: ', status_flags) - - if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then - return rows, 'again', cols - end - - break - end - - -- if typ ~= 'DATA' then - -- return nil, 'bad row packet type: ' .. typ - -- end - - -- typ == 'DATA' - - local row = _parse_row_data_packet(packet, cols, compact) - i = i + 1 - rows[i] = row - end - - self.state = STATE_CONNECTED - - return rows, nil, cols -end -_M.read_result = read_result - -function _M.query(self, query, est_nrows) - local bytes, err = send_query(self, query) - if not bytes then - return nil, 'failed to send query: ' .. err - end - - return read_result(self, est_nrows) + return bytes end -function _M.set_compact_arrays(self, value) - self.compact = value +function conn:read_result(est_nrows, compact) + assert(self.state == STATE_COMMAND_SENT) + local sock = assert(self.sock) + + compact = compact == 'compact' + if est_nrows == 'compact' then + est_nrows = null + compact = true + end + + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ == 'ERR' then + self.state = STATE_CONNECTED + + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end + + if typ == 'OK' then + local res = _parse_ok_packet(packet) + if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return res, 'again' + end + + self.state = STATE_CONNECTED + return res + end + + if typ ~= 'DATA' then + self.state = STATE_CONNECTED + + return nil, 'packet type ' .. typ .. ' not supported' + end + + -- typ == 'DATA' + + --print('read the result set header packet') + + local field_count, extra = _parse_result_set_header_packet(packet) + + --print('field count: ', field_count) + + local cols = new_tab(field_count, 0) + for i = 1, field_count do + local col, err, errno, sqlstate = _recv_field_packet(self) + if not col then + return nil, err, errno, sqlstate + end + + cols[i] = col + end + + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ ~= 'EOF' then + return nil, 'unexpected packet type ' .. typ .. ' while eof packet is ' + .. 'expected' + end + + -- typ == 'EOF' + + local rows = new_tab(est_nrows or 4, 0) + local i = 0 + while true do + --print('reading a row') + + packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ == 'EOF' then + local warning_count, status_flags = _parse_eof_packet(packet) + + --print('status flags: ', status_flags) + + if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return rows, 'again', cols + end + + break + end + + -- if typ ~= 'DATA' then + -- return nil, 'bad row packet type: ' .. typ + -- end + + -- typ == 'DATA' + + local row = _parse_row_data_packet(packet, cols, compact) + i = i + 1 + rows[i] = row + end + + self.state = STATE_CONNECTED + + return rows, nil, cols +end + +function conn:query(query, est_nrows) + local bytes, err, errcode = self:send_query(query) + if not bytes then return nil, err, errcode end + return self:read_result(est_nrows) end local qmap = { - ['\0' ] = '\\0', - ['\b' ] = '\\b', - ['\n' ] = '\\n', - ['\r' ] = '\\r', - ['\t' ] = '\\t', - ['\26'] = '\\Z', - ['\\' ] = '\\\\', - ['\'' ] = '\\\'', - ['\"' ] = '\\"', + ['\0' ] = '\\0', + ['\b' ] = '\\b', + ['\n' ] = '\\n', + ['\r' ] = '\\r', + ['\t' ] = '\\t', + ['\26'] = '\\Z', + ['\\' ] = '\\\\', + ['\'' ] = '\\\'', + ['\"' ] = '\\"', } -function _M.quote(s) - return s:gsub('[%z\b\n\r\t\26\\\'\"]', qmap) +function mysql.quote(s) + return s:gsub('[%z\b\n\r\t\26\\\'\"]', qmap) end -return _M +return mysql diff --git a/mysql_client.md b/mysql_client.md index 877d343..af03dee 100644 --- a/mysql_client.md +++ b/mysql_client.md @@ -72,8 +72,6 @@ The `options` argument is a Lua table holding the following keys: If the server does not have SSL support (or just disabled), the error string "ssl disabled on server" will be returned. * `ssl_verify`: if `true`, then verifies the validity of the server SSL certificate (default to `false`). - * `compact_arrays`: `true` to use array-of-arrays structure for the result set, - rather than the default array-of-hashes structure. ### `db:close() -> 1 | nil,err` @@ -85,7 +83,7 @@ Sends the query to the remote MySQL server without waiting for its replies. Returns the bytes successfully sent out. Use `read_result()` to read the replies. -### `db:read_result([nrows]) -> res | nil,err,errcode,sqlstate` +### `db:read_result([nrows,]['compact']) -> res | nil,err,errcode,sqlstate` Reads in one result returned from the server. @@ -97,8 +95,17 @@ Each row holds key-value pairs for each data fields. For instance, ```lua { - { name = "Bob", age = 32, phone = ngx.null }, - { name = "Marry", age = 18, phone = "10666372"} + { name = "Bob", age = 32, phone = mysql.null }, + { name = "Marry", age = 18, phone = "10666372" } + } +``` + +If `'compact'` given, it returns an array-of-arrays instead: + +```lua + { + { "Bob", 32, null }, + { "Marry", 18, "10666372" } } ``` @@ -148,12 +155,6 @@ Returns the MySQL server version string, like `"5.1.64"`. You should only call this method after successfully connecting to a MySQL server, otherwise `nil` will be returned. -### `db:set_compact_arrays(true|false)` - -Sets whether to use the "compact-arrays" structure for the resultsets returned -by subsequent queries. See the `compact_arrays` option for the `connect` -method for more details. - ### `mysql.quote(s) -> s` Quote literal string to be used in queries.