commit 89969e280953870e1a04d8a9985b32cd5c256006 Author: Cosmin Apreutesei Date: Thu Apr 29 22:24:55 2021 +0300 unimportant diff --git a/mysql_client.lua b/mysql_client.lua new file mode 100644 index 0000000..0f63969 --- /dev/null +++ b/mysql_client.lua @@ -0,0 +1,948 @@ +-- Copyright (C) Yichun Zhang (agentzh) + +local tcp = require'sock'.tcp +local sha1 = require'sha1' +local bit = require'bit' + +local sub = string.sub +local strbyte = string.byte +local strchar = string.char +local strfind = string.find +local format = string.format +local strrep = string.rep +local null = function() end +local band = bit.band +local bxor = bit.bxor +local bor = bit.bor +local lshift = bit.lshift +local rshift = bit.rshift +local tohex = bit.tohex +local concat = table.concat +local unpack = unpack +local setmetatable = setmetatable +local error = error +local tonumber = tonumber + +local ok, new_tab = pcall(require, 'table.new') +if not ok then + new_tab = function (narr, nrec) return {} end +end + +local _M = { _VERSION = '0.21' } + + +-- constants + +local STATE_CONNECTED = 1 +local STATE_COMMAND_SENT = 2 + +local COM_QUIT = 0x01 +local COM_QUERY = 0x03 +local CLIENT_SSL = 0x0800 + +local SERVER_MORE_RESULTS_EXISTS = 8 + +-- 16MB - 1, the default max allowed packet size used by libmysqlclient +local FULL_PACKET_SIZE = 16777215 + +-- the following charset map is generated from the following mysql query: +-- SELECT CHARACTER_SET_NAME, ID +-- FROM information_schema.collations +-- WHERE IS_DEFAULT = 'Yes' ORDER BY id; +local CHARSET_MAP = { + _default = 0, + big5 = 1, + dec8 = 3, + cp850 = 4, + hp8 = 6, + koi8r = 7, + latin1 = 8, + latin2 = 9, + swe7 = 10, + ascii = 11, + ujis = 12, + sjis = 13, + hebrew = 16, + tis620 = 18, + euckr = 19, + koi8u = 22, + gb2312 = 24, + greek = 25, + cp1250 = 26, + gbk = 28, + latin5 = 30, + armscii8 = 32, + utf8 = 33, + ucs2 = 35, + cp866 = 36, + keybcs2 = 37, + macce = 38, + macroman = 39, + cp852 = 40, + latin7 = 41, + utf8mb4 = 45, + cp1251 = 51, + utf16 = 54, + utf16le = 56, + cp1256 = 57, + cp1257 = 59, + utf32 = 60, + binary = 63, + geostd8 = 92, + cp932 = 95, + eucjpms = 97, + gb18030 = 248 +} + +local mt = { __index = _M } + + +-- mysql field value type converters +local converters = new_tab(0, 9) + +for i = 0x01, 0x05 do + -- tiny, short, long, float, double + converters[i] = tonumber +end +converters[0x00] = tonumber -- decimal +-- converters[0x08] = tonumber -- long long +converters[0x09] = tonumber -- int24 +converters[0x0d] = tonumber -- year +converters[0xf6] = tonumber -- newdecimal + + +local function _get_byte2(data, i) + local a, b = strbyte(data, i, i + 1) + return bor(a, lshift(b, 8)), i + 2 +end + + +local function _get_byte3(data, i) + local a, b, c = strbyte(data, i, i + 2) + return bor(a, lshift(b, 8), lshift(c, 16)), i + 3 +end + + +local function _get_byte4(data, i) + local a, b, c, d = strbyte(data, i, i + 3) + return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4 +end + + +local function _get_byte8(data, i) + local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7) + + -- XXX workaround for the lack of 64-bit support in bitop: + local lo = bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)) + local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24)) + return lo + hi * 4294967296, i + 8 + + -- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32), + -- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8 +end + + +local function _set_byte2(n) + return strchar(band(n, 0xff), band(rshift(n, 8), 0xff)) +end + + +local function _set_byte3(n) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff)) +end + + +local function _set_byte4(n) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff), + band(rshift(n, 24), 0xff)) +end + + +local function _from_cstring(data, i) + local last = strfind(data, "\0", i, true) + if not last then + return nil, nil + end + + return sub(data, i, last - 1), last + 1 +end + + +local function _to_cstring(data) + return data .. "\0" +end + + +local function _to_binary_coded_string(data) + return strchar(#data) .. data +end + + +local function _dump(data) + local len = #data + local bytes = new_tab(len, 0) + for i = 1, len do + bytes[i] = format("%x", strbyte(data, i)) + end + return concat(bytes, " ") +end + + +local function _dumphex(data) + local len = #data + local bytes = new_tab(len, 0) + for i = 1, len do + bytes[i] = tohex(strbyte(data, i), 2) + end + return concat(bytes, " ") +end + + +local function _compute_token(password, scramble) + if password == "" then + return "" + end + + local stage1 = sha1(password) + local stage2 = sha1(stage1) + local stage3 = sha1(scramble .. stage2) + local n = #stage1 + local bytes = new_tab(n, 0) + for i = 1, n do + bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i))) + end + + return concat(bytes) +end + + +local function _send_packet(self, req, size) + local sock = self.sock + + self.packet_no = self.packet_no + 1 + + -- print("packet no: ", self.packet_no) + + local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req + + -- print("sending packet: ", _dump(packet)) + + -- print("sending packet... of size " .. #packet) + + return sock:send(packet) +end + + +local function _recv_packet(self) + local sock = self.sock + + local data, err = sock:receive(4) -- packet header + if not data then + return nil, nil, "failed to receive packet header: " .. err + end + + --print("packet header: ", _dump(data)) + + local len, pos = _get_byte3(data, 1) + + --print("packet length: ", len) + + if len == 0 then + return nil, nil, "empty packet" + end + + if len > self._max_packet_size then + return nil, nil, "packet size too big: " .. len + end + + local num = strbyte(data, pos) + + --print("recv packet: packet no: ", num) + + self.packet_no = num + + data, err = sock:receive(len) + + --print("receive returned") + + if not data then + return nil, nil, "failed to read packet content: " .. err + end + + --print("packet content: ", _dump(data)) + --print("packet content (ascii): ", data) + + local field_count = strbyte(data, 1) + + local typ + if field_count == 0x00 then + typ = "OK" + elseif field_count == 0xff then + typ = "ERR" + elseif field_count == 0xfe then + typ = "EOF" + else + typ = "DATA" + end + + return data, typ +end + + +local function _from_length_coded_bin(data, pos) + local first = strbyte(data, pos) + + --print("LCB: first: ", first) + + if not first then + return nil, pos + end + + if first >= 0 and first <= 250 then + return first, pos + 1 + end + + if first == 251 then + return null, pos + 1 + end + + if first == 252 then + pos = pos + 1 + return _get_byte2(data, pos) + end + + if first == 253 then + pos = pos + 1 + return _get_byte3(data, pos) + end + + if first == 254 then + pos = pos + 1 + return _get_byte8(data, pos) + end + + return nil, pos + 1 +end + + +local function _from_length_coded_str(data, pos) + local len + len, pos = _from_length_coded_bin(data, pos) + if not len or len == null then + return null, pos + end + + return sub(data, pos, pos + len - 1), pos + len +end + + +local function _parse_ok_packet(packet) + local res = new_tab(0, 5) + local pos + + res.affected_rows, pos = _from_length_coded_bin(packet, 2) + + --print("affected rows: ", res.affected_rows, ", pos:", pos) + + res.insert_id, pos = _from_length_coded_bin(packet, pos) + + --print("insert id: ", res.insert_id, ", pos:", pos) + + res.server_status, pos = _get_byte2(packet, pos) + + --print("server status: ", res.server_status, ", pos:", pos) + + res.warning_count, pos = _get_byte2(packet, pos) + + --print("warning count: ", res.warning_count, ", pos: ", pos) + + local message = _from_length_coded_str(packet, pos) + if message and message ~= null then + res.message = message + end + + --print("message: ", res.message, ", pos:", pos) + + return res +end + + +local function _parse_eof_packet(packet) + local pos = 2 + + local warning_count, pos = _get_byte2(packet, pos) + local status_flags = _get_byte2(packet, pos) + + return warning_count, status_flags +end + + +local function _parse_err_packet(packet) + local errno, pos = _get_byte2(packet, 2) + local marker = sub(packet, pos, pos) + local sqlstate + if marker == '#' then + -- with sqlstate + pos = pos + 1 + sqlstate = sub(packet, pos, pos + 5 - 1) + pos = pos + 5 + end + + local message = sub(packet, pos) + return errno, message, sqlstate +end + + +local function _parse_result_set_header_packet(packet) + local field_count, pos = _from_length_coded_bin(packet, 1) + + local extra + extra = _from_length_coded_bin(packet, pos) + + return field_count, extra +end + +local NOT_NULL_FLAG = 1 +local PRI_KEY_FLAG = 2 +local UNIQUE_KEY_FLAG = 4 +local UNSIGNED_FLAG = 32 +local AUTO_INCREMENT_FLAG = 512 + +local function _parse_field_packet(data) + local col = new_tab(0, 2) + local pos + col.catalog, pos = _from_length_coded_str(data, 1) + col.db, pos = _from_length_coded_str(data, pos) + col.table, pos = _from_length_coded_str(data, pos) + col.orig_table, pos = _from_length_coded_str(data, pos) + col.name, pos = _from_length_coded_str(data, pos) + col.orig_name, pos = _from_length_coded_str(data, pos) + pos = pos + 1 -- ignore the filler + col.charsetnr, pos = _get_byte2(data, pos) + col.length, pos = _get_byte4(data, pos) + col.type = strbyte(data, pos) + pos = pos + 1 + col.flags, pos = _get_byte2(data, pos) + col.allow_null = band(col.flags, NOT_NULL_FLAG) == 0 + col.pri_key = band(col.flags, PRI_KEY_FLAG) ~= 0 + col.unique_key = band(col.flags, UNIQUE_KEY_FLAG) ~= 0 + col.unsigned = band(col.flags, UNSIGNED_FLAG) ~= 0 + col.auto_increment = band(col.flags, AUTO_INCREMENT_FLAG) ~= 0 + col.decimals = strbyte(data, pos) + pos = pos + 1 + local default = sub(data, pos + 2) + if default and default ~= "" then + col.default = default + end + return col +end + + +local function _parse_row_data_packet(data, cols, compact) + local pos = 1 + local ncols = #cols + local row + if compact then + row = new_tab(ncols, 0) + else + row = new_tab(0, ncols) + end + for i = 1, ncols do + local value + value, pos = _from_length_coded_str(data, pos) + local col = cols[i] + local typ = col.type + local name = col.name + + --print("row field value: ", value, ", type: ", typ) + + if value ~= null then + local conv = converters[typ] + if conv then + value = conv(value) + end + end + + if compact then + row[i] = value + elseif value ~= null then + row[name] = value + end + end + + return row +end + + +local function _recv_field_packet(self) + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ == "ERR" then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end + + if typ ~= 'DATA' then + return nil, "bad field packet type: " .. typ + end + + -- typ == 'DATA' + + return _parse_field_packet(packet) +end + + +function _M.new(self) + local sock, err = tcp() + if not sock then + return nil, err + end + return setmetatable({ sock = sock }, mt) +end + + +function _M.set_timeout(self, timeout) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:settimeout(timeout) +end + + +function _M.connect(self, opts) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local max_packet_size = opts.max_packet_size + if not max_packet_size then + max_packet_size = 1024 * 1024 -- default 1 MB + end + self._max_packet_size = max_packet_size + + local ok, err + + self.compact = opts.compact_arrays + + local database = opts.database or "" + local user = opts.user or "" + + local charset = CHARSET_MAP[opts.charset or "_default"] + if not charset then + return nil, "charset '" .. opts.charset .. "' is not supported" + end + + local pool = opts.pool + + local host = opts.host + if host then + local port = opts.port or 3306 + if not pool then + pool = user .. ":" .. database .. ":" .. host .. ":" .. port + end + + ok, err = sock:connect(host, port, { pool = pool }) + + else + local path = opts.path + if not path then + return nil, 'neither "host" nor "path" options are specified' + end + + if not pool then + pool = user .. ":" .. database .. ":" .. path + end + + ok, err = sock:connect("unix:" .. path, { pool = pool }) + end + + if not ok then + return nil, 'failed to connect: ' .. err + end + + local reused = sock:getreusedtimes() + + if reused and reused > 0 then + self.state = STATE_CONNECTED + return 1 + end + + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ == "ERR" then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end + + self.protocol_ver = strbyte(packet) + + --print("protocol version: ", self.protocol_ver) + + local server_ver, pos = _from_cstring(packet, 2) + if not server_ver then + return nil, "bad handshake initialization packet: bad server version" + end + + --print("server version: ", server_ver) + + self._server_ver = server_ver + + local thread_id, pos = _get_byte4(packet, pos) + + --print("thread id: ", thread_id) + + local scramble = sub(packet, pos, pos + 8 - 1) + if not scramble then + return nil, "1st part of scramble not found" + end + + pos = pos + 9 -- skip filler + + -- two lower bytes + local capabilities -- server capabilities + capabilities, pos = _get_byte2(packet, pos) + + -- print(format("server capabilities: %#x", capabilities)) + + self._server_lang = strbyte(packet, pos) + pos = pos + 1 + + --print("server lang: ", self._server_lang) + + self._server_status, pos = _get_byte2(packet, pos) + + --print("server status: ", self._server_status) + + local more_capabilities + more_capabilities, pos = _get_byte2(packet, pos) + + capabilities = bor(capabilities, lshift(more_capabilities, 16)) + + --print("server capabilities: ", capabilities) + + -- local len = strbyte(packet, pos) + local len = 21 - 8 - 1 + + --print("scramble len: ", len) + + pos = pos + 1 + 10 + + local scramble_part2 = sub(packet, pos, pos + len - 1) + if not scramble_part2 then + return nil, "2nd part of scramble not found" + end + + scramble = scramble .. scramble_part2 + --print("scramble: ", _dump(scramble)) + + local client_flags = 0x3f7cf; + + local ssl_verify = opts.ssl_verify + local use_ssl = opts.ssl or ssl_verify + + if use_ssl then + if band(capabilities, CLIENT_SSL) == 0 then + return nil, "ssl disabled on server" + end + + -- send a SSL Request Packet + local req = _set_byte4(bor(client_flags, CLIENT_SSL)) + .. _set_byte4(self._max_packet_size) + .. strchar(charset) + .. strrep("\0", 23) + + local packet_len = 4 + 4 + 1 + 23 + local bytes, err = _send_packet(self, req, packet_len) + if not bytes then + return nil, "failed to send client authentication packet: " .. err + end + + local ok, err = sock:sslhandshake(false, nil, ssl_verify) + if not ok then + return nil, "failed to do ssl handshake: " .. (err or "") + end + end + + local password = opts.password or "" + + local token = _compute_token(password, scramble) + + --print("token: ", _dump(token)) + + local req = _set_byte4(client_flags) + .. _set_byte4(self._max_packet_size) + .. strchar(charset) + .. strrep("\0", 23) + .. _to_cstring(user) + .. _to_binary_coded_string(token) + .. _to_cstring(database) + + local packet_len = 4 + 4 + 1 + 23 + #user + 1 + + #token + 1 + #database + 1 + + -- print("packet content length: ", packet_len) + -- print("packet content: ", _dump(concat(req, ""))) + + local bytes, err = _send_packet(self, req, packet_len) + if not bytes then + return nil, "failed to send client authentication packet: " .. err + end + + --print("packet sent ", bytes, " bytes") + + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, "failed to receive the result packet: " .. err + end + + if typ == 'ERR' then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end + + if typ == 'EOF' then + return nil, "old pre-4.1 authentication protocol not supported" + end + + if typ ~= 'OK' then + return nil, "bad packet type: " .. typ + end + + self.state = STATE_CONNECTED + + return 1 +end + + +function _M.set_keepalive(self, ...) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + if self.state ~= STATE_CONNECTED then + return nil, "cannot be reused in the current connection state: " + .. (self.state or "nil") + end + + self.state = nil + return sock:setkeepalive(...) +end + + +function _M.get_reused_times(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:getreusedtimes() +end + + +function _M.close(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + self.state = nil + + local bytes, err = _send_packet(self, strchar(COM_QUIT), 1) + if not bytes then + return nil, err + end + + return sock:close() +end + + +function _M.server_ver(self) + return self._server_ver +end + + +local function send_query(self, query) + if self.state ~= STATE_CONNECTED then + return nil, "cannot send query in the current context: " + .. (self.state or "nil") + end + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + self.packet_no = -1 + + local cmd_packet = strchar(COM_QUERY) .. query + local packet_len = 1 + #query + + local bytes, err = _send_packet(self, cmd_packet, packet_len) + if not bytes then + return nil, err + end + + self.state = STATE_COMMAND_SENT + + --print("packet sent ", bytes, " bytes") + + return bytes +end +_M.send_query = send_query + + +local function read_result(self, est_nrows) + if self.state ~= STATE_COMMAND_SENT then + return nil, "cannot read result in the current context: " + .. (self.state or "nil") + end + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ == "ERR" then + self.state = STATE_CONNECTED + + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end + + if typ == 'OK' then + local res = _parse_ok_packet(packet) + if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return res, "again" + end + + self.state = STATE_CONNECTED + return res + end + + if typ ~= 'DATA' then + self.state = STATE_CONNECTED + + return nil, "packet type " .. typ .. " not supported" + end + + -- typ == 'DATA' + + --print("read the result set header packet") + + local field_count, extra = _parse_result_set_header_packet(packet) + + --print("field count: ", field_count) + + local cols = new_tab(field_count, 0) + for i = 1, field_count do + local col, err, errno, sqlstate = _recv_field_packet(self) + if not col then + return nil, err, errno, sqlstate + end + + cols[i] = col + end + + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ ~= 'EOF' then + return nil, "unexpected packet type " .. typ .. " while eof packet is " + .. "expected" + end + + -- typ == 'EOF' + + local compact = self.compact + + local rows = new_tab(est_nrows or 4, 0) + local i = 0 + while true do + --print("reading a row") + + packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ == 'EOF' then + local warning_count, status_flags = _parse_eof_packet(packet) + + --print("status flags: ", status_flags) + + if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return rows, "again", cols + end + + break + end + + -- if typ ~= 'DATA' then + -- return nil, 'bad row packet type: ' .. typ + -- end + + -- typ == 'DATA' + + local row = _parse_row_data_packet(packet, cols, compact) + i = i + 1 + rows[i] = row + end + + self.state = STATE_CONNECTED + + return rows, nil, cols +end +_M.read_result = read_result + + +function _M.query(self, query, est_nrows) + local bytes, err = send_query(self, query) + if not bytes then + return nil, "failed to send query: " .. err + end + + return read_result(self, est_nrows) +end + + +function _M.set_compact_arrays(self, value) + self.compact = value +end + +local qmap = { + ['\0' ] = '\\0', + ['\b' ] = '\\b', + ['\n' ] = '\\n', + ['\r' ] = '\\r', + ['\t' ] = '\\t', + ['\26'] = '\\Z', + ['\\' ] = '\\\\', + ['\'' ] = '\\\'', + ['\"' ] = '\\"', +} +function _M.quote(s) + return s:gsub('[%z\b\n\r\t\26\\\'\"]', qmap) +end + +return _M