-- MySQL client protocol in Lua. -- Written by Yichun Zhang (agentzh). BSD license. -- Modified by Cosmin Apreutesei. Pulbic domain. local sha1 = require'sha1'.sha1 local bit = require'bit' local sub = string.sub local strbyte = string.byte local strchar = string.char local strfind = string.find local format = string.format local strrep = string.rep local null = function() end local band = bit.band local bxor = bit.bxor local bor = bit.bor local lshift = bit.lshift local rshift = bit.rshift local tohex = bit.tohex local concat = table.concat local unpack = unpack local setmetatable = setmetatable local error = error local tonumber = tonumber local ok, new_tab = pcall(require, 'table.new') new_tab = ok and new_tab or function() return {} end local mysql = {} -- constants local STATE_CONNECTED = 1 local STATE_COMMAND_SENT = 2 local COM_QUIT = 0x01 local COM_QUERY = 0x03 local CLIENT_SSL = 0x0800 local SERVER_MORE_RESULTS_EXISTS = 8 -- 16MB - 1, the default max allowed packet size used by libmysqlclient local FULL_PACKET_SIZE = 16777215 -- the following charset map is generated from the following mysql query: -- SELECT CHARACTER_SET_NAME, ID -- FROM information_schema.collations -- WHERE IS_DEFAULT = 'Yes' ORDER BY id; local CHARSET_MAP = { _default = 0, big5 = 1, dec8 = 3, cp850 = 4, hp8 = 6, koi8r = 7, latin1 = 8, latin2 = 9, swe7 = 10, ascii = 11, ujis = 12, sjis = 13, hebrew = 16, tis620 = 18, euckr = 19, koi8u = 22, gb2312 = 24, greek = 25, cp1250 = 26, gbk = 28, latin5 = 30, armscii8 = 32, utf8 = 33, ucs2 = 35, cp866 = 36, keybcs2 = 37, macce = 38, macroman = 39, cp852 = 40, latin7 = 41, utf8mb4 = 45, cp1251 = 51, utf16 = 54, utf16le = 56, cp1256 = 57, cp1257 = 59, utf32 = 60, binary = 63, geostd8 = 92, cp932 = 95, eucjpms = 97, gb18030 = 248 } local conn = {} local mt = {__index = conn} -- mysql field value type converters local converters = new_tab(0, 9) for i = 0x01, 0x05 do -- tiny, short, long, float, double converters[i] = tonumber end converters[0x00] = tonumber -- decimal -- converters[0x08] = tonumber -- long long converters[0x09] = tonumber -- int24 converters[0x0d] = tonumber -- year converters[0xf6] = tonumber -- newdecimal local function _get_byte2(data, i) local a, b = strbyte(data, i, i + 1) return bor(a, lshift(b, 8)), i + 2 end local function _get_byte3(data, i) local a, b, c = strbyte(data, i, i + 2) return bor(a, lshift(b, 8), lshift(c, 16)), i + 3 end local function _get_byte4(data, i) local a, b, c, d = strbyte(data, i, i + 3) return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4 end local function _get_byte8(data, i) local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7) -- XXX workaround for the lack of 64-bit support in bitop: local lo = bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)) local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24)) return lo + hi * 4294967296, i + 8 -- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32), -- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8 end local function _set_byte2(n) return strchar(band(n, 0xff), band(rshift(n, 8), 0xff)) end local function _set_byte3(n) return strchar(band(n, 0xff), band(rshift(n, 8), 0xff), band(rshift(n, 16), 0xff)) end local function _set_byte4(n) return strchar(band(n, 0xff), band(rshift(n, 8), 0xff), band(rshift(n, 16), 0xff), band(rshift(n, 24), 0xff)) end local function _from_cstring(data, i) local last = strfind(data, '\0', i, true) if not last then return nil, nil end return sub(data, i, last - 1), last + 1 end local function _to_cstring(data) return data .. '\0' end local function _to_binary_coded_string(data) return strchar(#data) .. data end local function _dump(data) local len = #data local bytes = new_tab(len, 0) for i = 1, len do bytes[i] = format('%x', strbyte(data, i)) end return concat(bytes, ' ') end local function _dumphex(data) local len = #data local bytes = new_tab(len, 0) for i = 1, len do bytes[i] = tohex(strbyte(data, i), 2) end return concat(bytes, ' ') end local function _compute_token(password, scramble) if password == '' then return '' end local stage1 = sha1(password) local stage2 = sha1(stage1) local stage3 = sha1(scramble .. stage2) local n = #stage1 local bytes = new_tab(n, 0) for i = 1, n do bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i))) end return concat(bytes) end local function _send_packet(self, req, size) local sock = self.sock self.packet_no = self.packet_no + 1 -- print('packet no: ', self.packet_no) local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req -- print('sending packet: ', _dump(packet)) -- print('sending packet... of size ' .. #packet) return sock:send(packet) end --static, auto-growing buffer allocation pattern (ctype must be vla). local max, ceil, log = math.max, math.ceil, math.log local function nextpow2(x) return max(0, 2^(ceil(log(x) / log(2)))) end local function grow_buffer(ctype) local vla = ffi.typeof(ctype) local buf, len = nil, -1 return function(minlen) if minlen == false then buf, len = nil, -1 elseif minlen > len then len = nextpow2(minlen) buf = vla(len) end return buf, len end end local function _recv(self, sz) local buf = self.buf if not buf then buf = grow_buffer'char[?]' self.buf = buf end local buf = buf(sz) local sock = self.sock local offset = 0 while sz > 0 do local n, err = sock:recv(buf + offset, sz) if not n then return nil, err end sz = sz - n offset = offset + n end return ffi.string(buf, offset) end local function _recv_packet(self) local sock = self.sock local data, err = _recv(self, 4) -- packet header if not data then return nil, nil, 'failed to receive packet header: ' .. err end --print('packet header: ', _dump(data)) local len, pos = _get_byte3(data, 1) --print('packet length: ', len) if len == 0 then return nil, nil, 'empty packet' end if len > self._max_packet_size then return nil, nil, 'packet size too big: ' .. len end local num = strbyte(data, pos) --print('recv packet: packet no: ', num) self.packet_no = num data, err = _recv(self, len) --print('receive returned') if not data then return nil, nil, 'failed to read packet content: ' .. err end --print('packet content: ', _dump(data)) --print('packet content (ascii): ', data) local field_count = strbyte(data, 1) local typ if field_count == 0x00 then typ = 'OK' elseif field_count == 0xff then typ = 'ERR' elseif field_count == 0xfe then typ = 'EOF' else typ = 'DATA' end return data, typ end local function _from_length_coded_bin(data, pos) local first = strbyte(data, pos) --print('LCB: first: ', first) if not first then return nil, pos end if first >= 0 and first <= 250 then return first, pos + 1 end if first == 251 then return null, pos + 1 end if first == 252 then pos = pos + 1 return _get_byte2(data, pos) end if first == 253 then pos = pos + 1 return _get_byte3(data, pos) end if first == 254 then pos = pos + 1 return _get_byte8(data, pos) end return nil, pos + 1 end local function _from_length_coded_str(data, pos) local len len, pos = _from_length_coded_bin(data, pos) if not len or len == null then return null, pos end return sub(data, pos, pos + len - 1), pos + len end local function _parse_ok_packet(packet) local res = new_tab(0, 5) local pos res.affected_rows, pos = _from_length_coded_bin(packet, 2) --print('affected rows: ', res.affected_rows, ', pos:', pos) res.insert_id, pos = _from_length_coded_bin(packet, pos) --print('insert id: ', res.insert_id, ', pos:', pos) res.server_status, pos = _get_byte2(packet, pos) --print('server status: ', res.server_status, ', pos:', pos) res.warning_count, pos = _get_byte2(packet, pos) --print('warning count: ', res.warning_count, ', pos: ', pos) local message = _from_length_coded_str(packet, pos) if message and message ~= null then res.message = message end --print('message: ', res.message, ', pos:', pos) return res end local function _parse_eof_packet(packet) local pos = 2 local warning_count, pos = _get_byte2(packet, pos) local status_flags = _get_byte2(packet, pos) return warning_count, status_flags end local function _parse_err_packet(packet) local errno, pos = _get_byte2(packet, 2) local marker = sub(packet, pos, pos) local sqlstate if marker == '#' then -- with sqlstate pos = pos + 1 sqlstate = sub(packet, pos, pos + 5 - 1) pos = pos + 5 end local message = sub(packet, pos) return errno, message, sqlstate end local function _parse_result_set_header_packet(packet) local field_count, pos = _from_length_coded_bin(packet, 1) local extra extra = _from_length_coded_bin(packet, pos) return field_count, extra end local NOT_NULL_FLAG = 1 local PRI_KEY_FLAG = 2 local UNIQUE_KEY_FLAG = 4 local UNSIGNED_FLAG = 32 local AUTO_INCREMENT_FLAG = 512 local function _parse_field_packet(data) local col = new_tab(0, 2) local pos col.catalog, pos = _from_length_coded_str(data, 1) col.db, pos = _from_length_coded_str(data, pos) col.table, pos = _from_length_coded_str(data, pos) col.orig_table, pos = _from_length_coded_str(data, pos) col.name, pos = _from_length_coded_str(data, pos) col.orig_name, pos = _from_length_coded_str(data, pos) pos = pos + 1 -- ignore the filler col.charsetnr, pos = _get_byte2(data, pos) col.length, pos = _get_byte4(data, pos) col.type = strbyte(data, pos) pos = pos + 1 col.flags, pos = _get_byte2(data, pos) col.allow_null = band(col.flags, NOT_NULL_FLAG) == 0 col.pri_key = band(col.flags, PRI_KEY_FLAG) ~= 0 col.unique_key = band(col.flags, UNIQUE_KEY_FLAG) ~= 0 col.unsigned = band(col.flags, UNSIGNED_FLAG) ~= 0 col.auto_increment = band(col.flags, AUTO_INCREMENT_FLAG) ~= 0 col.decimals = strbyte(data, pos) pos = pos + 1 local default = sub(data, pos + 2) if default and default ~= '' then col.default = default end return col end local function _parse_row_data_packet(data, cols, compact) local pos = 1 local ncols = #cols local row if compact then row = new_tab(ncols, 0) else row = new_tab(0, ncols) end for i = 1, ncols do local value value, pos = _from_length_coded_str(data, pos) local col = cols[i] local typ = col.type local name = col.name --print('row field value: ', value, ', type: ', typ) if value ~= null then local conv = converters[typ] if conv then value = conv(value) end end if compact then row[i] = value elseif value ~= null then row[name] = value end end return row end local function _recv_field_packet(self) local packet, typ, err = _recv_packet(self) if not packet then return nil, err end if typ == 'ERR' then local errno, msg, sqlstate = _parse_err_packet(packet) return nil, msg, errno, sqlstate end if typ ~= 'DATA' then return nil, 'bad field packet type: ' .. typ end -- typ == 'DATA' return _parse_field_packet(packet) end function mysql.new(self, opt) local tcp = opt and opt.tcp or require'sock'.tcp local sock, err = tcp() if not sock then return nil, err end return setmetatable({ sock = sock }, mt) end function conn:connect(opts) local sock = self.sock local max_packet_size = opts.max_packet_size if not max_packet_size then max_packet_size = 1024 * 1024 -- default 1 MB end self._max_packet_size = max_packet_size local ok, err local database = opts.database or '' local user = opts.user or '' local charset = CHARSET_MAP[opts.charset or '_default'] if not charset then return nil, 'charset \'' .. opts.charset .. '\' is not supported' end local host = opts.host local port = opts.port or 3306 ok, err = sock:connect(host, port) if not ok then return nil, 'failed to connect: ' .. err end local packet, typ, err = _recv_packet(self) if not packet then return nil, err end if typ == 'ERR' then local errno, msg, sqlstate = _parse_err_packet(packet) return nil, msg, errno, sqlstate end self.protocol_ver = strbyte(packet) --print('protocol version: ', self.protocol_ver) local server_ver, pos = _from_cstring(packet, 2) if not server_ver then return nil, 'bad handshake initialization packet: bad server version' end --print('server version: ', server_ver) self._server_ver = server_ver local thread_id, pos = _get_byte4(packet, pos) --print('thread id: ', thread_id) local scramble = sub(packet, pos, pos + 8 - 1) if not scramble then return nil, '1st part of scramble not found' end pos = pos + 9 -- skip filler -- two lower bytes local capabilities -- server capabilities capabilities, pos = _get_byte2(packet, pos) -- print(format('server capabilities: %#x', capabilities)) self._server_lang = strbyte(packet, pos) pos = pos + 1 --print('server lang: ', self._server_lang) self._server_status, pos = _get_byte2(packet, pos) --print('server status: ', self._server_status) local more_capabilities more_capabilities, pos = _get_byte2(packet, pos) capabilities = bor(capabilities, lshift(more_capabilities, 16)) --print('server capabilities: ', capabilities) -- local len = strbyte(packet, pos) local len = 21 - 8 - 1 --print('scramble len: ', len) pos = pos + 1 + 10 local scramble_part2 = sub(packet, pos, pos + len - 1) if not scramble_part2 then return nil, '2nd part of scramble not found' end scramble = scramble .. scramble_part2 --print('scramble: ', _dump(scramble)) local client_flags = 0x3f7cf; local ssl_verify = opts.ssl_verify local use_ssl = opts.ssl or ssl_verify if use_ssl then if band(capabilities, CLIENT_SSL) == 0 then return nil, 'ssl disabled on server' end -- send a SSL Request Packet local req = _set_byte4(bor(client_flags, CLIENT_SSL)) .. _set_byte4(self._max_packet_size) .. strchar(charset) .. strrep('\0', 23) local packet_len = 4 + 4 + 1 + 23 local bytes, err = _send_packet(self, req, packet_len) if not bytes then return nil, 'failed to send client authentication packet: ' .. err end local ok, err = sock:sslhandshake(false, nil, ssl_verify) if not ok then return nil, 'failed to do ssl handshake: ' .. (err or '') end end local password = opts.password or '' local token = _compute_token(password, scramble) --print('token: ', _dump(token)) local req = _set_byte4(client_flags) .. _set_byte4(self._max_packet_size) .. strchar(charset) .. strrep('\0', 23) .. _to_cstring(user) .. _to_binary_coded_string(token) .. _to_cstring(database) local packet_len = 4 + 4 + 1 + 23 + #user + 1 + #token + 1 + #database + 1 -- print('packet content length: ', packet_len) -- print('packet content: ', _dump(concat(req, ''))) local bytes, err = _send_packet(self, req, packet_len) if not bytes then return nil, 'failed to send client authentication packet: ' .. err end --print('packet sent ', bytes, ' bytes') local packet, typ, err = _recv_packet(self) if not packet then return nil, 'failed to receive the result packet: ' .. err end if typ == 'ERR' then local errno, msg, sqlstate = _parse_err_packet(packet) return nil, msg, errno, sqlstate end if typ == 'EOF' then return nil, 'old pre-4.1 authentication protocol not supported' end if typ ~= 'OK' then return nil, 'bad packet type: ' .. typ end self.state = STATE_CONNECTED return 1 end function conn:close() local sock = assert(self.sock) self.state = nil local bytes, err = _send_packet(self, strchar(COM_QUIT), 1) if not bytes then return nil, err end return sock:close() end function conn:server_ver() return self._server_ver end function conn:send_query(query) assert(self.state == STATE_CONNECTED) local sock = assert(self.sock) self.packet_no = -1 local cmd_packet = strchar(COM_QUERY) .. query local packet_len = 1 + #query local bytes, err = _send_packet(self, cmd_packet, packet_len) if not bytes then return nil, err end self.state = STATE_COMMAND_SENT --print('packet sent ', bytes, ' bytes') return bytes end function conn:read_result(est_nrows, compact) assert(self.state == STATE_COMMAND_SENT) local sock = assert(self.sock) compact = compact == 'compact' if est_nrows == 'compact' then est_nrows = null compact = true end local packet, typ, err = _recv_packet(self) if not packet then return nil, err end if typ == 'ERR' then self.state = STATE_CONNECTED local errno, msg, sqlstate = _parse_err_packet(packet) return nil, msg, errno, sqlstate end if typ == 'OK' then local res = _parse_ok_packet(packet) if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then return res, 'again' end self.state = STATE_CONNECTED return res end if typ ~= 'DATA' then self.state = STATE_CONNECTED return nil, 'packet type ' .. typ .. ' not supported' end -- typ == 'DATA' --print('read the result set header packet') local field_count, extra = _parse_result_set_header_packet(packet) --print('field count: ', field_count) local cols = new_tab(field_count, 0) for i = 1, field_count do local col, err, errno, sqlstate = _recv_field_packet(self) if not col then return nil, err, errno, sqlstate end cols[i] = col end local packet, typ, err = _recv_packet(self) if not packet then return nil, err end if typ ~= 'EOF' then return nil, 'unexpected packet type ' .. typ .. ' while eof packet is ' .. 'expected' end -- typ == 'EOF' local rows = new_tab(est_nrows or 4, 0) local i = 0 while true do --print('reading a row') packet, typ, err = _recv_packet(self) if not packet then return nil, err end if typ == 'EOF' then local warning_count, status_flags = _parse_eof_packet(packet) --print('status flags: ', status_flags) if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then return rows, 'again', cols end break end -- if typ ~= 'DATA' then -- return nil, 'bad row packet type: ' .. typ -- end -- typ == 'DATA' local row = _parse_row_data_packet(packet, cols, compact) i = i + 1 rows[i] = row end self.state = STATE_CONNECTED return rows, nil, cols end function conn:query(query, est_nrows) local bytes, err, errcode = self:send_query(query) if not bytes then return nil, err, errcode end return self:read_result(est_nrows) end local qmap = { ['\0' ] = '\\0', ['\b' ] = '\\b', ['\n' ] = '\\n', ['\r' ] = '\\r', ['\t' ] = '\\t', ['\26'] = '\\Z', ['\\' ] = '\\\\', ['\'' ] = '\\\'', ['\"' ] = '\\"', } function mysql.quote(s) return s:gsub('[%z\b\n\r\t\26\\\'\"]', qmap) end return mysql