unimportant

This commit is contained in:
Cosmin Apreutesei 2021-04-30 01:27:35 +03:00
parent cb7ae6e888
commit 7e7d7501eb
1 changed files with 123 additions and 149 deletions

View File

@ -3,7 +3,7 @@
-- Written by Yichun Zhang (agentzh). BSD license. -- Written by Yichun Zhang (agentzh). BSD license.
local tcp = require'sock'.tcp local tcp = require'sock'.tcp
local sha1 = require'sha1' local sha1 = require'sha1'.sha1
local bit = require'bit' local bit = require'bit'
local sub = string.sub local sub = string.sub
@ -165,7 +165,7 @@ end
local function _from_cstring(data, i) local function _from_cstring(data, i)
local last = strfind(data, "\0", i, true) local last = strfind(data, '\0', i, true)
if not last then if not last then
return nil, nil return nil, nil
end end
@ -175,7 +175,7 @@ end
local function _to_cstring(data) local function _to_cstring(data)
return data .. "\0" return data .. '\0'
end end
@ -188,9 +188,9 @@ local function _dump(data)
local len = #data local len = #data
local bytes = new_tab(len, 0) local bytes = new_tab(len, 0)
for i = 1, len do for i = 1, len do
bytes[i] = format("%x", strbyte(data, i)) bytes[i] = format('%x', strbyte(data, i))
end end
return concat(bytes, " ") return concat(bytes, ' ')
end end
@ -200,13 +200,13 @@ local function _dumphex(data)
for i = 1, len do for i = 1, len do
bytes[i] = tohex(strbyte(data, i), 2) bytes[i] = tohex(strbyte(data, i), 2)
end end
return concat(bytes, " ") return concat(bytes, ' ')
end end
local function _compute_token(password, scramble) local function _compute_token(password, scramble)
if password == "" then if password == '' then
return "" return ''
end end
local stage1 = sha1(password) local stage1 = sha1(password)
@ -227,68 +227,101 @@ local function _send_packet(self, req, size)
self.packet_no = self.packet_no + 1 self.packet_no = self.packet_no + 1
-- print("packet no: ", self.packet_no) -- print('packet no: ', self.packet_no)
local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req
-- print("sending packet: ", _dump(packet)) -- print('sending packet: ', _dump(packet))
-- print("sending packet... of size " .. #packet) -- print('sending packet... of size ' .. #packet)
return sock:send(packet) return sock:send(packet)
end end
--static, auto-growing buffer allocation pattern (ctype must be vla).
local function grow_buffer(ctype)
local vla = ffi.typeof(ctype)
local buf, len = nil, -1
return function(minlen)
if minlen == false then
buf, len = nil, -1
elseif minlen > len then
len = glue.nextpow2(minlen)
buf = vla(len)
end
return buf, len
end
end
local function _recv(self, sz)
local buf = self.buf
if not buf then
buf = grow_buffer'char[?]'
self.buf = buf
end
local buf = buf(sz)
local sock = self.sock
local offset = 0
while sz > 0 do
local n, err = sock:recv(buf + offset, sz)
if not n then return nil, err end
sz = sz - n
offset = offset + n
end
return ffi.string(buf, offset)
end
local function _recv_packet(self) local function _recv_packet(self)
local sock = self.sock local sock = self.sock
local data, err = sock:receive(4) -- packet header local data, err = _recv(self, 4) -- packet header
if not data then if not data then
return nil, nil, "failed to receive packet header: " .. err return nil, nil, 'failed to receive packet header: ' .. err
end end
--print("packet header: ", _dump(data)) --print('packet header: ', _dump(data))
local len, pos = _get_byte3(data, 1) local len, pos = _get_byte3(data, 1)
--print("packet length: ", len) --print('packet length: ', len)
if len == 0 then if len == 0 then
return nil, nil, "empty packet" return nil, nil, 'empty packet'
end end
if len > self._max_packet_size then if len > self._max_packet_size then
return nil, nil, "packet size too big: " .. len return nil, nil, 'packet size too big: ' .. len
end end
local num = strbyte(data, pos) local num = strbyte(data, pos)
--print("recv packet: packet no: ", num) --print('recv packet: packet no: ', num)
self.packet_no = num self.packet_no = num
data, err = sock:receive(len) data, err = _recv(self, len)
--print("receive returned") --print('receive returned')
if not data then if not data then
return nil, nil, "failed to read packet content: " .. err return nil, nil, 'failed to read packet content: ' .. err
end end
--print("packet content: ", _dump(data)) --print('packet content: ', _dump(data))
--print("packet content (ascii): ", data) --print('packet content (ascii): ', data)
local field_count = strbyte(data, 1) local field_count = strbyte(data, 1)
local typ local typ
if field_count == 0x00 then if field_count == 0x00 then
typ = "OK" typ = 'OK'
elseif field_count == 0xff then elseif field_count == 0xff then
typ = "ERR" typ = 'ERR'
elseif field_count == 0xfe then elseif field_count == 0xfe then
typ = "EOF" typ = 'EOF'
else else
typ = "DATA" typ = 'DATA'
end end
return data, typ return data, typ
@ -298,7 +331,7 @@ end
local function _from_length_coded_bin(data, pos) local function _from_length_coded_bin(data, pos)
local first = strbyte(data, pos) local first = strbyte(data, pos)
--print("LCB: first: ", first) --print('LCB: first: ', first)
if not first then if not first then
return nil, pos return nil, pos
@ -348,26 +381,26 @@ local function _parse_ok_packet(packet)
res.affected_rows, pos = _from_length_coded_bin(packet, 2) res.affected_rows, pos = _from_length_coded_bin(packet, 2)
--print("affected rows: ", res.affected_rows, ", pos:", pos) --print('affected rows: ', res.affected_rows, ', pos:', pos)
res.insert_id, pos = _from_length_coded_bin(packet, pos) res.insert_id, pos = _from_length_coded_bin(packet, pos)
--print("insert id: ", res.insert_id, ", pos:", pos) --print('insert id: ', res.insert_id, ', pos:', pos)
res.server_status, pos = _get_byte2(packet, pos) res.server_status, pos = _get_byte2(packet, pos)
--print("server status: ", res.server_status, ", pos:", pos) --print('server status: ', res.server_status, ', pos:', pos)
res.warning_count, pos = _get_byte2(packet, pos) res.warning_count, pos = _get_byte2(packet, pos)
--print("warning count: ", res.warning_count, ", pos: ", pos) --print('warning count: ', res.warning_count, ', pos: ', pos)
local message = _from_length_coded_str(packet, pos) local message = _from_length_coded_str(packet, pos)
if message and message ~= null then if message and message ~= null then
res.message = message res.message = message
end end
--print("message: ", res.message, ", pos:", pos) --print('message: ', res.message, ', pos:', pos)
return res return res
end end
@ -437,7 +470,7 @@ local function _parse_field_packet(data)
col.decimals = strbyte(data, pos) col.decimals = strbyte(data, pos)
pos = pos + 1 pos = pos + 1
local default = sub(data, pos + 2) local default = sub(data, pos + 2)
if default and default ~= "" then if default and default ~= '' then
col.default = default col.default = default
end end
return col return col
@ -460,7 +493,7 @@ local function _parse_row_data_packet(data, cols, compact)
local typ = col.type local typ = col.type
local name = col.name local name = col.name
--print("row field value: ", value, ", type: ", typ) --print('row field value: ', value, ', type: ', typ)
if value ~= null then if value ~= null then
local conv = converters[typ] local conv = converters[typ]
@ -486,13 +519,13 @@ local function _recv_field_packet(self)
return nil, err return nil, err
end end
if typ == "ERR" then if typ == 'ERR' then
local errno, msg, sqlstate = _parse_err_packet(packet) local errno, msg, sqlstate = _parse_err_packet(packet)
return nil, msg, errno, sqlstate return nil, msg, errno, sqlstate
end end
if typ ~= 'DATA' then if typ ~= 'DATA' then
return nil, "bad field packet type: " .. typ return nil, 'bad field packet type: ' .. typ
end end
-- typ == 'DATA' -- typ == 'DATA'
@ -513,7 +546,7 @@ end
function _M.connect(self, opts) function _M.connect(self, opts)
local sock = self.sock local sock = self.sock
if not sock then if not sock then
return nil, "not initialized" return nil, 'not initialized'
end end
local max_packet_size = opts.max_packet_size local max_packet_size = opts.max_packet_size
@ -526,79 +559,52 @@ function _M.connect(self, opts)
self.compact = opts.compact_arrays self.compact = opts.compact_arrays
local database = opts.database or "" local database = opts.database or ''
local user = opts.user or "" local user = opts.user or ''
local charset = CHARSET_MAP[opts.charset or "_default"] local charset = CHARSET_MAP[opts.charset or '_default']
if not charset then if not charset then
return nil, "charset '" .. opts.charset .. "' is not supported" return nil, 'charset \'' .. opts.charset .. '\' is not supported'
end end
local pool = opts.pool
local host = opts.host local host = opts.host
if host then local port = opts.port or 3306
local port = opts.port or 3306 ok, err = sock:connect(host, port)
if not pool then
pool = user .. ":" .. database .. ":" .. host .. ":" .. port
end
ok, err = sock:connect(host, port, { pool = pool })
else
local path = opts.path
if not path then
return nil, 'neither "host" nor "path" options are specified'
end
if not pool then
pool = user .. ":" .. database .. ":" .. path
end
ok, err = sock:connect("unix:" .. path, { pool = pool })
end
if not ok then if not ok then
return nil, 'failed to connect: ' .. err return nil, 'failed to connect: ' .. err
end end
local reused = sock:getreusedtimes()
if reused and reused > 0 then
self.state = STATE_CONNECTED
return 1
end
local packet, typ, err = _recv_packet(self) local packet, typ, err = _recv_packet(self)
if not packet then if not packet then
return nil, err return nil, err
end end
if typ == "ERR" then if typ == 'ERR' then
local errno, msg, sqlstate = _parse_err_packet(packet) local errno, msg, sqlstate = _parse_err_packet(packet)
return nil, msg, errno, sqlstate return nil, msg, errno, sqlstate
end end
self.protocol_ver = strbyte(packet) self.protocol_ver = strbyte(packet)
--print("protocol version: ", self.protocol_ver) --print('protocol version: ', self.protocol_ver)
local server_ver, pos = _from_cstring(packet, 2) local server_ver, pos = _from_cstring(packet, 2)
if not server_ver then if not server_ver then
return nil, "bad handshake initialization packet: bad server version" return nil, 'bad handshake initialization packet: bad server version'
end end
--print("server version: ", server_ver) --print('server version: ', server_ver)
self._server_ver = server_ver self._server_ver = server_ver
local thread_id, pos = _get_byte4(packet, pos) local thread_id, pos = _get_byte4(packet, pos)
--print("thread id: ", thread_id) --print('thread id: ', thread_id)
local scramble = sub(packet, pos, pos + 8 - 1) local scramble = sub(packet, pos, pos + 8 - 1)
if not scramble then if not scramble then
return nil, "1st part of scramble not found" return nil, '1st part of scramble not found'
end end
pos = pos + 9 -- skip filler pos = pos + 9 -- skip filler
@ -607,38 +613,38 @@ function _M.connect(self, opts)
local capabilities -- server capabilities local capabilities -- server capabilities
capabilities, pos = _get_byte2(packet, pos) capabilities, pos = _get_byte2(packet, pos)
-- print(format("server capabilities: %#x", capabilities)) -- print(format('server capabilities: %#x', capabilities))
self._server_lang = strbyte(packet, pos) self._server_lang = strbyte(packet, pos)
pos = pos + 1 pos = pos + 1
--print("server lang: ", self._server_lang) --print('server lang: ', self._server_lang)
self._server_status, pos = _get_byte2(packet, pos) self._server_status, pos = _get_byte2(packet, pos)
--print("server status: ", self._server_status) --print('server status: ', self._server_status)
local more_capabilities local more_capabilities
more_capabilities, pos = _get_byte2(packet, pos) more_capabilities, pos = _get_byte2(packet, pos)
capabilities = bor(capabilities, lshift(more_capabilities, 16)) capabilities = bor(capabilities, lshift(more_capabilities, 16))
--print("server capabilities: ", capabilities) --print('server capabilities: ', capabilities)
-- local len = strbyte(packet, pos) -- local len = strbyte(packet, pos)
local len = 21 - 8 - 1 local len = 21 - 8 - 1
--print("scramble len: ", len) --print('scramble len: ', len)
pos = pos + 1 + 10 pos = pos + 1 + 10
local scramble_part2 = sub(packet, pos, pos + len - 1) local scramble_part2 = sub(packet, pos, pos + len - 1)
if not scramble_part2 then if not scramble_part2 then
return nil, "2nd part of scramble not found" return nil, '2nd part of scramble not found'
end end
scramble = scramble .. scramble_part2 scramble = scramble .. scramble_part2
--print("scramble: ", _dump(scramble)) --print('scramble: ', _dump(scramble))
local client_flags = 0x3f7cf; local client_flags = 0x3f7cf;
@ -647,37 +653,37 @@ function _M.connect(self, opts)
if use_ssl then if use_ssl then
if band(capabilities, CLIENT_SSL) == 0 then if band(capabilities, CLIENT_SSL) == 0 then
return nil, "ssl disabled on server" return nil, 'ssl disabled on server'
end end
-- send a SSL Request Packet -- send a SSL Request Packet
local req = _set_byte4(bor(client_flags, CLIENT_SSL)) local req = _set_byte4(bor(client_flags, CLIENT_SSL))
.. _set_byte4(self._max_packet_size) .. _set_byte4(self._max_packet_size)
.. strchar(charset) .. strchar(charset)
.. strrep("\0", 23) .. strrep('\0', 23)
local packet_len = 4 + 4 + 1 + 23 local packet_len = 4 + 4 + 1 + 23
local bytes, err = _send_packet(self, req, packet_len) local bytes, err = _send_packet(self, req, packet_len)
if not bytes then if not bytes then
return nil, "failed to send client authentication packet: " .. err return nil, 'failed to send client authentication packet: ' .. err
end end
local ok, err = sock:sslhandshake(false, nil, ssl_verify) local ok, err = sock:sslhandshake(false, nil, ssl_verify)
if not ok then if not ok then
return nil, "failed to do ssl handshake: " .. (err or "") return nil, 'failed to do ssl handshake: ' .. (err or '')
end end
end end
local password = opts.password or "" local password = opts.password or ''
local token = _compute_token(password, scramble) local token = _compute_token(password, scramble)
--print("token: ", _dump(token)) --print('token: ', _dump(token))
local req = _set_byte4(client_flags) local req = _set_byte4(client_flags)
.. _set_byte4(self._max_packet_size) .. _set_byte4(self._max_packet_size)
.. strchar(charset) .. strchar(charset)
.. strrep("\0", 23) .. strrep('\0', 23)
.. _to_cstring(user) .. _to_cstring(user)
.. _to_binary_coded_string(token) .. _to_binary_coded_string(token)
.. _to_cstring(database) .. _to_cstring(database)
@ -685,19 +691,19 @@ function _M.connect(self, opts)
local packet_len = 4 + 4 + 1 + 23 + #user + 1 local packet_len = 4 + 4 + 1 + 23 + #user + 1
+ #token + 1 + #database + 1 + #token + 1 + #database + 1
-- print("packet content length: ", packet_len) -- print('packet content length: ', packet_len)
-- print("packet content: ", _dump(concat(req, ""))) -- print('packet content: ', _dump(concat(req, '')))
local bytes, err = _send_packet(self, req, packet_len) local bytes, err = _send_packet(self, req, packet_len)
if not bytes then if not bytes then
return nil, "failed to send client authentication packet: " .. err return nil, 'failed to send client authentication packet: ' .. err
end end
--print("packet sent ", bytes, " bytes") --print('packet sent ', bytes, ' bytes')
local packet, typ, err = _recv_packet(self) local packet, typ, err = _recv_packet(self)
if not packet then if not packet then
return nil, "failed to receive the result packet: " .. err return nil, 'failed to receive the result packet: ' .. err
end end
if typ == 'ERR' then if typ == 'ERR' then
@ -706,11 +712,11 @@ function _M.connect(self, opts)
end end
if typ == 'EOF' then if typ == 'EOF' then
return nil, "old pre-4.1 authentication protocol not supported" return nil, 'old pre-4.1 authentication protocol not supported'
end end
if typ ~= 'OK' then if typ ~= 'OK' then
return nil, "bad packet type: " .. typ return nil, 'bad packet type: ' .. typ
end end
self.state = STATE_CONNECTED self.state = STATE_CONNECTED
@ -718,37 +724,10 @@ function _M.connect(self, opts)
return 1 return 1
end end
function _M.set_keepalive(self, ...)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
if self.state ~= STATE_CONNECTED then
return nil, "cannot be reused in the current connection state: "
.. (self.state or "nil")
end
self.state = nil
return sock:setkeepalive(...)
end
function _M.get_reused_times(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:getreusedtimes()
end
function _M.close(self) function _M.close(self)
local sock = self.sock local sock = self.sock
if not sock then if not sock then
return nil, "not initialized" return nil, 'not initialized'
end end
self.state = nil self.state = nil
@ -761,21 +740,19 @@ function _M.close(self)
return sock:close() return sock:close()
end end
function _M.server_ver(self) function _M.server_ver(self)
return self._server_ver return self._server_ver
end end
local function send_query(self, query) local function send_query(self, query)
if self.state ~= STATE_CONNECTED then if self.state ~= STATE_CONNECTED then
return nil, "cannot send query in the current context: " return nil, 'cannot send query in the current context: '
.. (self.state or "nil") .. (self.state or 'nil')
end end
local sock = self.sock local sock = self.sock
if not sock then if not sock then
return nil, "not initialized" return nil, 'not initialized'
end end
self.packet_no = -1 self.packet_no = -1
@ -790,22 +767,21 @@ local function send_query(self, query)
self.state = STATE_COMMAND_SENT self.state = STATE_COMMAND_SENT
--print("packet sent ", bytes, " bytes") --print('packet sent ', bytes, ' bytes')
return bytes return bytes
end end
_M.send_query = send_query _M.send_query = send_query
local function read_result(self, est_nrows) local function read_result(self, est_nrows)
if self.state ~= STATE_COMMAND_SENT then if self.state ~= STATE_COMMAND_SENT then
return nil, "cannot read result in the current context: " return nil, 'cannot read result in the current context: '
.. (self.state or "nil") .. (self.state or 'nil')
end end
local sock = self.sock local sock = self.sock
if not sock then if not sock then
return nil, "not initialized" return nil, 'not initialized'
end end
local packet, typ, err = _recv_packet(self) local packet, typ, err = _recv_packet(self)
@ -813,7 +789,7 @@ local function read_result(self, est_nrows)
return nil, err return nil, err
end end
if typ == "ERR" then if typ == 'ERR' then
self.state = STATE_CONNECTED self.state = STATE_CONNECTED
local errno, msg, sqlstate = _parse_err_packet(packet) local errno, msg, sqlstate = _parse_err_packet(packet)
@ -823,7 +799,7 @@ local function read_result(self, est_nrows)
if typ == 'OK' then if typ == 'OK' then
local res = _parse_ok_packet(packet) local res = _parse_ok_packet(packet)
if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
return res, "again" return res, 'again'
end end
self.state = STATE_CONNECTED self.state = STATE_CONNECTED
@ -833,16 +809,16 @@ local function read_result(self, est_nrows)
if typ ~= 'DATA' then if typ ~= 'DATA' then
self.state = STATE_CONNECTED self.state = STATE_CONNECTED
return nil, "packet type " .. typ .. " not supported" return nil, 'packet type ' .. typ .. ' not supported'
end end
-- typ == 'DATA' -- typ == 'DATA'
--print("read the result set header packet") --print('read the result set header packet')
local field_count, extra = _parse_result_set_header_packet(packet) local field_count, extra = _parse_result_set_header_packet(packet)
--print("field count: ", field_count) --print('field count: ', field_count)
local cols = new_tab(field_count, 0) local cols = new_tab(field_count, 0)
for i = 1, field_count do for i = 1, field_count do
@ -860,8 +836,8 @@ local function read_result(self, est_nrows)
end end
if typ ~= 'EOF' then if typ ~= 'EOF' then
return nil, "unexpected packet type " .. typ .. " while eof packet is " return nil, 'unexpected packet type ' .. typ .. ' while eof packet is '
.. "expected" .. 'expected'
end end
-- typ == 'EOF' -- typ == 'EOF'
@ -871,7 +847,7 @@ local function read_result(self, est_nrows)
local rows = new_tab(est_nrows or 4, 0) local rows = new_tab(est_nrows or 4, 0)
local i = 0 local i = 0
while true do while true do
--print("reading a row") --print('reading a row')
packet, typ, err = _recv_packet(self) packet, typ, err = _recv_packet(self)
if not packet then if not packet then
@ -881,10 +857,10 @@ local function read_result(self, est_nrows)
if typ == 'EOF' then if typ == 'EOF' then
local warning_count, status_flags = _parse_eof_packet(packet) local warning_count, status_flags = _parse_eof_packet(packet)
--print("status flags: ", status_flags) --print('status flags: ', status_flags)
if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
return rows, "again", cols return rows, 'again', cols
end end
break break
@ -907,17 +883,15 @@ local function read_result(self, est_nrows)
end end
_M.read_result = read_result _M.read_result = read_result
function _M.query(self, query, est_nrows) function _M.query(self, query, est_nrows)
local bytes, err = send_query(self, query) local bytes, err = send_query(self, query)
if not bytes then if not bytes then
return nil, "failed to send query: " .. err return nil, 'failed to send query: ' .. err
end end
return read_result(self, est_nrows) return read_result(self, est_nrows)
end end
function _M.set_compact_arrays(self, value) function _M.set_compact_arrays(self, value)
self.compact = value self.compact = value
end end