diff --git a/mysql_client.lua b/mysql_client.lua index 8205eb7..ff0aa74 100644 --- a/mysql_client.lua +++ b/mysql_client.lua @@ -381,6 +381,24 @@ local mb_charsets = { --excluding ASCII supersets i.e. utf8* charsets. gb18030=1, } +local max_char_widths = { + utf8 = 3, + utf8mb4 = 4, + big5 = 2, + sjis = 2, + euckr = 2, + gb2312 = 2, + gbk = 2, + ucs2 = 2, + cp932 = 2, + ujis = 3, --eukjp + eucjpms = 3, + utf16 = 4, + utf16le = 4, + utf32 = 4, + gb18030 = 2, +} + local buffer_type_names = { [ 0] = 'decimal', [ 1] = 'tiny', @@ -410,28 +428,25 @@ local buffer_type_names = { [255] = 'geometry', } -local num_types = { +local bin_types = { tiny = 'tinyint', - short = 'shortint', - long = 'int', + short = 'smallint', int24 = 'mediumint', + long = 'int', longlong = 'bigint', newdecimal = 'decimal', -} - -local bin_types = { - tiny_blob = 'tinyblob', - medium_blob = 'mediumblob', - long_blob = 'longblob', + tiny_blob = 'tinyblob', --always selected as blob + medium_blob = 'mediumblob', --always selected as blob + long_blob = 'longblob', --always selected as blob blob = 'blob', var_string = 'varbinary', string = 'binary', } local text_types = { - tiny_blob = 'tinytext', - medium_blob = 'mediumtext', - long_blob = 'longtext', + tiny_blob = 'tinytext', --always selected as text + medium_blob = 'mediumtext', --always selected as text + long_blob = 'longtext', --always selected as text blob = 'text', var_string = 'varchar', string = 'char', @@ -453,19 +468,23 @@ local string_types = { } local int_ranges = { - tinyint = {-(2^ 7-1), 2^ 7, 0, 2^ 8-1}, - shortint = {-(2^15-1), 2^15, 0, 2^16-1}, - mediumint = {-(2^23-1), 2^23, 0, 2^24-1}, - int = {-(2^31-1), 2^31, 0, 2^32-1}, - bigint = {-(2^51-1), 2^51, 0, 2^52-1}, + tinyint = {1, -(2^ 7-1), 2^ 7, 0, 2^ 8-1}, + smallint = {2, -(2^15-1), 2^15, 0, 2^16-1}, + mediumint = {3, -(2^23-1), 2^23, 0, 2^24-1}, + int = {4, -(2^31-1), 2^31, 0, 2^32-1}, + bigint = {8, -(2^51-1), 2^51, 0, 2^52-1}, } local conn = {} local conn_mt = {__index = conn} -local to_lua = { +function mysql.isconn(x) + return getmetatable(x) == conn_mt +end + +local default_to_lua = { tinyint = tonumber, - shortint = tonumber, + smallint = tonumber, mediumint = tonumber, int = tonumber, bigint = tonumber, @@ -474,13 +493,6 @@ local to_lua = { double = tonumber, decimal = tonumber, } -function mysql.to_lua(v, col) - local to_lua = col.to_lua or to_lua[col.type] - if to_lua then - v = to_lua(v) - end - return v -end local function return_arg1(v) return v end @@ -867,26 +879,26 @@ end local UNSIGNED_FLAG = 32 -function mysql.num_range(type, unsigned, digits, decimals) - if digits and decimals then - local max = 10^(digits - decimals) - 1 / 10^decimals - local min = unsigned and 0 or -max --unsigned decimals is deprecated! - return min, max - else - local range = int_ranges[type] - if range then - if unsigned then - return range[3], range[4] - else - return range[1], range[2] - end +function mysql.int_range(type, unsigned) --min, max, size + local range = int_ranges[type] + if range then + if unsigned then + return range[4], range[5], range[1] + else + return range[2], range[3], range[1] end end end +function mysql.dec_range(digits, decimals, unsigned) --min, max, digits + local max = 10^(digits - decimals) - 10^-decimals + local min = unsigned and 0 or -max --unsigned decimals is deprecated! + return min, max, digits +end + --NOTE: MySQL doesn't give enough metadata to generate a form in a UI, --you'll have to query `information_schema` to get the rest like enum values ---and defaults. So we only keep enough info for formatting the values. +--and defaults. So we gather only what we need for display, not for editing. local function get_field_packet(buf) local col = {} local _ = get_name(buf) --always "def" @@ -896,39 +908,68 @@ local function get_field_packet(buf) col.name = get_name(buf) --alias column name col.col = get_name(buf) --name of column in origin table local _ = get_uint(buf) --0x0c - local collation = get_u16(buf) - col.max_char_w = get_u32(buf) + local collation = get_u16(buf) --connection's collation, not field's. + local display_size = get_u32(buf) --in bytes, not in characters. local buf_type_code = get_u8(buf) local flags = get_u16(buf) local decimals = get_u8(buf) local buf_type = buffer_type_names[buf_type_code] - if collation == 63 then - local type = bin_types[buf_type] - if not type then - type = num_types[buf_type] - if type then - col.decimals = decimals - end + local mysql_type + if collation == 63 then --binary + mysql_type = bin_types[buf_type] or buf_type + local unsigned = band(flags, UNSIGNED_FLAG) ~= 0 or nil --for val decoding + if mysql_type == 'tinyint' or mysql_type == 'smallint' + or mysql_type == 'mediumint' or mysql_type == 'int' + or mysql_type == 'bigint' + then + col.type = 'number' + col.decimals = 0 + col.unsigned = unsigned + elseif mysql_type == 'decimal' then + col.type = 'number' + col.decimals = decimals + col.unsigned = unsigned + elseif mysql_type == 'float' then + col.type = 'number' + elseif mysql_type == 'double' then + col.type = 'number' + elseif mysql_type == 'year' then + col.type = 'number' + elseif mysql_type == 'timestamp' then + col.type = 'date' + col.has_time = true + elseif mysql_type == 'date' then + col.type = 'date' + elseif mysql_type == 'datetime' then + col.type = 'date' + col.has_time = true + elseif mysql_type == 'binary' then + col.padded = true end - col.type = type or buf_type + col.display_width = display_size else - col.type = text_types[buf_type] - col.collation = collation_names[collation] - col.charset = col.collation and col.collation:match'^[^_]+' + mysql_type = text_types[buf_type] or buf_type + local collation = collation_names[collation] + local charset = collation and collation:match'^[^_]+' + col.mysql_display_collation = collation + col.mysql_display_charset = charset + col.padded = mysql_type == 'char' or nil + col.display_width = math.ceil(display_size / (max_char_widths[charset] or 1)) end - col.buffer_type = buf_type - col.buffer_type_code = buf_type_code - col.unsigned = band(flags, UNSIGNED_FLAG) ~= 0 or nil - col.min, col.max = mysql.num_range(col.type, col.unsigned, nil, col.decimals) + col.mysql_display_type = mysql_type + col.mysql_buffer_type = buf_type --for param encoding + col.mysql_buffer_type_code = buf_type_code --for val decoding return col end -local function recv_field_packets(self, field_count, field_attrs) +local function recv_field_packets(self, field_count, field_attrs, to_lua) local fields = {} + to_lua = to_lua or self.to_lua for i = 1, field_count do local typ, buf = recv_packet(self) checkp(self, typ == 'DATA', 'bad packet type') local field = get_field_packet(buf) + field.to_lua = to_lua or default_to_lua[field.mysql_type] field.index = i fields[i] = field fields[field.name] = field @@ -969,6 +1010,10 @@ function mysql.note (...) mysql.log('note', ...) end function mysql.connect(opt) + if mysql.isconn(opt) then --pass-through + return opt + end + local host = opt.host local port = opt.port or 3306 @@ -1115,7 +1160,7 @@ local function read_result(self, opt) local field_count = get_uint(buf) local extra = buf_len(buf) > 0 and get_uint(buf) or nil - local cols = recv_field_packets(self, field_count, opt and opt.field_attrs) + local cols = recv_field_packets(self, field_count, opt and opt.field_attrs, opt and opt.to_lua) local compact = opt and opt.compact local to_array = opt and opt.to_array and #cols == 1 @@ -1123,7 +1168,6 @@ local function read_result(self, opt) local datetime_format = opt and opt.datetime_format or self.datetime_format local date_format = opt and opt.date_format or self.date_format local time_format = opt and opt.time_format or self.time_format - local to_lua = opt and opt.to_lua or self.to_lua local rows = {} local i = 0 @@ -1155,7 +1199,7 @@ local function read_result(self, opt) local is_null = band(nulls[null_byte], shl(1, null_bit)) ~= 0 local v if not is_null then - local bt = col.buffer_type + local bt = col.mysql_buffer_type local unsigned = col.unsigned if string_types[bt] then v = get_str(buf) @@ -1193,7 +1237,10 @@ local function read_result(self, opt) for i, col in ipairs(cols) do local v = get_str(buf) if v ~= nil then - v = to_lua(v, col) + local to_lua = col.to_lua + if to_lua then + v = to_lua(v, col) + end else v = null_value end @@ -1268,8 +1315,8 @@ function conn:prepare(query, opt) local param_count = get_u16(buf) buf(1) --filler stmt.warning_count = get_u16(buf) - stmt.params = recv_field_packets(self, param_count, opt and opt.param_attrs) - stmt.cols = recv_field_packets(self, col_count, opt and opt.field_attrs) + stmt.params = recv_field_packets(self, param_count, opt and opt.param_attrs, opt and opt.to_lua) + stmt.cols = recv_field_packets(self, col_count , opt and opt.field_attrs, opt and opt.to_lua) stmt.cursor = assert(cursor_types[opt and opt.cursor or 'none']) return stmt end @@ -1309,13 +1356,13 @@ function stmt:exec(...) set_bytes(buf, nulls, nulls_len) set_u8(buf, 1) --new-params-bound-flag for i, param in ipairs(stmt.params) do - set_u8(buf, param.buffer_type_code) + set_u8(buf, param.mysql_buffer_type_code) set_u8(buf, param.unsigned and 0x80 or 0) end for i, param in ipairs(stmt.params) do local val = select(i, ...) if val ~= nil then - local bt = param.buffer_type + local bt = param.mysql_buffer_type local unsigned = param.unsigned if string_types[bt] then set_str(buf, tostring(val)) @@ -1390,34 +1437,14 @@ local qmap = { ['\26'] = '\\Z', ['\"' ] = '\\"', } +local function esc_utf8(s) + return s:gsub('[\\\'%z\b\n\r\t\26\"]', qmap) +end +mysql.esc_utf8 = esc_utf8 function conn:esc(s) --MBCS that are not ASCII supersets need decoding for correct quoting. assert(self.charset_is_ascii_superset, 'NYI') - return s:gsub('[\\\'%z\b\n\r\t\26\"]', qmap) -end - - -if not ... then --demo - - local sock = require'sock' - local pp = require'pp' - sock.run(function() - local conn = assert(mysql.connect{ - host = '127.0.0.1', - port = 3307, - user = 'root', - password = 'abcd12', - schema = 'sp', - collation = 'server', - }) - print(conn.charset, conn.collation) - --pp(conn:query'select * from val where val = 1') - local stmt = assert(conn:prepare('select min_price from vari where val = ?')) - assert(stmt:exec()) - pp(conn:read_result({datetime_format = '*t'})) - assert(stmt:free()) - end) - + return esc_utf8(s) end return mysql diff --git a/mysql_client.md b/mysql_client.md index 21e89b6..8b9018a 100644 --- a/mysql_client.md +++ b/mysql_client.md @@ -1,8 +1,10 @@ ## `local mysql = require'mysql_client'` -MySQL client protocol in Lua. -Stolen from OpenResty, modified to work with [sock] and added prepared statements. +MySQL client protocol in Lua. Ripped from OpenResty, modified to work with +[sock], added prepared statements, better interpretation of field metadata +(consistent with [sqlpp], [schema] and [xrowset][x-widges]), and other minor +changes. ## Example @@ -81,6 +83,7 @@ The `options` arg can contain: * `to_array = true` -- return an array of values for single-column results. * `null_value = val` -- value to use for `null` (defaults to `nil`). * `to_lua = f(v, col) -> v` -- custom value converter. + * `field_attrs = {name -> attr}` -- extra field attributes. For queries that return a result set, it returns an array of rows. For other queries it returns a Lua table with information such as @@ -99,8 +102,10 @@ the `sqlstate` return value contains the standard SQL error code that consists of 5 characters. Note that, the `errcode` and `sqlstate` might be `nil` if MySQL does not return them. -NOTE: 64 bit integers and decimals are converted to Lua numbers by default. +__NOTE:__ decimals and 64 bit integers are converted to Lua numbers by default. That limits the useful range of number types to 15 significant digits. +If you have other needs, provide your own `to_lua` (which you can set at +module, connection and query level, and even per field with `field_attrs`). ### `cn:query(query, [options]) -> res,nil,cols | nil,err,errcode,sqlstate` @@ -111,7 +116,6 @@ You should always check if the `err` return value is `again` in case of success because this method will only call [read_result](#read_result) once for you. - ### `cn:prepare(query, [opt]) -> stmt` Prepare a statement. Options can contain: @@ -130,10 +134,15 @@ Free statement. The MySQL server version string. +### `mysql.esc_utf8(s) -> s` + +Escape string to be used inside SQL string literals. Only works on connections +for which the charset is ASCII or an ASCII superset (ascii, utf8). + ### `cn:esc(s) -> s` -Escape string to be used inside SQL string literals. This only works if current -collation is known (ses `collation` arg on `connect()`). +Escape string to be used inside SQL string literals. This only works +if the current collation is known (see `collation` arg on `connect()`). ### Multiple result set support diff --git a/mysql_client_test.lua b/mysql_client_test.lua new file mode 100644 index 0000000..9710671 --- /dev/null +++ b/mysql_client_test.lua @@ -0,0 +1,117 @@ + +local mysql = require'mysql_client' +local sock = require'sock' +local pp = require'pp' +require'$' + +sock.run(function() + + local conn = assert(mysql.connect{ + host = '127.0.0.1', + port = 3307, + user = 'root', + password = 'root', + schema = 'sp', + collation = 'server', + }) + + assert(conn:query[[ + create table if not exists test ( + f1 decimal(20, 6), + f2 tinyint(1), + f2b tinyint unsigned, + f3 smallint(2), + f3a mediumint(3), + f4 int(4), + f5 bigint(5), + f6 float(2), /* (2) ignored */ + f7 double, /* can't even give (2) here */ + f8 timestamp, + f9 date, + f10 time, + f11 datetime, + f12 varchar(100), + f12a varchar(100) not null collate ascii_bin, + f13 char(100), + f14 varbinary(100), + f15 binary(100), + f16 year, + f17 bit(12), + f18 enum('apple', 'bannana'), + f19 set('a', 'b', 'c'), + f20 tinyblob, + f21 mediumblob, + f22 longblob, + f23 blob, + f24 tinytext, + f25 mediumtext, + f26 longtext, + f27 text(5), + f28 varchar(10), + f29 char(10) + ); + ]]) + + local function pr(cols, h) + local t = {} + for _,k in ipairs(h) do + add(t, fmt('%20s', k)) + end + print(cat(t)) + for _,col in ipairs(cols) do + local t = {} + for _,k in ipairs(h) do + local v = col[k] + v = isnum(v) and fmt('%0.17g', v) or v + add(t, fmt('%20s', repl(v, nil, ''))) + end + print(cat(t)) + end + end + + --pp(conn:query'select * from val where val = 1') + local stmt = assert(conn:prepare + --'select cast(123 as tinyint) union select cast(123 as tinyint)') + 'select * from test') + -- ('select min_price from vari where val = ?')) + assert(stmt:exec()) + local rows, _, cols = conn:read_result({datetime_format = '*t'}) + pr(cols, { + 'name', + 'mysql_display_type', + 'type', + 'display_width', + 'decimals', + 'has_time', + 'padded', + 'mysql_display_charset', + 'mysql_display_collation', + 'mysql_buffer_type', + }) + assert(stmt:free()) + + local spp = require'sqlpp'.new() + require'sqlpp_mysql' + spp.import'mysql' + local cn = spp.connect(conn) + local rows, cols = cn:query({get_table_defs=1}, 'select * from test') + print() + pr(cols, { + 'name', + 'mysql_type', + 'mysql_display_type', + 'type', + 'display_width', + 'decimals', + 'has_time', + 'padded', + 'mysql_charset', + 'mysql_display_charset', + 'mysql_collation', + 'mysql_display_collation', + 'mysql_buffer_type', + }) + --cn:close() + +end) +