mirror of
https://github.com/luapower/mysql.git
synced 2025-01-06 08:10:24 +01:00
unimportant
This commit is contained in:
parent
cb7ae6e888
commit
7e7d7501eb
272
mysql_client.lua
272
mysql_client.lua
@ -3,7 +3,7 @@
|
|||||||
-- Written by Yichun Zhang (agentzh). BSD license.
|
-- Written by Yichun Zhang (agentzh). BSD license.
|
||||||
|
|
||||||
local tcp = require'sock'.tcp
|
local tcp = require'sock'.tcp
|
||||||
local sha1 = require'sha1'
|
local sha1 = require'sha1'.sha1
|
||||||
local bit = require'bit'
|
local bit = require'bit'
|
||||||
|
|
||||||
local sub = string.sub
|
local sub = string.sub
|
||||||
@ -165,7 +165,7 @@ end
|
|||||||
|
|
||||||
|
|
||||||
local function _from_cstring(data, i)
|
local function _from_cstring(data, i)
|
||||||
local last = strfind(data, "\0", i, true)
|
local last = strfind(data, '\0', i, true)
|
||||||
if not last then
|
if not last then
|
||||||
return nil, nil
|
return nil, nil
|
||||||
end
|
end
|
||||||
@ -175,7 +175,7 @@ end
|
|||||||
|
|
||||||
|
|
||||||
local function _to_cstring(data)
|
local function _to_cstring(data)
|
||||||
return data .. "\0"
|
return data .. '\0'
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@ -188,9 +188,9 @@ local function _dump(data)
|
|||||||
local len = #data
|
local len = #data
|
||||||
local bytes = new_tab(len, 0)
|
local bytes = new_tab(len, 0)
|
||||||
for i = 1, len do
|
for i = 1, len do
|
||||||
bytes[i] = format("%x", strbyte(data, i))
|
bytes[i] = format('%x', strbyte(data, i))
|
||||||
end
|
end
|
||||||
return concat(bytes, " ")
|
return concat(bytes, ' ')
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@ -200,13 +200,13 @@ local function _dumphex(data)
|
|||||||
for i = 1, len do
|
for i = 1, len do
|
||||||
bytes[i] = tohex(strbyte(data, i), 2)
|
bytes[i] = tohex(strbyte(data, i), 2)
|
||||||
end
|
end
|
||||||
return concat(bytes, " ")
|
return concat(bytes, ' ')
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
local function _compute_token(password, scramble)
|
local function _compute_token(password, scramble)
|
||||||
if password == "" then
|
if password == '' then
|
||||||
return ""
|
return ''
|
||||||
end
|
end
|
||||||
|
|
||||||
local stage1 = sha1(password)
|
local stage1 = sha1(password)
|
||||||
@ -227,68 +227,101 @@ local function _send_packet(self, req, size)
|
|||||||
|
|
||||||
self.packet_no = self.packet_no + 1
|
self.packet_no = self.packet_no + 1
|
||||||
|
|
||||||
-- print("packet no: ", self.packet_no)
|
-- print('packet no: ', self.packet_no)
|
||||||
|
|
||||||
local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req
|
local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req
|
||||||
|
|
||||||
-- print("sending packet: ", _dump(packet))
|
-- print('sending packet: ', _dump(packet))
|
||||||
|
|
||||||
-- print("sending packet... of size " .. #packet)
|
-- print('sending packet... of size ' .. #packet)
|
||||||
|
|
||||||
return sock:send(packet)
|
return sock:send(packet)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--static, auto-growing buffer allocation pattern (ctype must be vla).
|
||||||
|
local function grow_buffer(ctype)
|
||||||
|
local vla = ffi.typeof(ctype)
|
||||||
|
local buf, len = nil, -1
|
||||||
|
return function(minlen)
|
||||||
|
if minlen == false then
|
||||||
|
buf, len = nil, -1
|
||||||
|
elseif minlen > len then
|
||||||
|
len = glue.nextpow2(minlen)
|
||||||
|
buf = vla(len)
|
||||||
|
end
|
||||||
|
return buf, len
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local function _recv(self, sz)
|
||||||
|
local buf = self.buf
|
||||||
|
if not buf then
|
||||||
|
buf = grow_buffer'char[?]'
|
||||||
|
self.buf = buf
|
||||||
|
end
|
||||||
|
local buf = buf(sz)
|
||||||
|
local sock = self.sock
|
||||||
|
local offset = 0
|
||||||
|
while sz > 0 do
|
||||||
|
local n, err = sock:recv(buf + offset, sz)
|
||||||
|
if not n then return nil, err end
|
||||||
|
sz = sz - n
|
||||||
|
offset = offset + n
|
||||||
|
end
|
||||||
|
return ffi.string(buf, offset)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
local function _recv_packet(self)
|
local function _recv_packet(self)
|
||||||
local sock = self.sock
|
local sock = self.sock
|
||||||
|
|
||||||
local data, err = sock:receive(4) -- packet header
|
local data, err = _recv(self, 4) -- packet header
|
||||||
if not data then
|
if not data then
|
||||||
return nil, nil, "failed to receive packet header: " .. err
|
return nil, nil, 'failed to receive packet header: ' .. err
|
||||||
end
|
end
|
||||||
|
|
||||||
--print("packet header: ", _dump(data))
|
--print('packet header: ', _dump(data))
|
||||||
|
|
||||||
local len, pos = _get_byte3(data, 1)
|
local len, pos = _get_byte3(data, 1)
|
||||||
|
|
||||||
--print("packet length: ", len)
|
--print('packet length: ', len)
|
||||||
|
|
||||||
if len == 0 then
|
if len == 0 then
|
||||||
return nil, nil, "empty packet"
|
return nil, nil, 'empty packet'
|
||||||
end
|
end
|
||||||
|
|
||||||
if len > self._max_packet_size then
|
if len > self._max_packet_size then
|
||||||
return nil, nil, "packet size too big: " .. len
|
return nil, nil, 'packet size too big: ' .. len
|
||||||
end
|
end
|
||||||
|
|
||||||
local num = strbyte(data, pos)
|
local num = strbyte(data, pos)
|
||||||
|
|
||||||
--print("recv packet: packet no: ", num)
|
--print('recv packet: packet no: ', num)
|
||||||
|
|
||||||
self.packet_no = num
|
self.packet_no = num
|
||||||
|
|
||||||
data, err = sock:receive(len)
|
data, err = _recv(self, len)
|
||||||
|
|
||||||
--print("receive returned")
|
--print('receive returned')
|
||||||
|
|
||||||
if not data then
|
if not data then
|
||||||
return nil, nil, "failed to read packet content: " .. err
|
return nil, nil, 'failed to read packet content: ' .. err
|
||||||
end
|
end
|
||||||
|
|
||||||
--print("packet content: ", _dump(data))
|
--print('packet content: ', _dump(data))
|
||||||
--print("packet content (ascii): ", data)
|
--print('packet content (ascii): ', data)
|
||||||
|
|
||||||
local field_count = strbyte(data, 1)
|
local field_count = strbyte(data, 1)
|
||||||
|
|
||||||
local typ
|
local typ
|
||||||
if field_count == 0x00 then
|
if field_count == 0x00 then
|
||||||
typ = "OK"
|
typ = 'OK'
|
||||||
elseif field_count == 0xff then
|
elseif field_count == 0xff then
|
||||||
typ = "ERR"
|
typ = 'ERR'
|
||||||
elseif field_count == 0xfe then
|
elseif field_count == 0xfe then
|
||||||
typ = "EOF"
|
typ = 'EOF'
|
||||||
else
|
else
|
||||||
typ = "DATA"
|
typ = 'DATA'
|
||||||
end
|
end
|
||||||
|
|
||||||
return data, typ
|
return data, typ
|
||||||
@ -298,7 +331,7 @@ end
|
|||||||
local function _from_length_coded_bin(data, pos)
|
local function _from_length_coded_bin(data, pos)
|
||||||
local first = strbyte(data, pos)
|
local first = strbyte(data, pos)
|
||||||
|
|
||||||
--print("LCB: first: ", first)
|
--print('LCB: first: ', first)
|
||||||
|
|
||||||
if not first then
|
if not first then
|
||||||
return nil, pos
|
return nil, pos
|
||||||
@ -348,26 +381,26 @@ local function _parse_ok_packet(packet)
|
|||||||
|
|
||||||
res.affected_rows, pos = _from_length_coded_bin(packet, 2)
|
res.affected_rows, pos = _from_length_coded_bin(packet, 2)
|
||||||
|
|
||||||
--print("affected rows: ", res.affected_rows, ", pos:", pos)
|
--print('affected rows: ', res.affected_rows, ', pos:', pos)
|
||||||
|
|
||||||
res.insert_id, pos = _from_length_coded_bin(packet, pos)
|
res.insert_id, pos = _from_length_coded_bin(packet, pos)
|
||||||
|
|
||||||
--print("insert id: ", res.insert_id, ", pos:", pos)
|
--print('insert id: ', res.insert_id, ', pos:', pos)
|
||||||
|
|
||||||
res.server_status, pos = _get_byte2(packet, pos)
|
res.server_status, pos = _get_byte2(packet, pos)
|
||||||
|
|
||||||
--print("server status: ", res.server_status, ", pos:", pos)
|
--print('server status: ', res.server_status, ', pos:', pos)
|
||||||
|
|
||||||
res.warning_count, pos = _get_byte2(packet, pos)
|
res.warning_count, pos = _get_byte2(packet, pos)
|
||||||
|
|
||||||
--print("warning count: ", res.warning_count, ", pos: ", pos)
|
--print('warning count: ', res.warning_count, ', pos: ', pos)
|
||||||
|
|
||||||
local message = _from_length_coded_str(packet, pos)
|
local message = _from_length_coded_str(packet, pos)
|
||||||
if message and message ~= null then
|
if message and message ~= null then
|
||||||
res.message = message
|
res.message = message
|
||||||
end
|
end
|
||||||
|
|
||||||
--print("message: ", res.message, ", pos:", pos)
|
--print('message: ', res.message, ', pos:', pos)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
end
|
end
|
||||||
@ -437,7 +470,7 @@ local function _parse_field_packet(data)
|
|||||||
col.decimals = strbyte(data, pos)
|
col.decimals = strbyte(data, pos)
|
||||||
pos = pos + 1
|
pos = pos + 1
|
||||||
local default = sub(data, pos + 2)
|
local default = sub(data, pos + 2)
|
||||||
if default and default ~= "" then
|
if default and default ~= '' then
|
||||||
col.default = default
|
col.default = default
|
||||||
end
|
end
|
||||||
return col
|
return col
|
||||||
@ -460,7 +493,7 @@ local function _parse_row_data_packet(data, cols, compact)
|
|||||||
local typ = col.type
|
local typ = col.type
|
||||||
local name = col.name
|
local name = col.name
|
||||||
|
|
||||||
--print("row field value: ", value, ", type: ", typ)
|
--print('row field value: ', value, ', type: ', typ)
|
||||||
|
|
||||||
if value ~= null then
|
if value ~= null then
|
||||||
local conv = converters[typ]
|
local conv = converters[typ]
|
||||||
@ -486,13 +519,13 @@ local function _recv_field_packet(self)
|
|||||||
return nil, err
|
return nil, err
|
||||||
end
|
end
|
||||||
|
|
||||||
if typ == "ERR" then
|
if typ == 'ERR' then
|
||||||
local errno, msg, sqlstate = _parse_err_packet(packet)
|
local errno, msg, sqlstate = _parse_err_packet(packet)
|
||||||
return nil, msg, errno, sqlstate
|
return nil, msg, errno, sqlstate
|
||||||
end
|
end
|
||||||
|
|
||||||
if typ ~= 'DATA' then
|
if typ ~= 'DATA' then
|
||||||
return nil, "bad field packet type: " .. typ
|
return nil, 'bad field packet type: ' .. typ
|
||||||
end
|
end
|
||||||
|
|
||||||
-- typ == 'DATA'
|
-- typ == 'DATA'
|
||||||
@ -513,7 +546,7 @@ end
|
|||||||
function _M.connect(self, opts)
|
function _M.connect(self, opts)
|
||||||
local sock = self.sock
|
local sock = self.sock
|
||||||
if not sock then
|
if not sock then
|
||||||
return nil, "not initialized"
|
return nil, 'not initialized'
|
||||||
end
|
end
|
||||||
|
|
||||||
local max_packet_size = opts.max_packet_size
|
local max_packet_size = opts.max_packet_size
|
||||||
@ -526,79 +559,52 @@ function _M.connect(self, opts)
|
|||||||
|
|
||||||
self.compact = opts.compact_arrays
|
self.compact = opts.compact_arrays
|
||||||
|
|
||||||
local database = opts.database or ""
|
local database = opts.database or ''
|
||||||
local user = opts.user or ""
|
local user = opts.user or ''
|
||||||
|
|
||||||
local charset = CHARSET_MAP[opts.charset or "_default"]
|
local charset = CHARSET_MAP[opts.charset or '_default']
|
||||||
if not charset then
|
if not charset then
|
||||||
return nil, "charset '" .. opts.charset .. "' is not supported"
|
return nil, 'charset \'' .. opts.charset .. '\' is not supported'
|
||||||
end
|
end
|
||||||
|
|
||||||
local pool = opts.pool
|
|
||||||
|
|
||||||
local host = opts.host
|
local host = opts.host
|
||||||
if host then
|
local port = opts.port or 3306
|
||||||
local port = opts.port or 3306
|
ok, err = sock:connect(host, port)
|
||||||
if not pool then
|
|
||||||
pool = user .. ":" .. database .. ":" .. host .. ":" .. port
|
|
||||||
end
|
|
||||||
|
|
||||||
ok, err = sock:connect(host, port, { pool = pool })
|
|
||||||
|
|
||||||
else
|
|
||||||
local path = opts.path
|
|
||||||
if not path then
|
|
||||||
return nil, 'neither "host" nor "path" options are specified'
|
|
||||||
end
|
|
||||||
|
|
||||||
if not pool then
|
|
||||||
pool = user .. ":" .. database .. ":" .. path
|
|
||||||
end
|
|
||||||
|
|
||||||
ok, err = sock:connect("unix:" .. path, { pool = pool })
|
|
||||||
end
|
|
||||||
|
|
||||||
if not ok then
|
if not ok then
|
||||||
return nil, 'failed to connect: ' .. err
|
return nil, 'failed to connect: ' .. err
|
||||||
end
|
end
|
||||||
|
|
||||||
local reused = sock:getreusedtimes()
|
|
||||||
|
|
||||||
if reused and reused > 0 then
|
|
||||||
self.state = STATE_CONNECTED
|
|
||||||
return 1
|
|
||||||
end
|
|
||||||
|
|
||||||
local packet, typ, err = _recv_packet(self)
|
local packet, typ, err = _recv_packet(self)
|
||||||
if not packet then
|
if not packet then
|
||||||
return nil, err
|
return nil, err
|
||||||
end
|
end
|
||||||
|
|
||||||
if typ == "ERR" then
|
if typ == 'ERR' then
|
||||||
local errno, msg, sqlstate = _parse_err_packet(packet)
|
local errno, msg, sqlstate = _parse_err_packet(packet)
|
||||||
return nil, msg, errno, sqlstate
|
return nil, msg, errno, sqlstate
|
||||||
end
|
end
|
||||||
|
|
||||||
self.protocol_ver = strbyte(packet)
|
self.protocol_ver = strbyte(packet)
|
||||||
|
|
||||||
--print("protocol version: ", self.protocol_ver)
|
--print('protocol version: ', self.protocol_ver)
|
||||||
|
|
||||||
local server_ver, pos = _from_cstring(packet, 2)
|
local server_ver, pos = _from_cstring(packet, 2)
|
||||||
if not server_ver then
|
if not server_ver then
|
||||||
return nil, "bad handshake initialization packet: bad server version"
|
return nil, 'bad handshake initialization packet: bad server version'
|
||||||
end
|
end
|
||||||
|
|
||||||
--print("server version: ", server_ver)
|
--print('server version: ', server_ver)
|
||||||
|
|
||||||
self._server_ver = server_ver
|
self._server_ver = server_ver
|
||||||
|
|
||||||
local thread_id, pos = _get_byte4(packet, pos)
|
local thread_id, pos = _get_byte4(packet, pos)
|
||||||
|
|
||||||
--print("thread id: ", thread_id)
|
--print('thread id: ', thread_id)
|
||||||
|
|
||||||
local scramble = sub(packet, pos, pos + 8 - 1)
|
local scramble = sub(packet, pos, pos + 8 - 1)
|
||||||
if not scramble then
|
if not scramble then
|
||||||
return nil, "1st part of scramble not found"
|
return nil, '1st part of scramble not found'
|
||||||
end
|
end
|
||||||
|
|
||||||
pos = pos + 9 -- skip filler
|
pos = pos + 9 -- skip filler
|
||||||
@ -607,38 +613,38 @@ function _M.connect(self, opts)
|
|||||||
local capabilities -- server capabilities
|
local capabilities -- server capabilities
|
||||||
capabilities, pos = _get_byte2(packet, pos)
|
capabilities, pos = _get_byte2(packet, pos)
|
||||||
|
|
||||||
-- print(format("server capabilities: %#x", capabilities))
|
-- print(format('server capabilities: %#x', capabilities))
|
||||||
|
|
||||||
self._server_lang = strbyte(packet, pos)
|
self._server_lang = strbyte(packet, pos)
|
||||||
pos = pos + 1
|
pos = pos + 1
|
||||||
|
|
||||||
--print("server lang: ", self._server_lang)
|
--print('server lang: ', self._server_lang)
|
||||||
|
|
||||||
self._server_status, pos = _get_byte2(packet, pos)
|
self._server_status, pos = _get_byte2(packet, pos)
|
||||||
|
|
||||||
--print("server status: ", self._server_status)
|
--print('server status: ', self._server_status)
|
||||||
|
|
||||||
local more_capabilities
|
local more_capabilities
|
||||||
more_capabilities, pos = _get_byte2(packet, pos)
|
more_capabilities, pos = _get_byte2(packet, pos)
|
||||||
|
|
||||||
capabilities = bor(capabilities, lshift(more_capabilities, 16))
|
capabilities = bor(capabilities, lshift(more_capabilities, 16))
|
||||||
|
|
||||||
--print("server capabilities: ", capabilities)
|
--print('server capabilities: ', capabilities)
|
||||||
|
|
||||||
-- local len = strbyte(packet, pos)
|
-- local len = strbyte(packet, pos)
|
||||||
local len = 21 - 8 - 1
|
local len = 21 - 8 - 1
|
||||||
|
|
||||||
--print("scramble len: ", len)
|
--print('scramble len: ', len)
|
||||||
|
|
||||||
pos = pos + 1 + 10
|
pos = pos + 1 + 10
|
||||||
|
|
||||||
local scramble_part2 = sub(packet, pos, pos + len - 1)
|
local scramble_part2 = sub(packet, pos, pos + len - 1)
|
||||||
if not scramble_part2 then
|
if not scramble_part2 then
|
||||||
return nil, "2nd part of scramble not found"
|
return nil, '2nd part of scramble not found'
|
||||||
end
|
end
|
||||||
|
|
||||||
scramble = scramble .. scramble_part2
|
scramble = scramble .. scramble_part2
|
||||||
--print("scramble: ", _dump(scramble))
|
--print('scramble: ', _dump(scramble))
|
||||||
|
|
||||||
local client_flags = 0x3f7cf;
|
local client_flags = 0x3f7cf;
|
||||||
|
|
||||||
@ -647,37 +653,37 @@ function _M.connect(self, opts)
|
|||||||
|
|
||||||
if use_ssl then
|
if use_ssl then
|
||||||
if band(capabilities, CLIENT_SSL) == 0 then
|
if band(capabilities, CLIENT_SSL) == 0 then
|
||||||
return nil, "ssl disabled on server"
|
return nil, 'ssl disabled on server'
|
||||||
end
|
end
|
||||||
|
|
||||||
-- send a SSL Request Packet
|
-- send a SSL Request Packet
|
||||||
local req = _set_byte4(bor(client_flags, CLIENT_SSL))
|
local req = _set_byte4(bor(client_flags, CLIENT_SSL))
|
||||||
.. _set_byte4(self._max_packet_size)
|
.. _set_byte4(self._max_packet_size)
|
||||||
.. strchar(charset)
|
.. strchar(charset)
|
||||||
.. strrep("\0", 23)
|
.. strrep('\0', 23)
|
||||||
|
|
||||||
local packet_len = 4 + 4 + 1 + 23
|
local packet_len = 4 + 4 + 1 + 23
|
||||||
local bytes, err = _send_packet(self, req, packet_len)
|
local bytes, err = _send_packet(self, req, packet_len)
|
||||||
if not bytes then
|
if not bytes then
|
||||||
return nil, "failed to send client authentication packet: " .. err
|
return nil, 'failed to send client authentication packet: ' .. err
|
||||||
end
|
end
|
||||||
|
|
||||||
local ok, err = sock:sslhandshake(false, nil, ssl_verify)
|
local ok, err = sock:sslhandshake(false, nil, ssl_verify)
|
||||||
if not ok then
|
if not ok then
|
||||||
return nil, "failed to do ssl handshake: " .. (err or "")
|
return nil, 'failed to do ssl handshake: ' .. (err or '')
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local password = opts.password or ""
|
local password = opts.password or ''
|
||||||
|
|
||||||
local token = _compute_token(password, scramble)
|
local token = _compute_token(password, scramble)
|
||||||
|
|
||||||
--print("token: ", _dump(token))
|
--print('token: ', _dump(token))
|
||||||
|
|
||||||
local req = _set_byte4(client_flags)
|
local req = _set_byte4(client_flags)
|
||||||
.. _set_byte4(self._max_packet_size)
|
.. _set_byte4(self._max_packet_size)
|
||||||
.. strchar(charset)
|
.. strchar(charset)
|
||||||
.. strrep("\0", 23)
|
.. strrep('\0', 23)
|
||||||
.. _to_cstring(user)
|
.. _to_cstring(user)
|
||||||
.. _to_binary_coded_string(token)
|
.. _to_binary_coded_string(token)
|
||||||
.. _to_cstring(database)
|
.. _to_cstring(database)
|
||||||
@ -685,19 +691,19 @@ function _M.connect(self, opts)
|
|||||||
local packet_len = 4 + 4 + 1 + 23 + #user + 1
|
local packet_len = 4 + 4 + 1 + 23 + #user + 1
|
||||||
+ #token + 1 + #database + 1
|
+ #token + 1 + #database + 1
|
||||||
|
|
||||||
-- print("packet content length: ", packet_len)
|
-- print('packet content length: ', packet_len)
|
||||||
-- print("packet content: ", _dump(concat(req, "")))
|
-- print('packet content: ', _dump(concat(req, '')))
|
||||||
|
|
||||||
local bytes, err = _send_packet(self, req, packet_len)
|
local bytes, err = _send_packet(self, req, packet_len)
|
||||||
if not bytes then
|
if not bytes then
|
||||||
return nil, "failed to send client authentication packet: " .. err
|
return nil, 'failed to send client authentication packet: ' .. err
|
||||||
end
|
end
|
||||||
|
|
||||||
--print("packet sent ", bytes, " bytes")
|
--print('packet sent ', bytes, ' bytes')
|
||||||
|
|
||||||
local packet, typ, err = _recv_packet(self)
|
local packet, typ, err = _recv_packet(self)
|
||||||
if not packet then
|
if not packet then
|
||||||
return nil, "failed to receive the result packet: " .. err
|
return nil, 'failed to receive the result packet: ' .. err
|
||||||
end
|
end
|
||||||
|
|
||||||
if typ == 'ERR' then
|
if typ == 'ERR' then
|
||||||
@ -706,11 +712,11 @@ function _M.connect(self, opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
if typ == 'EOF' then
|
if typ == 'EOF' then
|
||||||
return nil, "old pre-4.1 authentication protocol not supported"
|
return nil, 'old pre-4.1 authentication protocol not supported'
|
||||||
end
|
end
|
||||||
|
|
||||||
if typ ~= 'OK' then
|
if typ ~= 'OK' then
|
||||||
return nil, "bad packet type: " .. typ
|
return nil, 'bad packet type: ' .. typ
|
||||||
end
|
end
|
||||||
|
|
||||||
self.state = STATE_CONNECTED
|
self.state = STATE_CONNECTED
|
||||||
@ -718,37 +724,10 @@ function _M.connect(self, opts)
|
|||||||
return 1
|
return 1
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
function _M.set_keepalive(self, ...)
|
|
||||||
local sock = self.sock
|
|
||||||
if not sock then
|
|
||||||
return nil, "not initialized"
|
|
||||||
end
|
|
||||||
|
|
||||||
if self.state ~= STATE_CONNECTED then
|
|
||||||
return nil, "cannot be reused in the current connection state: "
|
|
||||||
.. (self.state or "nil")
|
|
||||||
end
|
|
||||||
|
|
||||||
self.state = nil
|
|
||||||
return sock:setkeepalive(...)
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
function _M.get_reused_times(self)
|
|
||||||
local sock = self.sock
|
|
||||||
if not sock then
|
|
||||||
return nil, "not initialized"
|
|
||||||
end
|
|
||||||
|
|
||||||
return sock:getreusedtimes()
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
function _M.close(self)
|
function _M.close(self)
|
||||||
local sock = self.sock
|
local sock = self.sock
|
||||||
if not sock then
|
if not sock then
|
||||||
return nil, "not initialized"
|
return nil, 'not initialized'
|
||||||
end
|
end
|
||||||
|
|
||||||
self.state = nil
|
self.state = nil
|
||||||
@ -761,21 +740,19 @@ function _M.close(self)
|
|||||||
return sock:close()
|
return sock:close()
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
function _M.server_ver(self)
|
function _M.server_ver(self)
|
||||||
return self._server_ver
|
return self._server_ver
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
local function send_query(self, query)
|
local function send_query(self, query)
|
||||||
if self.state ~= STATE_CONNECTED then
|
if self.state ~= STATE_CONNECTED then
|
||||||
return nil, "cannot send query in the current context: "
|
return nil, 'cannot send query in the current context: '
|
||||||
.. (self.state or "nil")
|
.. (self.state or 'nil')
|
||||||
end
|
end
|
||||||
|
|
||||||
local sock = self.sock
|
local sock = self.sock
|
||||||
if not sock then
|
if not sock then
|
||||||
return nil, "not initialized"
|
return nil, 'not initialized'
|
||||||
end
|
end
|
||||||
|
|
||||||
self.packet_no = -1
|
self.packet_no = -1
|
||||||
@ -790,22 +767,21 @@ local function send_query(self, query)
|
|||||||
|
|
||||||
self.state = STATE_COMMAND_SENT
|
self.state = STATE_COMMAND_SENT
|
||||||
|
|
||||||
--print("packet sent ", bytes, " bytes")
|
--print('packet sent ', bytes, ' bytes')
|
||||||
|
|
||||||
return bytes
|
return bytes
|
||||||
end
|
end
|
||||||
_M.send_query = send_query
|
_M.send_query = send_query
|
||||||
|
|
||||||
|
|
||||||
local function read_result(self, est_nrows)
|
local function read_result(self, est_nrows)
|
||||||
if self.state ~= STATE_COMMAND_SENT then
|
if self.state ~= STATE_COMMAND_SENT then
|
||||||
return nil, "cannot read result in the current context: "
|
return nil, 'cannot read result in the current context: '
|
||||||
.. (self.state or "nil")
|
.. (self.state or 'nil')
|
||||||
end
|
end
|
||||||
|
|
||||||
local sock = self.sock
|
local sock = self.sock
|
||||||
if not sock then
|
if not sock then
|
||||||
return nil, "not initialized"
|
return nil, 'not initialized'
|
||||||
end
|
end
|
||||||
|
|
||||||
local packet, typ, err = _recv_packet(self)
|
local packet, typ, err = _recv_packet(self)
|
||||||
@ -813,7 +789,7 @@ local function read_result(self, est_nrows)
|
|||||||
return nil, err
|
return nil, err
|
||||||
end
|
end
|
||||||
|
|
||||||
if typ == "ERR" then
|
if typ == 'ERR' then
|
||||||
self.state = STATE_CONNECTED
|
self.state = STATE_CONNECTED
|
||||||
|
|
||||||
local errno, msg, sqlstate = _parse_err_packet(packet)
|
local errno, msg, sqlstate = _parse_err_packet(packet)
|
||||||
@ -823,7 +799,7 @@ local function read_result(self, est_nrows)
|
|||||||
if typ == 'OK' then
|
if typ == 'OK' then
|
||||||
local res = _parse_ok_packet(packet)
|
local res = _parse_ok_packet(packet)
|
||||||
if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
|
if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
|
||||||
return res, "again"
|
return res, 'again'
|
||||||
end
|
end
|
||||||
|
|
||||||
self.state = STATE_CONNECTED
|
self.state = STATE_CONNECTED
|
||||||
@ -833,16 +809,16 @@ local function read_result(self, est_nrows)
|
|||||||
if typ ~= 'DATA' then
|
if typ ~= 'DATA' then
|
||||||
self.state = STATE_CONNECTED
|
self.state = STATE_CONNECTED
|
||||||
|
|
||||||
return nil, "packet type " .. typ .. " not supported"
|
return nil, 'packet type ' .. typ .. ' not supported'
|
||||||
end
|
end
|
||||||
|
|
||||||
-- typ == 'DATA'
|
-- typ == 'DATA'
|
||||||
|
|
||||||
--print("read the result set header packet")
|
--print('read the result set header packet')
|
||||||
|
|
||||||
local field_count, extra = _parse_result_set_header_packet(packet)
|
local field_count, extra = _parse_result_set_header_packet(packet)
|
||||||
|
|
||||||
--print("field count: ", field_count)
|
--print('field count: ', field_count)
|
||||||
|
|
||||||
local cols = new_tab(field_count, 0)
|
local cols = new_tab(field_count, 0)
|
||||||
for i = 1, field_count do
|
for i = 1, field_count do
|
||||||
@ -860,8 +836,8 @@ local function read_result(self, est_nrows)
|
|||||||
end
|
end
|
||||||
|
|
||||||
if typ ~= 'EOF' then
|
if typ ~= 'EOF' then
|
||||||
return nil, "unexpected packet type " .. typ .. " while eof packet is "
|
return nil, 'unexpected packet type ' .. typ .. ' while eof packet is '
|
||||||
.. "expected"
|
.. 'expected'
|
||||||
end
|
end
|
||||||
|
|
||||||
-- typ == 'EOF'
|
-- typ == 'EOF'
|
||||||
@ -871,7 +847,7 @@ local function read_result(self, est_nrows)
|
|||||||
local rows = new_tab(est_nrows or 4, 0)
|
local rows = new_tab(est_nrows or 4, 0)
|
||||||
local i = 0
|
local i = 0
|
||||||
while true do
|
while true do
|
||||||
--print("reading a row")
|
--print('reading a row')
|
||||||
|
|
||||||
packet, typ, err = _recv_packet(self)
|
packet, typ, err = _recv_packet(self)
|
||||||
if not packet then
|
if not packet then
|
||||||
@ -881,10 +857,10 @@ local function read_result(self, est_nrows)
|
|||||||
if typ == 'EOF' then
|
if typ == 'EOF' then
|
||||||
local warning_count, status_flags = _parse_eof_packet(packet)
|
local warning_count, status_flags = _parse_eof_packet(packet)
|
||||||
|
|
||||||
--print("status flags: ", status_flags)
|
--print('status flags: ', status_flags)
|
||||||
|
|
||||||
if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
|
if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
|
||||||
return rows, "again", cols
|
return rows, 'again', cols
|
||||||
end
|
end
|
||||||
|
|
||||||
break
|
break
|
||||||
@ -907,17 +883,15 @@ local function read_result(self, est_nrows)
|
|||||||
end
|
end
|
||||||
_M.read_result = read_result
|
_M.read_result = read_result
|
||||||
|
|
||||||
|
|
||||||
function _M.query(self, query, est_nrows)
|
function _M.query(self, query, est_nrows)
|
||||||
local bytes, err = send_query(self, query)
|
local bytes, err = send_query(self, query)
|
||||||
if not bytes then
|
if not bytes then
|
||||||
return nil, "failed to send query: " .. err
|
return nil, 'failed to send query: ' .. err
|
||||||
end
|
end
|
||||||
|
|
||||||
return read_result(self, est_nrows)
|
return read_result(self, est_nrows)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
function _M.set_compact_arrays(self, value)
|
function _M.set_compact_arrays(self, value)
|
||||||
self.compact = value
|
self.compact = value
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user