1
0
mirror of https://github.com/luapower/mysql.git synced 2025-01-04 07:10:25 +01:00

unimportant

This commit is contained in:
Cosmin Apreutesei 2021-11-29 18:26:00 +02:00
parent ef3eee92f2
commit 1e73103e46
3 changed files with 250 additions and 97 deletions

View File

@ -381,6 +381,24 @@ local mb_charsets = { --excluding ASCII supersets i.e. utf8* charsets.
gb18030=1, gb18030=1,
} }
local max_char_widths = {
utf8 = 3,
utf8mb4 = 4,
big5 = 2,
sjis = 2,
euckr = 2,
gb2312 = 2,
gbk = 2,
ucs2 = 2,
cp932 = 2,
ujis = 3, --eukjp
eucjpms = 3,
utf16 = 4,
utf16le = 4,
utf32 = 4,
gb18030 = 2,
}
local buffer_type_names = { local buffer_type_names = {
[ 0] = 'decimal', [ 0] = 'decimal',
[ 1] = 'tiny', [ 1] = 'tiny',
@ -410,28 +428,25 @@ local buffer_type_names = {
[255] = 'geometry', [255] = 'geometry',
} }
local num_types = { local bin_types = {
tiny = 'tinyint', tiny = 'tinyint',
short = 'shortint', short = 'smallint',
long = 'int',
int24 = 'mediumint', int24 = 'mediumint',
long = 'int',
longlong = 'bigint', longlong = 'bigint',
newdecimal = 'decimal', newdecimal = 'decimal',
} tiny_blob = 'tinyblob', --always selected as blob
medium_blob = 'mediumblob', --always selected as blob
local bin_types = { long_blob = 'longblob', --always selected as blob
tiny_blob = 'tinyblob',
medium_blob = 'mediumblob',
long_blob = 'longblob',
blob = 'blob', blob = 'blob',
var_string = 'varbinary', var_string = 'varbinary',
string = 'binary', string = 'binary',
} }
local text_types = { local text_types = {
tiny_blob = 'tinytext', tiny_blob = 'tinytext', --always selected as text
medium_blob = 'mediumtext', medium_blob = 'mediumtext', --always selected as text
long_blob = 'longtext', long_blob = 'longtext', --always selected as text
blob = 'text', blob = 'text',
var_string = 'varchar', var_string = 'varchar',
string = 'char', string = 'char',
@ -453,19 +468,23 @@ local string_types = {
} }
local int_ranges = { local int_ranges = {
tinyint = {-(2^ 7-1), 2^ 7, 0, 2^ 8-1}, tinyint = {1, -(2^ 7-1), 2^ 7, 0, 2^ 8-1},
shortint = {-(2^15-1), 2^15, 0, 2^16-1}, smallint = {2, -(2^15-1), 2^15, 0, 2^16-1},
mediumint = {-(2^23-1), 2^23, 0, 2^24-1}, mediumint = {3, -(2^23-1), 2^23, 0, 2^24-1},
int = {-(2^31-1), 2^31, 0, 2^32-1}, int = {4, -(2^31-1), 2^31, 0, 2^32-1},
bigint = {-(2^51-1), 2^51, 0, 2^52-1}, bigint = {8, -(2^51-1), 2^51, 0, 2^52-1},
} }
local conn = {} local conn = {}
local conn_mt = {__index = conn} local conn_mt = {__index = conn}
local to_lua = { function mysql.isconn(x)
return getmetatable(x) == conn_mt
end
local default_to_lua = {
tinyint = tonumber, tinyint = tonumber,
shortint = tonumber, smallint = tonumber,
mediumint = tonumber, mediumint = tonumber,
int = tonumber, int = tonumber,
bigint = tonumber, bigint = tonumber,
@ -474,13 +493,6 @@ local to_lua = {
double = tonumber, double = tonumber,
decimal = tonumber, decimal = tonumber,
} }
function mysql.to_lua(v, col)
local to_lua = col.to_lua or to_lua[col.type]
if to_lua then
v = to_lua(v)
end
return v
end
local function return_arg1(v) return v end local function return_arg1(v) return v end
@ -867,26 +879,26 @@ end
local UNSIGNED_FLAG = 32 local UNSIGNED_FLAG = 32
function mysql.num_range(type, unsigned, digits, decimals) function mysql.int_range(type, unsigned) --min, max, size
if digits and decimals then
local max = 10^(digits - decimals) - 1 / 10^decimals
local min = unsigned and 0 or -max --unsigned decimals is deprecated!
return min, max
else
local range = int_ranges[type] local range = int_ranges[type]
if range then if range then
if unsigned then if unsigned then
return range[3], range[4] return range[4], range[5], range[1]
else else
return range[1], range[2] return range[2], range[3], range[1]
end
end end
end end
end end
function mysql.dec_range(digits, decimals, unsigned) --min, max, digits
local max = 10^(digits - decimals) - 10^-decimals
local min = unsigned and 0 or -max --unsigned decimals is deprecated!
return min, max, digits
end
--NOTE: MySQL doesn't give enough metadata to generate a form in a UI, --NOTE: MySQL doesn't give enough metadata to generate a form in a UI,
--you'll have to query `information_schema` to get the rest like enum values --you'll have to query `information_schema` to get the rest like enum values
--and defaults. So we only keep enough info for formatting the values. --and defaults. So we gather only what we need for display, not for editing.
local function get_field_packet(buf) local function get_field_packet(buf)
local col = {} local col = {}
local _ = get_name(buf) --always "def" local _ = get_name(buf) --always "def"
@ -896,39 +908,68 @@ local function get_field_packet(buf)
col.name = get_name(buf) --alias column name col.name = get_name(buf) --alias column name
col.col = get_name(buf) --name of column in origin table col.col = get_name(buf) --name of column in origin table
local _ = get_uint(buf) --0x0c local _ = get_uint(buf) --0x0c
local collation = get_u16(buf) local collation = get_u16(buf) --connection's collation, not field's.
col.max_char_w = get_u32(buf) local display_size = get_u32(buf) --in bytes, not in characters.
local buf_type_code = get_u8(buf) local buf_type_code = get_u8(buf)
local flags = get_u16(buf) local flags = get_u16(buf)
local decimals = get_u8(buf) local decimals = get_u8(buf)
local buf_type = buffer_type_names[buf_type_code] local buf_type = buffer_type_names[buf_type_code]
if collation == 63 then local mysql_type
local type = bin_types[buf_type] if collation == 63 then --binary
if not type then mysql_type = bin_types[buf_type] or buf_type
type = num_types[buf_type] local unsigned = band(flags, UNSIGNED_FLAG) ~= 0 or nil --for val decoding
if type then if mysql_type == 'tinyint' or mysql_type == 'smallint'
or mysql_type == 'mediumint' or mysql_type == 'int'
or mysql_type == 'bigint'
then
col.type = 'number'
col.decimals = 0
col.unsigned = unsigned
elseif mysql_type == 'decimal' then
col.type = 'number'
col.decimals = decimals col.decimals = decimals
col.unsigned = unsigned
elseif mysql_type == 'float' then
col.type = 'number'
elseif mysql_type == 'double' then
col.type = 'number'
elseif mysql_type == 'year' then
col.type = 'number'
elseif mysql_type == 'timestamp' then
col.type = 'date'
col.has_time = true
elseif mysql_type == 'date' then
col.type = 'date'
elseif mysql_type == 'datetime' then
col.type = 'date'
col.has_time = true
elseif mysql_type == 'binary' then
col.padded = true
end end
end col.display_width = display_size
col.type = type or buf_type
else else
col.type = text_types[buf_type] mysql_type = text_types[buf_type] or buf_type
col.collation = collation_names[collation] local collation = collation_names[collation]
col.charset = col.collation and col.collation:match'^[^_]+' local charset = collation and collation:match'^[^_]+'
col.mysql_display_collation = collation
col.mysql_display_charset = charset
col.padded = mysql_type == 'char' or nil
col.display_width = math.ceil(display_size / (max_char_widths[charset] or 1))
end end
col.buffer_type = buf_type col.mysql_display_type = mysql_type
col.buffer_type_code = buf_type_code col.mysql_buffer_type = buf_type --for param encoding
col.unsigned = band(flags, UNSIGNED_FLAG) ~= 0 or nil col.mysql_buffer_type_code = buf_type_code --for val decoding
col.min, col.max = mysql.num_range(col.type, col.unsigned, nil, col.decimals)
return col return col
end end
local function recv_field_packets(self, field_count, field_attrs) local function recv_field_packets(self, field_count, field_attrs, to_lua)
local fields = {} local fields = {}
to_lua = to_lua or self.to_lua
for i = 1, field_count do for i = 1, field_count do
local typ, buf = recv_packet(self) local typ, buf = recv_packet(self)
checkp(self, typ == 'DATA', 'bad packet type') checkp(self, typ == 'DATA', 'bad packet type')
local field = get_field_packet(buf) local field = get_field_packet(buf)
field.to_lua = to_lua or default_to_lua[field.mysql_type]
field.index = i field.index = i
fields[i] = field fields[i] = field
fields[field.name] = field fields[field.name] = field
@ -969,6 +1010,10 @@ function mysql.note (...) mysql.log('note', ...) end
function mysql.connect(opt) function mysql.connect(opt)
if mysql.isconn(opt) then --pass-through
return opt
end
local host = opt.host local host = opt.host
local port = opt.port or 3306 local port = opt.port or 3306
@ -1115,7 +1160,7 @@ local function read_result(self, opt)
local field_count = get_uint(buf) local field_count = get_uint(buf)
local extra = buf_len(buf) > 0 and get_uint(buf) or nil local extra = buf_len(buf) > 0 and get_uint(buf) or nil
local cols = recv_field_packets(self, field_count, opt and opt.field_attrs) local cols = recv_field_packets(self, field_count, opt and opt.field_attrs, opt and opt.to_lua)
local compact = opt and opt.compact local compact = opt and opt.compact
local to_array = opt and opt.to_array and #cols == 1 local to_array = opt and opt.to_array and #cols == 1
@ -1123,7 +1168,6 @@ local function read_result(self, opt)
local datetime_format = opt and opt.datetime_format or self.datetime_format local datetime_format = opt and opt.datetime_format or self.datetime_format
local date_format = opt and opt.date_format or self.date_format local date_format = opt and opt.date_format or self.date_format
local time_format = opt and opt.time_format or self.time_format local time_format = opt and opt.time_format or self.time_format
local to_lua = opt and opt.to_lua or self.to_lua
local rows = {} local rows = {}
local i = 0 local i = 0
@ -1155,7 +1199,7 @@ local function read_result(self, opt)
local is_null = band(nulls[null_byte], shl(1, null_bit)) ~= 0 local is_null = band(nulls[null_byte], shl(1, null_bit)) ~= 0
local v local v
if not is_null then if not is_null then
local bt = col.buffer_type local bt = col.mysql_buffer_type
local unsigned = col.unsigned local unsigned = col.unsigned
if string_types[bt] then if string_types[bt] then
v = get_str(buf) v = get_str(buf)
@ -1193,7 +1237,10 @@ local function read_result(self, opt)
for i, col in ipairs(cols) do for i, col in ipairs(cols) do
local v = get_str(buf) local v = get_str(buf)
if v ~= nil then if v ~= nil then
local to_lua = col.to_lua
if to_lua then
v = to_lua(v, col) v = to_lua(v, col)
end
else else
v = null_value v = null_value
end end
@ -1268,8 +1315,8 @@ function conn:prepare(query, opt)
local param_count = get_u16(buf) local param_count = get_u16(buf)
buf(1) --filler buf(1) --filler
stmt.warning_count = get_u16(buf) stmt.warning_count = get_u16(buf)
stmt.params = recv_field_packets(self, param_count, opt and opt.param_attrs) stmt.params = recv_field_packets(self, param_count, opt and opt.param_attrs, opt and opt.to_lua)
stmt.cols = recv_field_packets(self, col_count, opt and opt.field_attrs) stmt.cols = recv_field_packets(self, col_count , opt and opt.field_attrs, opt and opt.to_lua)
stmt.cursor = assert(cursor_types[opt and opt.cursor or 'none']) stmt.cursor = assert(cursor_types[opt and opt.cursor or 'none'])
return stmt return stmt
end end
@ -1309,13 +1356,13 @@ function stmt:exec(...)
set_bytes(buf, nulls, nulls_len) set_bytes(buf, nulls, nulls_len)
set_u8(buf, 1) --new-params-bound-flag set_u8(buf, 1) --new-params-bound-flag
for i, param in ipairs(stmt.params) do for i, param in ipairs(stmt.params) do
set_u8(buf, param.buffer_type_code) set_u8(buf, param.mysql_buffer_type_code)
set_u8(buf, param.unsigned and 0x80 or 0) set_u8(buf, param.unsigned and 0x80 or 0)
end end
for i, param in ipairs(stmt.params) do for i, param in ipairs(stmt.params) do
local val = select(i, ...) local val = select(i, ...)
if val ~= nil then if val ~= nil then
local bt = param.buffer_type local bt = param.mysql_buffer_type
local unsigned = param.unsigned local unsigned = param.unsigned
if string_types[bt] then if string_types[bt] then
set_str(buf, tostring(val)) set_str(buf, tostring(val))
@ -1390,34 +1437,14 @@ local qmap = {
['\26'] = '\\Z', ['\26'] = '\\Z',
['\"' ] = '\\"', ['\"' ] = '\\"',
} }
local function esc_utf8(s)
return s:gsub('[\\\'%z\b\n\r\t\26\"]', qmap)
end
mysql.esc_utf8 = esc_utf8
function conn:esc(s) function conn:esc(s)
--MBCS that are not ASCII supersets need decoding for correct quoting. --MBCS that are not ASCII supersets need decoding for correct quoting.
assert(self.charset_is_ascii_superset, 'NYI') assert(self.charset_is_ascii_superset, 'NYI')
return s:gsub('[\\\'%z\b\n\r\t\26\"]', qmap) return esc_utf8(s)
end
if not ... then --demo
local sock = require'sock'
local pp = require'pp'
sock.run(function()
local conn = assert(mysql.connect{
host = '127.0.0.1',
port = 3307,
user = 'root',
password = 'abcd12',
schema = 'sp',
collation = 'server',
})
print(conn.charset, conn.collation)
--pp(conn:query'select * from val where val = 1')
local stmt = assert(conn:prepare('select min_price from vari where val = ?'))
assert(stmt:exec())
pp(conn:read_result({datetime_format = '*t'}))
assert(stmt:free())
end)
end end
return mysql return mysql

View File

@ -1,8 +1,10 @@
## `local mysql = require'mysql_client'` ## `local mysql = require'mysql_client'`
MySQL client protocol in Lua. MySQL client protocol in Lua. Ripped from OpenResty, modified to work with
Stolen from OpenResty, modified to work with [sock] and added prepared statements. [sock], added prepared statements, better interpretation of field metadata
(consistent with [sqlpp], [schema] and [xrowset][x-widges]), and other minor
changes.
## Example ## Example
@ -81,6 +83,7 @@ The `options` arg can contain:
* `to_array = true` -- return an array of values for single-column results. * `to_array = true` -- return an array of values for single-column results.
* `null_value = val` -- value to use for `null` (defaults to `nil`). * `null_value = val` -- value to use for `null` (defaults to `nil`).
* `to_lua = f(v, col) -> v` -- custom value converter. * `to_lua = f(v, col) -> v` -- custom value converter.
* `field_attrs = {name -> attr}` -- extra field attributes.
For queries that return a result set, it returns an array of rows. For queries that return a result set, it returns an array of rows.
For other queries it returns a Lua table with information such as For other queries it returns a Lua table with information such as
@ -99,8 +102,10 @@ the `sqlstate` return value contains the standard SQL error code that consists
of 5 characters. Note that, the `errcode` and `sqlstate` might be `nil` of 5 characters. Note that, the `errcode` and `sqlstate` might be `nil`
if MySQL does not return them. if MySQL does not return them.
NOTE: 64 bit integers and decimals are converted to Lua numbers by default. __NOTE:__ decimals and 64 bit integers are converted to Lua numbers by default.
That limits the useful range of number types to 15 significant digits. That limits the useful range of number types to 15 significant digits.
If you have other needs, provide your own `to_lua` (which you can set at
module, connection and query level, and even per field with `field_attrs`).
### `cn:query(query, [options]) -> res,nil,cols | nil,err,errcode,sqlstate` ### `cn:query(query, [options]) -> res,nil,cols | nil,err,errcode,sqlstate`
@ -111,7 +116,6 @@ You should always check if the `err` return value is `again` in case of
success because this method will only call [read_result](#read_result) success because this method will only call [read_result](#read_result)
once for you. once for you.
### `cn:prepare(query, [opt]) -> stmt` ### `cn:prepare(query, [opt]) -> stmt`
Prepare a statement. Options can contain: Prepare a statement. Options can contain:
@ -130,10 +134,15 @@ Free statement.
The MySQL server version string. The MySQL server version string.
### `mysql.esc_utf8(s) -> s`
Escape string to be used inside SQL string literals. Only works on connections
for which the charset is ASCII or an ASCII superset (ascii, utf8).
### `cn:esc(s) -> s` ### `cn:esc(s) -> s`
Escape string to be used inside SQL string literals. This only works if current Escape string to be used inside SQL string literals. This only works
collation is known (ses `collation` arg on `connect()`). if the current collation is known (see `collation` arg on `connect()`).
### Multiple result set support ### Multiple result set support

117
mysql_client_test.lua Normal file
View File

@ -0,0 +1,117 @@
local mysql = require'mysql_client'
local sock = require'sock'
local pp = require'pp'
require'$'
sock.run(function()
local conn = assert(mysql.connect{
host = '127.0.0.1',
port = 3307,
user = 'root',
password = 'root',
schema = 'sp',
collation = 'server',
})
assert(conn:query[[
create table if not exists test (
f1 decimal(20, 6),
f2 tinyint(1),
f2b tinyint unsigned,
f3 smallint(2),
f3a mediumint(3),
f4 int(4),
f5 bigint(5),
f6 float(2), /* (2) ignored */
f7 double, /* can't even give (2) here */
f8 timestamp,
f9 date,
f10 time,
f11 datetime,
f12 varchar(100),
f12a varchar(100) not null collate ascii_bin,
f13 char(100),
f14 varbinary(100),
f15 binary(100),
f16 year,
f17 bit(12),
f18 enum('apple', 'bannana'),
f19 set('a', 'b', 'c'),
f20 tinyblob,
f21 mediumblob,
f22 longblob,
f23 blob,
f24 tinytext,
f25 mediumtext,
f26 longtext,
f27 text(5),
f28 varchar(10),
f29 char(10)
);
]])
local function pr(cols, h)
local t = {}
for _,k in ipairs(h) do
add(t, fmt('%20s', k))
end
print(cat(t))
for _,col in ipairs(cols) do
local t = {}
for _,k in ipairs(h) do
local v = col[k]
v = isnum(v) and fmt('%0.17g', v) or v
add(t, fmt('%20s', repl(v, nil, '')))
end
print(cat(t))
end
end
--pp(conn:query'select * from val where val = 1')
local stmt = assert(conn:prepare
--'select cast(123 as tinyint) union select cast(123 as tinyint)')
'select * from test')
-- ('select min_price from vari where val = ?'))
assert(stmt:exec())
local rows, _, cols = conn:read_result({datetime_format = '*t'})
pr(cols, {
'name',
'mysql_display_type',
'type',
'display_width',
'decimals',
'has_time',
'padded',
'mysql_display_charset',
'mysql_display_collation',
'mysql_buffer_type',
})
assert(stmt:free())
local spp = require'sqlpp'.new()
require'sqlpp_mysql'
spp.import'mysql'
local cn = spp.connect(conn)
local rows, cols = cn:query({get_table_defs=1}, 'select * from test')
print()
pr(cols, {
'name',
'mysql_type',
'mysql_display_type',
'type',
'display_width',
'decimals',
'has_time',
'padded',
'mysql_charset',
'mysql_display_charset',
'mysql_collation',
'mysql_display_collation',
'mysql_buffer_type',
})
--cn:close()
end)