From 27382e197d8cd29a5b6bd801bb1d0f7cf720cedd Mon Sep 17 00:00:00 2001 From: Cosmin Apreutesei Date: Sat, 21 Aug 2021 22:08:23 +0300 Subject: [PATCH] switched to a binary protocol; added prepared statements --- mysql_client.lua | 1310 ++++++++++++++++++++++++---------------------- mysql_client.md | 49 +- 2 files changed, 704 insertions(+), 655 deletions(-) diff --git a/mysql_client.lua b/mysql_client.lua index 999bf0b..77f2f1f 100644 --- a/mysql_client.lua +++ b/mysql_client.lua @@ -1,12 +1,13 @@ --- MySQL client protocol in Lua. --- Written by Yichun Zhang (agentzh). BSD license. --- Modified by Cosmin Apreutesei. Pulbic domain. +--MySQL client protocol in Lua. +--Written by Cosmin Apreutesei. Public domain. +--Original code by Yichun Zhang (agentzh). BSD license. local ffi = require'ffi' local bit = require'bit' local sha1 = require'sha1'.sha1 local glue = require'glue' +local errors = require'errors' local sub = string.sub local strbyte = string.byte @@ -16,22 +17,27 @@ local strrep = string.rep local band = bit.band local bxor = bit.bxor local bor = bit.bor -local lshift = bit.lshift -local rshift = bit.rshift +local shl = bit.lshift +local shr = bit.rshift local tohex = bit.tohex local concat = table.concat local buffer = glue.buffer +local dynarray = glue.dynarray local index = glue.index local repl = glue.repl +local update = glue.update -local ok, new_tab = pcall(require, 'table.new') -new_tab = ok and new_tab or function() return {} end +local check_io, check, protect = errors.tcp_protocol_errors'mysql' local mysql = {} -local COM_QUIT = 0x01 -local COM_QUERY = 0x03 +local COM_QUIT = 0x01 +local COM_QUERY = 0x03 +local COM_STMT_PREPARE = 0x16 +local COM_STMT_EXECUTE = 0x17 +local COM_STMT_CLOSE = 0x19 + local CLIENT_SSL = 0x0800 local SERVER_MORE_RESULTS_EXISTS = 8 @@ -386,7 +392,7 @@ local buffer_type_names = { [255] = 'geometry', } -local type_names = { +local num_types = { tiny = 'tinyint', short = 'shortint', long = 'int', @@ -395,7 +401,7 @@ local type_names = { newdecimal = 'decimal', } -local bin_type_names = { +local bin_types = { tiny_blob = 'tinyblob', medium_blob = 'mediumblob', long_blob = 'longblob', @@ -404,7 +410,7 @@ local bin_type_names = { string = 'binary', } -local text_type_names = { +local text_types = { tiny_blob = 'tinytext', medium_blob = 'mediumtext', long_blob = 'longtext', @@ -413,11 +419,25 @@ local text_type_names = { string = 'char', } -local conn = {} -local mt = {__index = conn} +local string_types = { + string=1, + varchar=1, + var_string=1, + enum=1, + set=1, + long_blob=1, + blob=1, + tiny_blob=1, + geometry=1, + bit=1, + decimal=1, + newdecimal=1, +} --- mysql field value type converters -local converters = { +local conn = {} +local conn_mt = {__index = conn} + +local from_text_converters = { tinyint = tonumber, shortint = tonumber, mediumint = tonumber, @@ -428,756 +448,635 @@ local converters = { double = tonumber, } -local function _get_byte2(data, i) - local a, b = strbyte(data, i, i + 1) - return bor(a, lshift(b, 8)), i + 2 +local function return_arg1(v) return v end + +assert(ffi.abi'le') + +local function buf_len(buf) + local _, _, n = buf() + buf(-n) + return n end +local i8_ct = ffi.typeof 'int8_t*' +local i16_ct = ffi.typeof 'int16_t*' +local u16_ct = ffi.typeof'uint16_t*' +local i32_ct = ffi.typeof' int32_t*' +local u32_ct = ffi.typeof'uint32_t*' +local i64_ct = ffi.typeof 'int64_t*' +local u64_ct = ffi.typeof'uint64_t*' +local f64_ct = ffi.typeof'double*' +local f32_ct = ffi.typeof'float*' -local function _get_byte3(data, i) - local a, b, c = strbyte(data, i, i + 2) - return bor(a, lshift(b, 8), lshift(c, 16)), i + 3 +local function get_u8(buf) + local p, i = buf(1) + return p[i] end - -local function _get_byte4(data, i) - local a, b, c, d = strbyte(data, i, i + 3) - return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4 +local function get_i8(buf) + local p, i = buf(1) + return ffi.cast(i8_ct, p+i)[0] end - -local function _get_byte8(data, i) - local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7) - - -- XXX workaround for the lack of 64-bit support in bitop: - local lo = bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)) - local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24)) - return lo + hi * 4294967296, i + 8 - - -- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32), - -- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8 +local function get_u16(buf) + local p, i = buf(2) + return ffi.cast(u16_ct, p+i)[0] end - -local function _set_byte2(n) - return strchar(band(n, 0xff), band(rshift(n, 8), 0xff)) +local function get_i16(buf) + local p, i = buf(2) + return ffi.cast(i16_ct, p+i)[0] end - -local function _set_byte3(n) - return strchar(band(n, 0xff), - band(rshift(n, 8), 0xff), - band(rshift(n, 16), 0xff)) +local function get_u24(buf) + local p, i = buf(3) + local a, b, c = p[i], p[i+1], p[i+2] + return bor(a, shl(b, 8), shl(c, 16)) end - -local function _set_byte4(n) - return strchar(band(n, 0xff), - band(rshift(n, 8), 0xff), - band(rshift(n, 16), 0xff), - band(rshift(n, 24), 0xff)) +local function get_u32(buf) + local p, i = buf(4) + return ffi.cast(u32_ct, p+i)[0] end +local function get_i32(buf) + local p, i = buf(4) + return ffi.cast(i32_ct, p+i)[0] +end -local function _from_cstring(data, i) - local last = data:find('\0', i, true) - if not last then - return nil, nil +local function get_u64(buf) + local p, i = buf(8) + return tonumber(ffi.cast(u64_ct, p+i)[0]) +end + +local function get_i64(buf) + local p, i = buf(8) + return tonumber(ffi.cast(i64_ct, p+i)[0]) +end + +local function get_f64(buf) + local p, i = buf(8) + return tonumber(ffi.cast(f64_ct, p+i)[0]) +end + +local function get_f32(buf) + local p, i = buf(4) + return tonumber(ffi.cast(f32_ct, p+i)[0]) +end + +local function get_uint(buf) --length-encoded int + local c = get_u8(buf) + if c < 0xfb then + return c + elseif c == 0xfb then --NULL string + return nil + elseif c == 0xfc then + return get_u16(buf) + elseif c == 0xfd then + return get_u24(buf) + elseif c == 0xfe then + return get_u64(buf) + else + buf(1/0, 'invalid length-encoded int') end - - return sub(data, i, last - 1), last + 1 end - -local function _to_cstring(data) - return data .. '\0' -end - - -local function _to_binary_coded_string(data) - return strchar(#data) .. data -end - - -local function _dump(data) - local len = #data - local bytes = new_tab(len, 0) - for i = 1, len do - bytes[i] = format('%x', strbyte(data, i)) +local function get_cstring(buf) + local p, i0 = buf(0) + while true do + local _, i = buf(1) + if p[i] == 0 then + return ffi.string(p+i0, i-i0) + end + i = i + 1 end - return concat(bytes, ' ') end +local function get_str(buf) --length-encoded string + local slen = get_uint(buf) + if not slen then return nil end + local p, i = buf(slen) + return ffi.string(p+i, slen) +end -local function _dumphex(data) - local len = #data - local bytes = new_tab(len, 0) - for i = 1, len do - bytes[i] = tohex(strbyte(data, i), 2) +local function get_bytes(buf, len) --fixed-length string + local p, i, len = buf(len) + return ffi.string(p+i, len) +end + +local function get_datetime(buf, date_format) + local len = get_u8(buf) + if len == 0 then + return date_format == '*t' + and {year = 0, month = 0, day = 0} + or date_format and format(date_format, 0, 0, 0, 0, 0, 0, 0) + or '0000-00-00' end - return concat(bytes, ' ') + local y = get_u16(buf) + local m = get_u8(buf) + local d = get_u8(buf) + if len == 4 then + return date_format == '*t' + and {year = y, month = m, day = d} + or format(date_format or '%04d-%02d-%02d', y, m, d, 0, 0, 0, 0) + end + local H = get_u8(buf) + local M = get_u8(buf) + local S = get_u8(buf) + local ms = len == 7 and 0 or get_u32(buf) + return date_format == '*t' + and {year = y, month = m, day = d, hour = H, min = M, sec = S + ms / 10^6} + or format(date_format or (len == 7 + and '%04d-%02d-%02d %02d:%02d:%02d' + or '%04d-%02d-%02d %02d:%02d:%02d.%06d'), + y, m, d, H, M, S, ms) end +local function get_time(buf, time_format) + local len = get_u8(buf) + if len == 0 then + return {days = 0, hour = 0, min = 0, sec = 0} + end + local sign = get_u8(buf) == 1 and -1 or 1 + local days = get_u4(buf) * sign + local H = get_u8(buf) + local M = get_u8(buf) + local S = get_u8(buf) + local ms = len == 8 and 0 or get_u32(buf) + return time_format == '*t' + and {days = days, hour = H, min = M, sec = S + ms / 10^6} + or time_format == '*s' and days * 24 * 3600 + H * 3600 + M * 60 + S + ms / 10^6 + or format(time_format or (len == 8 + and '%dd %02d:%02d:%02d' + or '%dd %02d:%02d:%02d.%06d'), days, H, M, S, ms) +end -local function _compute_token(password, scramble) +local function set_datetime(buf, t) + +end + +local function set_time(buf, t) + +end + +local function set_u8(buf, x) + local p, i = buf(1) + assert(x >= 0 and x < 2^8) + p[i] = x +end + +local function set_i8(buf, x) + local p, i = buf(1) + assert(x >= -127 and x <= 128) + ffi.cast(i8_ct, p+i)[0] = x +end + +local function set_u24(buf, x) + local p, i = buf(3) + assert(x >= 0 and x < 2^24) + p[i+0] = band( x , 0xff) + p[i+1] = band(shr(x, 8), 0xff) + p[i+2] = band(shr(x, 16), 0xff) +end + +local function set_u32(buf, x) + local p, i = buf(4) + assert(x >= 0 and x < 2^32) + ffi.cast(u32_ct, p+i)[0] = x +end + +local function set_i32(buf, x) + local p, i = buf(4) + assert(x >= -(2^31-1) and x <= 2^31) + ffi.cast(i32_ct, p+i)[0] = x +end + +local function set_u64(buf, x) + local p, i = buf(8) + assert(x >= 0 and x <= 2^52) + ffi.cast(u64_ct, p+i)[0] = x +end + +local function set_i64(buf, x) + local p, i = buf(8) + assert(x >= -(2^51-1) and x <= 2^51) + ffi.cast(i64_ct, p+i)[0] = x +end + +local function set_f64(buf, x) + local p, i = buf(8) + ffi.cast(f64_ct, p+i)[0] = x +end + +local function set_f32(buf, x) + local p, i = buf(4) + ffi.cast(f32_ct, p+i)[0] = x +end + +local function set_uint(buf, x) --length-encoded int + assert(x >= 0) + if x < 0xfb then + set_u8(buf, x) + elseif x < 2^16 then + set_u8(buf, 0xfc) + set_u16(buf, x) + elseif x < 2^24 then + set_u8(buf, 0xfd) + set_u24(buf, x) + else + set_u8(buf, 0xfe) + set_u64(buf, x) + end +end + +local function set_cstring(buf, s) + local p, i = buf(#s+1) + ffi.copy(p+i, s) +end + +local function set_bytes(buf, s, len) + len = len or #s + local p, i = buf(len) + ffi.copy(p+i, s, len) +end + +local function set_str(buf, s) + set_uint(#s) + set_bytes(buf, s) +end + +local function set_token(buf, password, scramble) if password == '' then return '' end - local stage1 = sha1(password) local stage2 = sha1(stage1) local stage3 = sha1(scramble .. stage2) local n = #stage1 - local bytes = new_tab(n, 0) + set_u8(buf, n) + local p, pi = buf(n) for i = 1, n do - bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i))) + p[pi+i-1] = bxor(strbyte(stage3, i), strbyte(stage1, i)) end - - return concat(bytes) end +local function send_buffer(min_capacity) + local arr = dynarray('uint8_t[?]', min_capacity) + local i = 0 + return function(n) + local p = arr(i+n) + i = i + n + return p, i-n + end +end -local function _send_packet(self, req, size) - local sock = self.sock - +local function send_packet(self, send_buf) + local send_buf, send_len = send_buf(0) self.packet_no = self.packet_no + 1 - - -- print('packet no: ', self.packet_no) - - local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req - - -- print('sending packet: ', _dump(packet)) - - -- print('sending packet... of size ' .. #packet) - - return sock:send(packet) + local buf = send_buffer(4) + set_u24(buf, send_len) + set_u8(buf, band(self.packet_no, 0xff)) + check_io(self, self.tcp:send(buf(0))) + check_io(self, self.tcp:send(send_buf, send_len)) end -local function _recv(self, sz) +local function recv(self, sz) local buf = self.buf if not buf then - buf = buffer'char[?]' + buf = buffer'uint8_t[?]' self.buf = buf end local buf = buf(sz) - local ok, err = self.sock:recvall(buf, sz) - if not ok then return nil, err end - return ffi.string(buf, sz) + check_io(self, self.tcp:recvall(buf, sz)) + local i = 0 + return function(n, err) + n = n or sz-i + check(self, i + n <= sz, err or 'short read') + i = i + n + return buf, i-n, n + end end -local function _recv_packet(self) - local sock = self.sock - - local data, err = _recv(self, 4) -- packet header - if not data then - return nil, nil, 'failed to receive packet header: ' .. err +local function recv_packet(self) + local buf = recv(self, 4) --packet header + local len = get_u24(buf) + check(self, len > 0, 'empty packet') + check(self, len <= self.max_packet_size, 'packet too big') + self.packet_no = get_u8(buf) + local buf = recv(self, len) + local field_count = get_u8(buf) + buf(-1) --peek + if field_count == 0x00 then typ = 'OK' + elseif field_count == 0xff then typ = 'ERR' + elseif field_count == 0xfe then typ = 'EOF' + else typ = 'DATA' end - - --print('packet header: ', _dump(data)) - - local len, pos = _get_byte3(data, 1) - - --print('packet length: ', len) - - if len == 0 then - return nil, nil, 'empty packet' - end - - if len > self._max_packet_size then - return nil, nil, 'packet size too big: ' .. len - end - - local num = strbyte(data, pos) - - --print('recv packet: packet no: ', num) - - self.packet_no = num - - data, err = _recv(self, len) - - --print('receive returned') - - if not data then - return nil, nil, 'failed to read packet content: ' .. err - end - - --print('packet content: ', _dump(data)) - --print('packet content (ascii): ', data) - - local field_count = strbyte(data, 1) - - local typ - if field_count == 0x00 then - typ = 'OK' - elseif field_count == 0xff then - typ = 'ERR' - elseif field_count == 0xfe then - typ = 'EOF' - else - typ = 'DATA' - end - - return data, typ + return typ, buf end - -local function _from_length_coded_bin(data, pos) - local first = strbyte(data, pos) - - --print('LCB: first: ', first) - - if not first then - return nil, pos - end - - if first >= 0 and first <= 250 then - return first, pos + 1 - end - - if first == 251 then - return nil, pos + 1 - end - - if first == 252 then - pos = pos + 1 - return _get_byte2(data, pos) - end - - if first == 253 then - pos = pos + 1 - return _get_byte3(data, pos) - end - - if first == 254 then - pos = pos + 1 - return _get_byte8(data, pos) - end - - return nil, pos + 1 +local function get_name(buf) + local s = get_str(buf) + return s ~= '' and s:lower() or nil end - -local function _from_length_coded_str(data, pos) - local len - len, pos = _from_length_coded_bin(data, pos) - if not len then - return nil, pos - end - return sub(data, pos, pos + len - 1), pos + len -end - -local function _parse_ok_packet(packet) - local res = new_tab(0, 5) - local pos - - res.affected_rows, pos = _from_length_coded_bin(packet, 2) - - --print('affected rows: ', res.affected_rows, ', pos:', pos) - - res.insert_id, pos = _from_length_coded_bin(packet, pos) - - if res.insert_id == 0 then - res.insert_id = nil - end - - --print('insert id: ', res.insert_id, ', pos:', pos) - - res.server_status, pos = _get_byte2(packet, pos) - - --print('server status: ', res.server_status, ', pos:', pos) - - res.warning_count, pos = _get_byte2(packet, pos) - - --print('warning count: ', res.warning_count, ', pos: ', pos) - - res.message = _from_length_coded_str(packet, pos) - - --print('message: ', res.message, ', pos:', pos) - - return res -end - - -local function _parse_eof_packet(packet) - local pos = 2 - - local warning_count, pos = _get_byte2(packet, pos) - local status_flags = _get_byte2(packet, pos) - +local function get_eof_packet(buf) + local _ = get_u8(buf) --status: EOF + local warning_count = get_u16(buf) + local status_flags = get_u16(buf) return warning_count, status_flags end +local UNSIGNED_FLAG = 32 -local function _parse_err_packet(packet) - local errno, pos = _get_byte2(packet, 2) - local marker = sub(packet, pos, pos) - local sqlstate - if marker == '#' then - -- with sqlstate - pos = pos + 1 - sqlstate = sub(packet, pos, pos + 5 - 1) - pos = pos + 5 - end - - local message = sub(packet, pos) - message = message:gsub('You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near ', 'Syntax error: ') - return errno, message, sqlstate -end - - -local function _parse_result_set_header_packet(packet) - local field_count, pos = _from_length_coded_bin(packet, 1) - - local extra - extra = _from_length_coded_bin(packet, pos) - - return field_count, extra -end - -local function _parse_field(data, pos) - local s, pos = _from_length_coded_str(data, pos) - s = s and s ~= '' and s:lower() or nil - return s, pos -end - -local charset_bytes = { - utf8 = 3, - utf8mb4 = 4, -} - ---NOTE: MySQL doesn't give enough info to make editable fields in a UI, +--NOTE: MySQL doesn't give enough metadata to generate a form in a UI, --you'll have to query `information_schema` to get the rest like enum values ---and defaults. So we only keep enough info to format read-only fields in a UI. -local function _parse_field_packet(data) - local col = new_tab(0, 16) - local catalog, pos = _parse_field(data, 1) --always "def" - col.schema, pos = _parse_field(data, pos) - col.table, pos = _parse_field(data, pos) - col.origin_table, pos = _parse_field(data, pos) - col.name, pos = _parse_field(data, pos) - col.origin_name, pos = _parse_field(data, pos) - pos = pos + 1 --ignore the filler - local collation, pos = _get_byte2(data, pos) - col.max_char_w, pos = _get_byte4(data, pos) - local buffer_type = buffer_type_names[strbyte(data, pos)] +--and defaults. So we only keep enough info for formatting the values. +local function get_field_packet(buf) + local col = {} + local _ = get_name(buf) --always "def" + col.schema = get_name(buf) + col.table = get_name(buf) + col.origin_table = get_name(buf) + col.name = get_name(buf) + col.origin_name = get_name(buf) + local _ = get_uint(buf) --0x0c + local collation = get_u16(buf) + col.max_char_w = get_u32(buf) + local buf_type_code = get_u8(buf) + local flags = get_u16(buf) + local decimals = get_u8(buf) + local buf_type = buffer_type_names[buf_type_code] if collation == 63 then - col.type = bin_type_names[buffer_type] - or type_names[buffer_type] - or buffer_type + col.type = bin_types[buf_type] + or num_types[buf_type] + or buf_type else - col.type = text_type_names[buffer_type] + col.type = text_types[buf_type] col.collation = collation_names[collation] col.charset = col.collation and col.collation:match'^[^_]+' - col.max_char_w = col.max_char_w / (charset_bytes[col.charset] or 1) end - pos = pos + 1 - local flags, pos = _get_byte2(data, pos) - col.decimals = strbyte(data, pos) --for formatting only, not for editing! - if col.type ~= 'decimal' and col.decimals == 0x1f then --varchar and floats - col.decimals = nil + if col.type == 'decimal' then + col.decimals = decimals end + col.buffer_type = buf_type + col.buffer_type_code = buf_type_code + col.unsigned = band(flags, UNSIGNED_FLAG) ~= 0 or nil return col end - -local function _parse_row_data_packet(data, cols, compact, to_array, null_value) - local pos = 1 - local ncols = #cols - local row - if not to_array then - if compact then - row = new_tab(ncols, 0) - else - row = new_tab(0, ncols) - end +local function recv_field_packets(self, field_count) + local fields = {} + for i = 1, field_count do + local typ, buf = recv_packet(self) + check(self, typ == 'DATA', 'bad packet type') + local field = get_field_packet(buf) + field.index = i + fields[i] = field + fields[field.name] = field end - for i = 1, ncols do - local value - value, pos = _from_length_coded_str(data, pos) - local col = cols[i] - local typ = col.type - local name = col.name - - --print('row field value: ', value, ', type: ', typ) - - if value ~= nil then - local conv = converters[typ] - if conv then - value = conv(value) - end - else - value = null_value - end - - if to_array then - return value - end - - if compact then - row[i] = value - else - row[name] = value - end + if field_count > 0 then + local typ, buf = recv_packet(self) + check(self, typ == 'EOF', 'bad packet type') + get_eof_packet(buf) end - - return row + return fields end - -local function _recv_field_packet(self) - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end - - if typ == 'ERR' then - local errno, msg, sqlstate = _parse_err_packet(packet) - return nil, msg, errno, sqlstate - end - - if typ ~= 'DATA' then - return nil, 'bad field packet type: ' .. typ - end - - -- typ == 'DATA' - - return _parse_field_packet(packet) +local function get_err_packet(buf) + local _ = get_u8(buf) + local errno = get_u16(buf) + local marker = get_u8(buf) + local sqlstate = strchar(marker) == '#' and get_bytes(buf, 5) or nil + local message = get_bytes(buf) + message = message:gsub('You have an error in your SQL syntax; ' + ..'check the manual that corresponds to your MySQL server version ' + ..'for the right syntax to use near ', 'Syntax error: ') + return message, errno, sqlstate end +function mysql.connect(opt) -function mysql.new(self, opt) local tcp = opt and opt.tcp or require'sock'.tcp - local sock, err = tcp() - if not sock then - return nil, err - end - return setmetatable({ sock = sock }, mt) -end - - -function conn:connect(opts) - local sock = self.sock - - local max_packet_size = opts.max_packet_size - if not max_packet_size then - max_packet_size = 1024 * 1024 -- default 1 MB - end - self._max_packet_size = max_packet_size + local tcp = check_io(self, tcp()) + local self = setmetatable({tcp = tcp}, conn_mt) + self.max_packet_size = opt.max_packet_size or 16 * 1024 * 1024 --16 MB local ok, err - local database = opts.database or '' - local user = opts.user or '' + local database = opt.database or '' + local user = opt.user or '' local collation = 0 --default - if opts.collation then - collation = assert(collation_codes[opts.collation], 'invalid collation') - elseif opts.charset then - collation = assert(default_collations[opts.charset], 'invalid charset') + if opt.collation then + collation = assert(collation_codes[opt.collation], 'invalid collation') + elseif opt.charset then + collation = assert(default_collations[opt.charset], 'invalid charset') collation = assert(collation_codes[collation]) end - local host = opts.host - local port = opts.port or 3306 - ok, err, errcode = sock:connect(host, port) - - if not ok then - return nil, err, errcode - end - - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end + local host = opt.host + local port = opt.port or 3306 + check_io(self, self.tcp:connect(host, port)) + local typ, buf = recv_packet(self) if typ == 'ERR' then - local errno, msg, sqlstate = _parse_err_packet(packet) - return nil, msg, errno, sqlstate + return nil, get_err_packet(buf) end - - self.protocol_ver = strbyte(packet) - - --print('protocol version: ', self.protocol_ver) - - local server_ver, pos = _from_cstring(packet, 2) - if not server_ver then - return nil, 'bad handshake initialization packet: bad server version' - end - - --print('server version: ', server_ver) - - self._server_ver = server_ver - - local thread_id, pos = _get_byte4(packet, pos) - - --print('thread id: ', thread_id) - - local scramble = sub(packet, pos, pos + 8 - 1) - if not scramble then - return nil, '1st part of scramble not found' - end - - pos = pos + 9 -- skip filler - - -- two lower bytes - local capabilities -- server capabilities - capabilities, pos = _get_byte2(packet, pos) - - -- print(format('server capabilities: %#x', capabilities)) - - self._server_lang = strbyte(packet, pos) - pos = pos + 1 - - --print('server lang: ', self._server_lang) - - self._server_status, pos = _get_byte2(packet, pos) - - --print('server status: ', self._server_status) - - local more_capabilities - more_capabilities, pos = _get_byte2(packet, pos) - - capabilities = bor(capabilities, lshift(more_capabilities, 16)) - - --print('server capabilities: ', capabilities) - - -- local len = strbyte(packet, pos) - local len = 21 - 8 - 1 - - --print('scramble len: ', len) - - pos = pos + 1 + 10 - - local scramble_part2 = sub(packet, pos, pos + len - 1) - if not scramble_part2 then - return nil, '2nd part of scramble not found' - end - + self.protocol_ver = get_u8(buf) + self.server_ver = get_cstring(buf) + self.thread_id = get_u32(buf) + local scramble = get_bytes(buf, 8) + buf(1) --filler + local capabilities = get_u16(buf) + self.server_lang = get_u8(buf) + self.server_status = get_u16(buf) + local more_capabilities = get_u16(buf) + capabilities = bor(capabilities, shl(more_capabilities, 16)) + get_bytes(buf, 1 + 10) + local scramble_part2 = get_bytes(buf, 21 - 8 - 1) scramble = scramble .. scramble_part2 - --print('scramble: ', _dump(scramble)) - - local client_flags = 0x3f7cf; - - local ssl_verify = opts.ssl_verify - local use_ssl = opts.ssl or ssl_verify + local client_flags = 0x3f7cf + local ssl_verify = opt.ssl_verify + local use_ssl = opt.ssl or ssl_verify + local buf = send_buffer(64) if use_ssl then - if band(capabilities, CLIENT_SSL) == 0 then - return nil, 'ssl disabled on server' - end - - -- send a SSL Request Packet - local req = _set_byte4(bor(client_flags, CLIENT_SSL)) - .. _set_byte4(self._max_packet_size) - .. strchar(collation) - .. strrep('\0', 23) - - local packet_len = 4 + 4 + 1 + 23 - local bytes, err = _send_packet(self, req, packet_len) - if not bytes then - return nil, 'failed to send client authentication packet: ' .. err - end - - local ok, err = sock:sslhandshake(false, nil, ssl_verify) - if not ok then - return nil, 'failed to do ssl handshake: ' .. (err or '') - end - end - - local password = opts.password or '' - - local token = _compute_token(password, scramble) - - --print('token: ', _dump(token)) - - local req = _set_byte4(client_flags) - .. _set_byte4(self._max_packet_size) - .. strchar(collation) - .. strrep('\0', 23) - .. _to_cstring(user) - .. _to_binary_coded_string(token) - .. _to_cstring(database) - - local packet_len = 4 + 4 + 1 + 23 + #user + 1 - + #token + 1 + #database + 1 - - -- print('packet content length: ', packet_len) - -- print('packet content: ', _dump(concat(req, ''))) - - local bytes, err = _send_packet(self, req, packet_len) - if not bytes then - return nil, 'failed to send client authentication packet: ' .. err - end - - --print('packet sent ', bytes, ' bytes') - - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, 'failed to receive the result packet: ' .. err + check(self, band(capabilities, CLIENT_SSL) ~= 0, 'ssl disabled on server') + set_u32(buf, bor(client_flags, CLIENT_SSL)) + set_u32(buf, self.max_packet_size) + set_u8(buf, collation) + buf(23) + send_packet(self, buf) + check_io(self, self.tcp:sslhandshake(false, nil, ssl_verify)) end + set_u32(buf, client_flags) + set_u32(buf, self.max_packet_size) + set_u8(buf, collation) + buf(23) + set_cstring(buf, user) + set_token(buf, opt.password or '', scramble) + set_cstring(buf, database) + send_packet(self, buf) + local typ, buf = recv_packet(self) if typ == 'ERR' then - local errno, msg, sqlstate = _parse_err_packet(packet) - return nil, msg, errno, sqlstate - end - - if typ == 'EOF' then + return nil, get_err_packet(buf) + elseif typ == 'EOF' then return nil, 'old pre-4.1 authentication protocol not supported' end - - if typ ~= 'OK' then - return nil, 'bad packet type: ' .. typ - end - + check(self, typ == 'OK', 'bad packet type') self.state = 'ready' - - return 1 + return self end +conn.connect = protect(conn.connect) function conn:close() - local sock = assert(self.sock) - - self.state = nil - - local bytes, err = _send_packet(self, strchar(COM_QUIT), 1) - if not bytes then - return nil, err + if self.state then + local buf = send_buffer(1) + set_u8(buf, COM_QUIT) + send_packet(self, buf) + check_io(self, self.tcp:close()) + self.state = nil end - - return sock:close() -end - -function conn:server_ver() - return self._server_ver + return true end +conn.close = protect(conn.close) function conn:send_query(query) assert(self.state == 'ready') - local sock = assert(self.sock) - self.packet_no = -1 - - local cmd_packet = strchar(COM_QUERY) .. query - local packet_len = 1 + #query - - local bytes, err = _send_packet(self, cmd_packet, packet_len) - if not bytes then - return nil, err - end - + local buf = send_buffer(1 + #query) + set_u8(buf, COM_QUERY) + set_bytes(buf, query) + send_packet(self, buf) self.state = 'read' - - --print('packet sent ', bytes, ' bytes') - - return bytes + return true end +conn.send_query = protect(conn.send_query) function conn:read_result(opt) - assert(self.state == 'read') - local sock = assert(self.sock) - - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end - + assert(self.state == 'read' or self.state == 'read_binary') + local typ, buf = recv_packet(self) if typ == 'ERR' then self.state = 'ready' - - local errno, msg, sqlstate = _parse_err_packet(packet) - return nil, msg, errno, sqlstate - end - - if typ == 'OK' then - local res = _parse_ok_packet(packet) - if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return nil, get_err_packet(buf) + elseif typ == 'OK' then + local res = {} + res.affected_rows = get_uint(buf) + res.insert_id = get_uint(buf) + res.server_status = get_u16(buf) + res.warning_count = get_u16(buf) + res.message = get_str(buf) + res.insert_id = repl(res.insert_id, 0, nil) + if band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then return res, 'again' + else + self.state = 'ready' + return res end - - self.state = 'ready' - return res end + check(self, typ == 'DATA', 'bad packet type') - if typ ~= 'DATA' then - self.state = 'ready' + local field_count = get_uint(buf) + local extra = buf_len(buf) > 0 and get_uint(buf) or nil - return nil, 'packet type ' .. typ .. ' not supported' - end + local cols = recv_field_packets(self, field_count) - -- typ == 'DATA' + local compact = opt and opt.compact + local to_array = opt and opt.to_array and #cols == 1 + local null_value = opt and opt.null_value + local datetime_format = opt and opt.datetime_format + local date_format = opt and opt.date_format + local time_format = opt and opt.time_format - --print('read the result set header packet') - - local field_count, extra = _parse_result_set_header_packet(packet) - - --print('field count: ', field_count) - - local cols = new_tab(field_count, 0) - for i = 1, field_count do - local col, err, errno, sqlstate = _recv_field_packet(self) - if not col then - return nil, err, errno, sqlstate - end - - col.index = i - cols[i] = col - cols[col.name] = col - end - - local packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end - - if typ ~= 'EOF' then - return nil, 'unexpected packet type ' .. typ .. ' while eof packet is ' - .. 'expected' - end - - -- typ == 'EOF' - - local compact = opt and opt.compact - local to_array = opt and opt.to_array and #cols == 1 - local null_value = opt and opt.null_value - - local rows = new_tab(4, 0) + local rows = {} local i = 0 while true do - --print('reading a row') - - packet, typ, err = _recv_packet(self) - if not packet then - return nil, err - end + local typ, buf = recv_packet(self) if typ == 'EOF' then - local warning_count, status_flags = _parse_eof_packet(packet) - - --print('status flags: ', status_flags) - + local _, status_flags = get_eof_packet(buf) if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then return rows, 'again', cols end - break end - -- if typ ~= 'DATA' then - -- return nil, 'bad row packet type: ' .. typ - -- end + local row = not to_array and {} or nil - -- typ == 'DATA' - - local row = _parse_row_data_packet(packet, cols, compact, to_array, null_value) + if self.state == 'read_binary' then + check(get_u8(buf) == 0, 'invalid row packet') + local nulls_len = math.floor((#cols + 7 + 2) / 8) + local nulls, nulls_offset = buf(nulls_len) + for i, col in ipairs(cols) do + local null_byte = shr(i-1+2, 3) + nulls_offset + local null_bit = band(i-1+2, 7) + local is_null = band(nulls[null_byte], shl(1, null_bit)) ~= 0 + local v + if not is_null then + local bt = col.buffer_type + local unsigned = col.unsigned + if string_types[bt] then + v = get_str(buf) + elseif bt == 'longlong' then + v = unsigned and get_u64(buf) or get_i64(buf) + elseif bt == 'int24' or bt == 'long' then + v = unsigned and get_u32(buf) or get_i32(buf) + elseif bt == 'year' then + v = unsigned and get_u16(buf) or get_i16(buf) + elseif bt == 'tiny' then + v = unsigned and get_u8(buf) or get_i8(buf) + elseif bt == 'double' then + v = get_f64(buf) + elseif bt == 'float' then + v = get_f32(buf) + elseif bt == 'date' or bt == 'datetime' or bt == 'timestamp' then + v = get_datetime(buf, bt == 'date' and date_format or datetime_format) + elseif bt == 'time' then + v = get_time(buf, time_format) + else + check(self, false, 'unsupported param type '..bt) + end + else + v = null_value + end + if to_array then + row = v + elseif compact then + row[i] = v + else + row[col.name] = v + end + end + else + for i, col in ipairs(cols) do + local v = get_str(buf) + if v ~= nil then + local convert = from_text_converters[col.type] + if convert then + v = convert(v) + end + else + v = null_value + end + if to_array then + row = v + elseif compact then + row[i] = v + else + row[col.name] = v + end + end + end i = i + 1 rows[i] = row end self.state = 'ready' - return rows, nil, cols end +conn.read_result = protect(conn.read_result) function conn:query(query, opt) local bytes, err, errcode = self:send_query(query) @@ -1185,6 +1084,137 @@ function conn:query(query, opt) return self:read_result(opt) end +local stmt = {} + +local cursor_types = { + no_cursor = 0x00, + read_only = 0x01, + update = 0x02, + scrollable = 0x04, +} + +function conn:prepare(query, cursor_type) + assert(self.state == 'ready') + self.packet_no = -1 + local buf = send_buffer(1 + #query) + set_u8(buf, COM_STMT_PREPARE) + set_bytes(buf, query) + send_packet(self, buf) + + local typ, buf = recv_packet(self) + if typ == 'ERR' then + return nil, get_err_packet(buf) + end + check(self, typ == 'OK', 'bad packet type') + buf(1) --status + local stmt = update({conn = self}, stmt) + stmt.id = get_u32(buf) + local col_count = get_u16(buf) + local param_count = get_u16(buf) + buf(1) --filler + stmt.warning_count = get_u16(buf) + stmt.params = recv_field_packets(self, param_count) + stmt.cols = recv_field_packets(self, col_count) + stmt.cursor_type = assert(cursor_types[cursor_type or 'no_cursor']) + return stmt +end +conn.prepare = protect(conn.prepare) + +function stmt:free() + local self, stmt = self.conn, self + assert(self.state == 'ready') + self.packet_no = -1 + local buf = send_buffer(5) + set_u8(buf, COM_STMT_CLOSE) + set_u32(buf, stmt.id) + return true +end +stmt.free = protect(stmt.free) + +function stmt:exec(...) + local self, stmt = self.conn, self + assert(self.state == 'ready') + self.packet_no = -1 + local buf = send_buffer(64) + set_u8(buf, COM_STMT_EXECUTE) + set_u32(buf, stmt.id) + set_u8(buf, stmt.cursor_type) + set_u32(buf, 1) --iteration-count, must be 1 + if #stmt.params > 0 then + local nulls_len = math.floor((#stmt.params + 7) / 8) + local nulls = ffi.new('uint8_t[?]', nulls_len) + for i = 1, #stmt.params do + local val = select(i, ...) + if val == nil then + local byte = shr(i-1, 3) + local bit = band(i-1, 7) + nulls[byte] = bor(nulls[byte], shl(1, bit)) + end + end + set_bytes(buf, nulls, nulls_len) + set_u8(buf, 1) --new-params-bound-flag + for i, param in ipairs(stmt.params) do + set_u8(buf, param.buffer_type_code) + set_u8(buf, param.unsigned and 0x80 or 0) + end + for i, param in ipairs(stmt.params) do + local val = select(i, ...) + if val ~= nil then + local bt = param.buffer_type + local unsigned = param.unsigned + if string_types[bt] then + set_str(buf, tostring(val)) + elseif bt == 'longlong' then + if unsigned then + set_u64(buf, val) + else + set_i64(buf, val) + end + elseif bt == 'int24' or bt == 'long' then + if unsigned then + assert(val >= 0 and val < (bt == 'int24' and 2^24 or 2^32)) + set_u32(buf, val) + else + if bt == 'int24' then + assert(val >= -(2^23-1) and val <= 2^23-1) + else + assert(val >= -(2^31-1) and val <= 2^31) + end + set_i32(buf, val) + end + elseif bt == 'year' then + if unsigned then + set_u16(buf, val) + else + set_i16(buf, val) + end + elseif bt == 'tiny' then + if unsigned then + set_u8(buf, val) + else + set_i8(buf, val) + end + elseif bt == 'double' then + set_f64(buf, val) + elseif bt == 'float' then + set_f32(buf, val) + elseif bt == 'date' or bt == 'datetime' or bt == 'timestamp' then + set_datetime(buf, val) + elseif bt == 'time' then + set_time(buf, val) + else + check(self, false, 'unsupported param type '..bt) + end + end + end + + end + send_packet(self, buf) + self.state = 'read_binary' + return true +end +stmt.exec = protect(stmt.exec) + local qmap = { ['\0' ] = '\\0', ['\b' ] = '\\b', @@ -1200,4 +1230,26 @@ function mysql.quote(s) return s:gsub('[%z\b\n\r\t\26\\\'\"]', qmap) end + +if not ... then --demo + + local sock = require'sock' + local pp = require'pp' + sock.run(function() + local conn = mysql.connect{ + host = '127.0.0.1', + port = 3307, + user = 'root', + password = 'abcd12', + database = 'sp', + } + pp(conn:query'select * from val where val = 1') + local stmt = conn:prepare('select * from val where val = ?') + assert(stmt:exec(1)) + pp(conn:read_result({datetime_format = '*t'})) + assert(stmt:free()) + end) + +end + return mysql diff --git a/mysql_client.md b/mysql_client.md index b3ce2bd..99ecbe8 100644 --- a/mysql_client.md +++ b/mysql_client.md @@ -1,25 +1,21 @@ ## `local mysql = require'mysql_client'` -MySQL client protocol in Lua. Stolen from OpenResty and modified to work standalone. - -## Status - -This library is considered production ready. +MySQL client protocol in Lua. +Stolen from OpenResty and modified to work standalone. ## Example ```lua local mysql = require'mysql_client' -local cn = assert(mysql:new()) -assert(cn:connect{ +assert(mysql.connect{ host = '127.0.0.1', port = 3306, - database = 'ngx_test', - user = 'ngx_test', - password = 'ngx_test', - charset = 'utf8', + database = 'foo', + user = 'bar', + password = 'baz', + charset = 'utf8mb4', max_packet_size = 1024 * 1024, }) @@ -35,21 +31,14 @@ local res = assert(cn:query('insert into cats (name) ' print(res.affected_rows, ' rows inserted into table cats ', '(last insert id: ', res.insert_id, ')') -local res = assert(cn:query('select * from cats order by id asc', 10)) - -local cjson = require'cjson' -print(cjson.encode(res)) +require'pp'(assert(cn:query('select * from cats order by id asc', 10))) assert(cn:close()) ``` ## API -### `mysql:new() -> cn | nil,err` - -Creates a MySQL connection object. - -### `cn:connect(options) -> ok | nil,err,errcode,sqlstate` +### `mysql.connect(options) -> ok | nil,err,errcode,sqlstate` Connect to a MySQL server. @@ -129,12 +118,22 @@ You should always check if the `err` return value is `again` in case of success because this method will only call [read_result](#read_result) once for you. -### `cn:server_ver() -> s` -Returns the MySQL server version string, like `"5.1.64"`. +### `cn:prepare(query) -> stmt` -You should only call this method after successfully connecting to a MySQL server, -otherwise `nil` will be returned. +Prepare a statement. + +### `stmt:exec(params...)` + +Execute a statement. Use `cn:read_result()` to get the results. + +### `stmt:free()` + +Free statement. + +### `cn.server_ver` + +The MySQL server version string. ### `mysql.quote(s) -> s` @@ -158,7 +157,5 @@ are suppored. ## TODO -* implement the MySQL binary row data packets. -* implement MySQL server prepare and execute packets. * implement the data compression support in the protocol.