diff --git a/mysql_client.lua b/mysql_client.lua index 652e379..19e540a 100644 --- a/mysql_client.lua +++ b/mysql_client.lua @@ -3,7 +3,7 @@ -- Written by Yichun Zhang (agentzh). BSD license. local tcp = require'sock'.tcp -local sha1 = require'sha1' +local sha1 = require'sha1'.sha1 local bit = require'bit' local sub = string.sub @@ -165,7 +165,7 @@ end local function _from_cstring(data, i) - local last = strfind(data, "\0", i, true) + local last = strfind(data, '\0', i, true) if not last then return nil, nil end @@ -175,7 +175,7 @@ end local function _to_cstring(data) - return data .. "\0" + return data .. '\0' end @@ -188,9 +188,9 @@ local function _dump(data) local len = #data local bytes = new_tab(len, 0) for i = 1, len do - bytes[i] = format("%x", strbyte(data, i)) + bytes[i] = format('%x', strbyte(data, i)) end - return concat(bytes, " ") + return concat(bytes, ' ') end @@ -200,13 +200,13 @@ local function _dumphex(data) for i = 1, len do bytes[i] = tohex(strbyte(data, i), 2) end - return concat(bytes, " ") + return concat(bytes, ' ') end local function _compute_token(password, scramble) - if password == "" then - return "" + if password == '' then + return '' end local stage1 = sha1(password) @@ -227,68 +227,101 @@ local function _send_packet(self, req, size) self.packet_no = self.packet_no + 1 - -- print("packet no: ", self.packet_no) + -- print('packet no: ', self.packet_no) local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req - -- print("sending packet: ", _dump(packet)) + -- print('sending packet: ', _dump(packet)) - -- print("sending packet... of size " .. #packet) + -- print('sending packet... of size ' .. #packet) return sock:send(packet) end +--static, auto-growing buffer allocation pattern (ctype must be vla). +local function grow_buffer(ctype) + local vla = ffi.typeof(ctype) + local buf, len = nil, -1 + return function(minlen) + if minlen == false then + buf, len = nil, -1 + elseif minlen > len then + len = glue.nextpow2(minlen) + buf = vla(len) + end + return buf, len + end +end + +local function _recv(self, sz) + local buf = self.buf + if not buf then + buf = grow_buffer'char[?]' + self.buf = buf + end + local buf = buf(sz) + local sock = self.sock + local offset = 0 + while sz > 0 do + local n, err = sock:recv(buf + offset, sz) + if not n then return nil, err end + sz = sz - n + offset = offset + n + end + return ffi.string(buf, offset) +end + local function _recv_packet(self) local sock = self.sock - local data, err = sock:receive(4) -- packet header + local data, err = _recv(self, 4) -- packet header if not data then - return nil, nil, "failed to receive packet header: " .. err + return nil, nil, 'failed to receive packet header: ' .. err end - --print("packet header: ", _dump(data)) + --print('packet header: ', _dump(data)) local len, pos = _get_byte3(data, 1) - --print("packet length: ", len) + --print('packet length: ', len) if len == 0 then - return nil, nil, "empty packet" + return nil, nil, 'empty packet' end if len > self._max_packet_size then - return nil, nil, "packet size too big: " .. len + return nil, nil, 'packet size too big: ' .. len end local num = strbyte(data, pos) - --print("recv packet: packet no: ", num) + --print('recv packet: packet no: ', num) self.packet_no = num - data, err = sock:receive(len) + data, err = _recv(self, len) - --print("receive returned") + --print('receive returned') if not data then - return nil, nil, "failed to read packet content: " .. err + return nil, nil, 'failed to read packet content: ' .. err end - --print("packet content: ", _dump(data)) - --print("packet content (ascii): ", data) + --print('packet content: ', _dump(data)) + --print('packet content (ascii): ', data) local field_count = strbyte(data, 1) local typ if field_count == 0x00 then - typ = "OK" + typ = 'OK' elseif field_count == 0xff then - typ = "ERR" + typ = 'ERR' elseif field_count == 0xfe then - typ = "EOF" + typ = 'EOF' else - typ = "DATA" + typ = 'DATA' end return data, typ @@ -298,7 +331,7 @@ end local function _from_length_coded_bin(data, pos) local first = strbyte(data, pos) - --print("LCB: first: ", first) + --print('LCB: first: ', first) if not first then return nil, pos @@ -348,26 +381,26 @@ local function _parse_ok_packet(packet) res.affected_rows, pos = _from_length_coded_bin(packet, 2) - --print("affected rows: ", res.affected_rows, ", pos:", pos) + --print('affected rows: ', res.affected_rows, ', pos:', pos) res.insert_id, pos = _from_length_coded_bin(packet, pos) - --print("insert id: ", res.insert_id, ", pos:", pos) + --print('insert id: ', res.insert_id, ', pos:', pos) res.server_status, pos = _get_byte2(packet, pos) - --print("server status: ", res.server_status, ", pos:", pos) + --print('server status: ', res.server_status, ', pos:', pos) res.warning_count, pos = _get_byte2(packet, pos) - --print("warning count: ", res.warning_count, ", pos: ", pos) + --print('warning count: ', res.warning_count, ', pos: ', pos) local message = _from_length_coded_str(packet, pos) if message and message ~= null then res.message = message end - --print("message: ", res.message, ", pos:", pos) + --print('message: ', res.message, ', pos:', pos) return res end @@ -437,7 +470,7 @@ local function _parse_field_packet(data) col.decimals = strbyte(data, pos) pos = pos + 1 local default = sub(data, pos + 2) - if default and default ~= "" then + if default and default ~= '' then col.default = default end return col @@ -460,7 +493,7 @@ local function _parse_row_data_packet(data, cols, compact) local typ = col.type local name = col.name - --print("row field value: ", value, ", type: ", typ) + --print('row field value: ', value, ', type: ', typ) if value ~= null then local conv = converters[typ] @@ -486,13 +519,13 @@ local function _recv_field_packet(self) return nil, err end - if typ == "ERR" then + if typ == 'ERR' then local errno, msg, sqlstate = _parse_err_packet(packet) return nil, msg, errno, sqlstate end if typ ~= 'DATA' then - return nil, "bad field packet type: " .. typ + return nil, 'bad field packet type: ' .. typ end -- typ == 'DATA' @@ -513,7 +546,7 @@ end function _M.connect(self, opts) local sock = self.sock if not sock then - return nil, "not initialized" + return nil, 'not initialized' end local max_packet_size = opts.max_packet_size @@ -526,79 +559,52 @@ function _M.connect(self, opts) self.compact = opts.compact_arrays - local database = opts.database or "" - local user = opts.user or "" + local database = opts.database or '' + local user = opts.user or '' - local charset = CHARSET_MAP[opts.charset or "_default"] + local charset = CHARSET_MAP[opts.charset or '_default'] if not charset then - return nil, "charset '" .. opts.charset .. "' is not supported" + return nil, 'charset \'' .. opts.charset .. '\' is not supported' end - local pool = opts.pool - local host = opts.host - if host then - local port = opts.port or 3306 - if not pool then - pool = user .. ":" .. database .. ":" .. host .. ":" .. port - end - - ok, err = sock:connect(host, port, { pool = pool }) - - else - local path = opts.path - if not path then - return nil, 'neither "host" nor "path" options are specified' - end - - if not pool then - pool = user .. ":" .. database .. ":" .. path - end - - ok, err = sock:connect("unix:" .. path, { pool = pool }) - end + local port = opts.port or 3306 + ok, err = sock:connect(host, port) if not ok then return nil, 'failed to connect: ' .. err end - local reused = sock:getreusedtimes() - - if reused and reused > 0 then - self.state = STATE_CONNECTED - return 1 - end - local packet, typ, err = _recv_packet(self) if not packet then return nil, err end - if typ == "ERR" then + if typ == 'ERR' then local errno, msg, sqlstate = _parse_err_packet(packet) return nil, msg, errno, sqlstate end self.protocol_ver = strbyte(packet) - --print("protocol version: ", self.protocol_ver) + --print('protocol version: ', self.protocol_ver) local server_ver, pos = _from_cstring(packet, 2) if not server_ver then - return nil, "bad handshake initialization packet: bad server version" + return nil, 'bad handshake initialization packet: bad server version' end - --print("server version: ", server_ver) + --print('server version: ', server_ver) self._server_ver = server_ver local thread_id, pos = _get_byte4(packet, pos) - --print("thread id: ", thread_id) + --print('thread id: ', thread_id) local scramble = sub(packet, pos, pos + 8 - 1) if not scramble then - return nil, "1st part of scramble not found" + return nil, '1st part of scramble not found' end pos = pos + 9 -- skip filler @@ -607,38 +613,38 @@ function _M.connect(self, opts) local capabilities -- server capabilities capabilities, pos = _get_byte2(packet, pos) - -- print(format("server capabilities: %#x", capabilities)) + -- print(format('server capabilities: %#x', capabilities)) self._server_lang = strbyte(packet, pos) pos = pos + 1 - --print("server lang: ", self._server_lang) + --print('server lang: ', self._server_lang) self._server_status, pos = _get_byte2(packet, pos) - --print("server status: ", self._server_status) + --print('server status: ', self._server_status) local more_capabilities more_capabilities, pos = _get_byte2(packet, pos) capabilities = bor(capabilities, lshift(more_capabilities, 16)) - --print("server capabilities: ", capabilities) + --print('server capabilities: ', capabilities) -- local len = strbyte(packet, pos) local len = 21 - 8 - 1 - --print("scramble len: ", len) + --print('scramble len: ', len) pos = pos + 1 + 10 local scramble_part2 = sub(packet, pos, pos + len - 1) if not scramble_part2 then - return nil, "2nd part of scramble not found" + return nil, '2nd part of scramble not found' end scramble = scramble .. scramble_part2 - --print("scramble: ", _dump(scramble)) + --print('scramble: ', _dump(scramble)) local client_flags = 0x3f7cf; @@ -647,37 +653,37 @@ function _M.connect(self, opts) if use_ssl then if band(capabilities, CLIENT_SSL) == 0 then - return nil, "ssl disabled on server" + return nil, 'ssl disabled on server' end -- send a SSL Request Packet local req = _set_byte4(bor(client_flags, CLIENT_SSL)) .. _set_byte4(self._max_packet_size) .. strchar(charset) - .. strrep("\0", 23) + .. strrep('\0', 23) local packet_len = 4 + 4 + 1 + 23 local bytes, err = _send_packet(self, req, packet_len) if not bytes then - return nil, "failed to send client authentication packet: " .. err + return nil, 'failed to send client authentication packet: ' .. err end local ok, err = sock:sslhandshake(false, nil, ssl_verify) if not ok then - return nil, "failed to do ssl handshake: " .. (err or "") + return nil, 'failed to do ssl handshake: ' .. (err or '') end end - local password = opts.password or "" + local password = opts.password or '' local token = _compute_token(password, scramble) - --print("token: ", _dump(token)) + --print('token: ', _dump(token)) local req = _set_byte4(client_flags) .. _set_byte4(self._max_packet_size) .. strchar(charset) - .. strrep("\0", 23) + .. strrep('\0', 23) .. _to_cstring(user) .. _to_binary_coded_string(token) .. _to_cstring(database) @@ -685,19 +691,19 @@ function _M.connect(self, opts) local packet_len = 4 + 4 + 1 + 23 + #user + 1 + #token + 1 + #database + 1 - -- print("packet content length: ", packet_len) - -- print("packet content: ", _dump(concat(req, ""))) + -- print('packet content length: ', packet_len) + -- print('packet content: ', _dump(concat(req, ''))) local bytes, err = _send_packet(self, req, packet_len) if not bytes then - return nil, "failed to send client authentication packet: " .. err + return nil, 'failed to send client authentication packet: ' .. err end - --print("packet sent ", bytes, " bytes") + --print('packet sent ', bytes, ' bytes') local packet, typ, err = _recv_packet(self) if not packet then - return nil, "failed to receive the result packet: " .. err + return nil, 'failed to receive the result packet: ' .. err end if typ == 'ERR' then @@ -706,11 +712,11 @@ function _M.connect(self, opts) end if typ == 'EOF' then - return nil, "old pre-4.1 authentication protocol not supported" + return nil, 'old pre-4.1 authentication protocol not supported' end if typ ~= 'OK' then - return nil, "bad packet type: " .. typ + return nil, 'bad packet type: ' .. typ end self.state = STATE_CONNECTED @@ -718,37 +724,10 @@ function _M.connect(self, opts) return 1 end - -function _M.set_keepalive(self, ...) - local sock = self.sock - if not sock then - return nil, "not initialized" - end - - if self.state ~= STATE_CONNECTED then - return nil, "cannot be reused in the current connection state: " - .. (self.state or "nil") - end - - self.state = nil - return sock:setkeepalive(...) -end - - -function _M.get_reused_times(self) - local sock = self.sock - if not sock then - return nil, "not initialized" - end - - return sock:getreusedtimes() -end - - function _M.close(self) local sock = self.sock if not sock then - return nil, "not initialized" + return nil, 'not initialized' end self.state = nil @@ -761,21 +740,19 @@ function _M.close(self) return sock:close() end - function _M.server_ver(self) return self._server_ver end - local function send_query(self, query) if self.state ~= STATE_CONNECTED then - return nil, "cannot send query in the current context: " - .. (self.state or "nil") + return nil, 'cannot send query in the current context: ' + .. (self.state or 'nil') end local sock = self.sock if not sock then - return nil, "not initialized" + return nil, 'not initialized' end self.packet_no = -1 @@ -790,22 +767,21 @@ local function send_query(self, query) self.state = STATE_COMMAND_SENT - --print("packet sent ", bytes, " bytes") + --print('packet sent ', bytes, ' bytes') return bytes end _M.send_query = send_query - local function read_result(self, est_nrows) if self.state ~= STATE_COMMAND_SENT then - return nil, "cannot read result in the current context: " - .. (self.state or "nil") + return nil, 'cannot read result in the current context: ' + .. (self.state or 'nil') end local sock = self.sock if not sock then - return nil, "not initialized" + return nil, 'not initialized' end local packet, typ, err = _recv_packet(self) @@ -813,7 +789,7 @@ local function read_result(self, est_nrows) return nil, err end - if typ == "ERR" then + if typ == 'ERR' then self.state = STATE_CONNECTED local errno, msg, sqlstate = _parse_err_packet(packet) @@ -823,7 +799,7 @@ local function read_result(self, est_nrows) if typ == 'OK' then local res = _parse_ok_packet(packet) if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then - return res, "again" + return res, 'again' end self.state = STATE_CONNECTED @@ -833,16 +809,16 @@ local function read_result(self, est_nrows) if typ ~= 'DATA' then self.state = STATE_CONNECTED - return nil, "packet type " .. typ .. " not supported" + return nil, 'packet type ' .. typ .. ' not supported' end -- typ == 'DATA' - --print("read the result set header packet") + --print('read the result set header packet') local field_count, extra = _parse_result_set_header_packet(packet) - --print("field count: ", field_count) + --print('field count: ', field_count) local cols = new_tab(field_count, 0) for i = 1, field_count do @@ -860,8 +836,8 @@ local function read_result(self, est_nrows) end if typ ~= 'EOF' then - return nil, "unexpected packet type " .. typ .. " while eof packet is " - .. "expected" + return nil, 'unexpected packet type ' .. typ .. ' while eof packet is ' + .. 'expected' end -- typ == 'EOF' @@ -871,7 +847,7 @@ local function read_result(self, est_nrows) local rows = new_tab(est_nrows or 4, 0) local i = 0 while true do - --print("reading a row") + --print('reading a row') packet, typ, err = _recv_packet(self) if not packet then @@ -881,10 +857,10 @@ local function read_result(self, est_nrows) if typ == 'EOF' then local warning_count, status_flags = _parse_eof_packet(packet) - --print("status flags: ", status_flags) + --print('status flags: ', status_flags) if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then - return rows, "again", cols + return rows, 'again', cols end break @@ -907,17 +883,15 @@ local function read_result(self, est_nrows) end _M.read_result = read_result - function _M.query(self, query, est_nrows) local bytes, err = send_query(self, query) if not bytes then - return nil, "failed to send query: " .. err + return nil, 'failed to send query: ' .. err end return read_result(self, est_nrows) end - function _M.set_compact_arrays(self, value) self.compact = value end