1
0
mirror of https://github.com/luapower/mysql.git synced 2025-10-24 13:55:25 +02:00

unimportant

This commit is contained in:
Cosmin Apreutesei
2021-11-29 18:26:00 +02:00
parent ef3eee92f2
commit 1e73103e46
3 changed files with 250 additions and 97 deletions

View File

@@ -381,6 +381,24 @@ local mb_charsets = { --excluding ASCII supersets i.e. utf8* charsets.
gb18030=1,
}
local max_char_widths = {
utf8 = 3,
utf8mb4 = 4,
big5 = 2,
sjis = 2,
euckr = 2,
gb2312 = 2,
gbk = 2,
ucs2 = 2,
cp932 = 2,
ujis = 3, --eukjp
eucjpms = 3,
utf16 = 4,
utf16le = 4,
utf32 = 4,
gb18030 = 2,
}
local buffer_type_names = {
[ 0] = 'decimal',
[ 1] = 'tiny',
@@ -410,28 +428,25 @@ local buffer_type_names = {
[255] = 'geometry',
}
local num_types = {
local bin_types = {
tiny = 'tinyint',
short = 'shortint',
long = 'int',
short = 'smallint',
int24 = 'mediumint',
long = 'int',
longlong = 'bigint',
newdecimal = 'decimal',
}
local bin_types = {
tiny_blob = 'tinyblob',
medium_blob = 'mediumblob',
long_blob = 'longblob',
tiny_blob = 'tinyblob', --always selected as blob
medium_blob = 'mediumblob', --always selected as blob
long_blob = 'longblob', --always selected as blob
blob = 'blob',
var_string = 'varbinary',
string = 'binary',
}
local text_types = {
tiny_blob = 'tinytext',
medium_blob = 'mediumtext',
long_blob = 'longtext',
tiny_blob = 'tinytext', --always selected as text
medium_blob = 'mediumtext', --always selected as text
long_blob = 'longtext', --always selected as text
blob = 'text',
var_string = 'varchar',
string = 'char',
@@ -453,19 +468,23 @@ local string_types = {
}
local int_ranges = {
tinyint = {-(2^ 7-1), 2^ 7, 0, 2^ 8-1},
shortint = {-(2^15-1), 2^15, 0, 2^16-1},
mediumint = {-(2^23-1), 2^23, 0, 2^24-1},
int = {-(2^31-1), 2^31, 0, 2^32-1},
bigint = {-(2^51-1), 2^51, 0, 2^52-1},
tinyint = {1, -(2^ 7-1), 2^ 7, 0, 2^ 8-1},
smallint = {2, -(2^15-1), 2^15, 0, 2^16-1},
mediumint = {3, -(2^23-1), 2^23, 0, 2^24-1},
int = {4, -(2^31-1), 2^31, 0, 2^32-1},
bigint = {8, -(2^51-1), 2^51, 0, 2^52-1},
}
local conn = {}
local conn_mt = {__index = conn}
local to_lua = {
function mysql.isconn(x)
return getmetatable(x) == conn_mt
end
local default_to_lua = {
tinyint = tonumber,
shortint = tonumber,
smallint = tonumber,
mediumint = tonumber,
int = tonumber,
bigint = tonumber,
@@ -474,13 +493,6 @@ local to_lua = {
double = tonumber,
decimal = tonumber,
}
function mysql.to_lua(v, col)
local to_lua = col.to_lua or to_lua[col.type]
if to_lua then
v = to_lua(v)
end
return v
end
local function return_arg1(v) return v end
@@ -867,26 +879,26 @@ end
local UNSIGNED_FLAG = 32
function mysql.num_range(type, unsigned, digits, decimals)
if digits and decimals then
local max = 10^(digits - decimals) - 1 / 10^decimals
local min = unsigned and 0 or -max --unsigned decimals is deprecated!
return min, max
else
local range = int_ranges[type]
if range then
if unsigned then
return range[3], range[4]
else
return range[1], range[2]
end
function mysql.int_range(type, unsigned) --min, max, size
local range = int_ranges[type]
if range then
if unsigned then
return range[4], range[5], range[1]
else
return range[2], range[3], range[1]
end
end
end
function mysql.dec_range(digits, decimals, unsigned) --min, max, digits
local max = 10^(digits - decimals) - 10^-decimals
local min = unsigned and 0 or -max --unsigned decimals is deprecated!
return min, max, digits
end
--NOTE: MySQL doesn't give enough metadata to generate a form in a UI,
--you'll have to query `information_schema` to get the rest like enum values
--and defaults. So we only keep enough info for formatting the values.
--and defaults. So we gather only what we need for display, not for editing.
local function get_field_packet(buf)
local col = {}
local _ = get_name(buf) --always "def"
@@ -896,39 +908,68 @@ local function get_field_packet(buf)
col.name = get_name(buf) --alias column name
col.col = get_name(buf) --name of column in origin table
local _ = get_uint(buf) --0x0c
local collation = get_u16(buf)
col.max_char_w = get_u32(buf)
local collation = get_u16(buf) --connection's collation, not field's.
local display_size = get_u32(buf) --in bytes, not in characters.
local buf_type_code = get_u8(buf)
local flags = get_u16(buf)
local decimals = get_u8(buf)
local buf_type = buffer_type_names[buf_type_code]
if collation == 63 then
local type = bin_types[buf_type]
if not type then
type = num_types[buf_type]
if type then
col.decimals = decimals
end
local mysql_type
if collation == 63 then --binary
mysql_type = bin_types[buf_type] or buf_type
local unsigned = band(flags, UNSIGNED_FLAG) ~= 0 or nil --for val decoding
if mysql_type == 'tinyint' or mysql_type == 'smallint'
or mysql_type == 'mediumint' or mysql_type == 'int'
or mysql_type == 'bigint'
then
col.type = 'number'
col.decimals = 0
col.unsigned = unsigned
elseif mysql_type == 'decimal' then
col.type = 'number'
col.decimals = decimals
col.unsigned = unsigned
elseif mysql_type == 'float' then
col.type = 'number'
elseif mysql_type == 'double' then
col.type = 'number'
elseif mysql_type == 'year' then
col.type = 'number'
elseif mysql_type == 'timestamp' then
col.type = 'date'
col.has_time = true
elseif mysql_type == 'date' then
col.type = 'date'
elseif mysql_type == 'datetime' then
col.type = 'date'
col.has_time = true
elseif mysql_type == 'binary' then
col.padded = true
end
col.type = type or buf_type
col.display_width = display_size
else
col.type = text_types[buf_type]
col.collation = collation_names[collation]
col.charset = col.collation and col.collation:match'^[^_]+'
mysql_type = text_types[buf_type] or buf_type
local collation = collation_names[collation]
local charset = collation and collation:match'^[^_]+'
col.mysql_display_collation = collation
col.mysql_display_charset = charset
col.padded = mysql_type == 'char' or nil
col.display_width = math.ceil(display_size / (max_char_widths[charset] or 1))
end
col.buffer_type = buf_type
col.buffer_type_code = buf_type_code
col.unsigned = band(flags, UNSIGNED_FLAG) ~= 0 or nil
col.min, col.max = mysql.num_range(col.type, col.unsigned, nil, col.decimals)
col.mysql_display_type = mysql_type
col.mysql_buffer_type = buf_type --for param encoding
col.mysql_buffer_type_code = buf_type_code --for val decoding
return col
end
local function recv_field_packets(self, field_count, field_attrs)
local function recv_field_packets(self, field_count, field_attrs, to_lua)
local fields = {}
to_lua = to_lua or self.to_lua
for i = 1, field_count do
local typ, buf = recv_packet(self)
checkp(self, typ == 'DATA', 'bad packet type')
local field = get_field_packet(buf)
field.to_lua = to_lua or default_to_lua[field.mysql_type]
field.index = i
fields[i] = field
fields[field.name] = field
@@ -969,6 +1010,10 @@ function mysql.note (...) mysql.log('note', ...) end
function mysql.connect(opt)
if mysql.isconn(opt) then --pass-through
return opt
end
local host = opt.host
local port = opt.port or 3306
@@ -1115,7 +1160,7 @@ local function read_result(self, opt)
local field_count = get_uint(buf)
local extra = buf_len(buf) > 0 and get_uint(buf) or nil
local cols = recv_field_packets(self, field_count, opt and opt.field_attrs)
local cols = recv_field_packets(self, field_count, opt and opt.field_attrs, opt and opt.to_lua)
local compact = opt and opt.compact
local to_array = opt and opt.to_array and #cols == 1
@@ -1123,7 +1168,6 @@ local function read_result(self, opt)
local datetime_format = opt and opt.datetime_format or self.datetime_format
local date_format = opt and opt.date_format or self.date_format
local time_format = opt and opt.time_format or self.time_format
local to_lua = opt and opt.to_lua or self.to_lua
local rows = {}
local i = 0
@@ -1155,7 +1199,7 @@ local function read_result(self, opt)
local is_null = band(nulls[null_byte], shl(1, null_bit)) ~= 0
local v
if not is_null then
local bt = col.buffer_type
local bt = col.mysql_buffer_type
local unsigned = col.unsigned
if string_types[bt] then
v = get_str(buf)
@@ -1193,7 +1237,10 @@ local function read_result(self, opt)
for i, col in ipairs(cols) do
local v = get_str(buf)
if v ~= nil then
v = to_lua(v, col)
local to_lua = col.to_lua
if to_lua then
v = to_lua(v, col)
end
else
v = null_value
end
@@ -1268,8 +1315,8 @@ function conn:prepare(query, opt)
local param_count = get_u16(buf)
buf(1) --filler
stmt.warning_count = get_u16(buf)
stmt.params = recv_field_packets(self, param_count, opt and opt.param_attrs)
stmt.cols = recv_field_packets(self, col_count, opt and opt.field_attrs)
stmt.params = recv_field_packets(self, param_count, opt and opt.param_attrs, opt and opt.to_lua)
stmt.cols = recv_field_packets(self, col_count , opt and opt.field_attrs, opt and opt.to_lua)
stmt.cursor = assert(cursor_types[opt and opt.cursor or 'none'])
return stmt
end
@@ -1309,13 +1356,13 @@ function stmt:exec(...)
set_bytes(buf, nulls, nulls_len)
set_u8(buf, 1) --new-params-bound-flag
for i, param in ipairs(stmt.params) do
set_u8(buf, param.buffer_type_code)
set_u8(buf, param.mysql_buffer_type_code)
set_u8(buf, param.unsigned and 0x80 or 0)
end
for i, param in ipairs(stmt.params) do
local val = select(i, ...)
if val ~= nil then
local bt = param.buffer_type
local bt = param.mysql_buffer_type
local unsigned = param.unsigned
if string_types[bt] then
set_str(buf, tostring(val))
@@ -1390,34 +1437,14 @@ local qmap = {
['\26'] = '\\Z',
['\"' ] = '\\"',
}
local function esc_utf8(s)
return s:gsub('[\\\'%z\b\n\r\t\26\"]', qmap)
end
mysql.esc_utf8 = esc_utf8
function conn:esc(s)
--MBCS that are not ASCII supersets need decoding for correct quoting.
assert(self.charset_is_ascii_superset, 'NYI')
return s:gsub('[\\\'%z\b\n\r\t\26\"]', qmap)
end
if not ... then --demo
local sock = require'sock'
local pp = require'pp'
sock.run(function()
local conn = assert(mysql.connect{
host = '127.0.0.1',
port = 3307,
user = 'root',
password = 'abcd12',
schema = 'sp',
collation = 'server',
})
print(conn.charset, conn.collation)
--pp(conn:query'select * from val where val = 1')
local stmt = assert(conn:prepare('select min_price from vari where val = ?'))
assert(stmt:exec())
pp(conn:read_result({datetime_format = '*t'}))
assert(stmt:free())
end)
return esc_utf8(s)
end
return mysql

View File

@@ -1,8 +1,10 @@
## `local mysql = require'mysql_client'`
MySQL client protocol in Lua.
Stolen from OpenResty, modified to work with [sock] and added prepared statements.
MySQL client protocol in Lua. Ripped from OpenResty, modified to work with
[sock], added prepared statements, better interpretation of field metadata
(consistent with [sqlpp], [schema] and [xrowset][x-widges]), and other minor
changes.
## Example
@@ -81,6 +83,7 @@ The `options` arg can contain:
* `to_array = true` -- return an array of values for single-column results.
* `null_value = val` -- value to use for `null` (defaults to `nil`).
* `to_lua = f(v, col) -> v` -- custom value converter.
* `field_attrs = {name -> attr}` -- extra field attributes.
For queries that return a result set, it returns an array of rows.
For other queries it returns a Lua table with information such as
@@ -99,8 +102,10 @@ the `sqlstate` return value contains the standard SQL error code that consists
of 5 characters. Note that, the `errcode` and `sqlstate` might be `nil`
if MySQL does not return them.
NOTE: 64 bit integers and decimals are converted to Lua numbers by default.
__NOTE:__ decimals and 64 bit integers are converted to Lua numbers by default.
That limits the useful range of number types to 15 significant digits.
If you have other needs, provide your own `to_lua` (which you can set at
module, connection and query level, and even per field with `field_attrs`).
### `cn:query(query, [options]) -> res,nil,cols | nil,err,errcode,sqlstate`
@@ -111,7 +116,6 @@ You should always check if the `err` return value is `again` in case of
success because this method will only call [read_result](#read_result)
once for you.
### `cn:prepare(query, [opt]) -> stmt`
Prepare a statement. Options can contain:
@@ -130,10 +134,15 @@ Free statement.
The MySQL server version string.
### `mysql.esc_utf8(s) -> s`
Escape string to be used inside SQL string literals. Only works on connections
for which the charset is ASCII or an ASCII superset (ascii, utf8).
### `cn:esc(s) -> s`
Escape string to be used inside SQL string literals. This only works if current
collation is known (ses `collation` arg on `connect()`).
Escape string to be used inside SQL string literals. This only works
if the current collation is known (see `collation` arg on `connect()`).
### Multiple result set support

117
mysql_client_test.lua Normal file
View File

@@ -0,0 +1,117 @@
local mysql = require'mysql_client'
local sock = require'sock'
local pp = require'pp'
require'$'
sock.run(function()
local conn = assert(mysql.connect{
host = '127.0.0.1',
port = 3307,
user = 'root',
password = 'root',
schema = 'sp',
collation = 'server',
})
assert(conn:query[[
create table if not exists test (
f1 decimal(20, 6),
f2 tinyint(1),
f2b tinyint unsigned,
f3 smallint(2),
f3a mediumint(3),
f4 int(4),
f5 bigint(5),
f6 float(2), /* (2) ignored */
f7 double, /* can't even give (2) here */
f8 timestamp,
f9 date,
f10 time,
f11 datetime,
f12 varchar(100),
f12a varchar(100) not null collate ascii_bin,
f13 char(100),
f14 varbinary(100),
f15 binary(100),
f16 year,
f17 bit(12),
f18 enum('apple', 'bannana'),
f19 set('a', 'b', 'c'),
f20 tinyblob,
f21 mediumblob,
f22 longblob,
f23 blob,
f24 tinytext,
f25 mediumtext,
f26 longtext,
f27 text(5),
f28 varchar(10),
f29 char(10)
);
]])
local function pr(cols, h)
local t = {}
for _,k in ipairs(h) do
add(t, fmt('%20s', k))
end
print(cat(t))
for _,col in ipairs(cols) do
local t = {}
for _,k in ipairs(h) do
local v = col[k]
v = isnum(v) and fmt('%0.17g', v) or v
add(t, fmt('%20s', repl(v, nil, '')))
end
print(cat(t))
end
end
--pp(conn:query'select * from val where val = 1')
local stmt = assert(conn:prepare
--'select cast(123 as tinyint) union select cast(123 as tinyint)')
'select * from test')
-- ('select min_price from vari where val = ?'))
assert(stmt:exec())
local rows, _, cols = conn:read_result({datetime_format = '*t'})
pr(cols, {
'name',
'mysql_display_type',
'type',
'display_width',
'decimals',
'has_time',
'padded',
'mysql_display_charset',
'mysql_display_collation',
'mysql_buffer_type',
})
assert(stmt:free())
local spp = require'sqlpp'.new()
require'sqlpp_mysql'
spp.import'mysql'
local cn = spp.connect(conn)
local rows, cols = cn:query({get_table_defs=1}, 'select * from test')
print()
pr(cols, {
'name',
'mysql_type',
'mysql_display_type',
'type',
'display_width',
'decimals',
'has_time',
'padded',
'mysql_charset',
'mysql_display_charset',
'mysql_collation',
'mysql_display_collation',
'mysql_buffer_type',
})
--cn:close()
end)