mirror of
https://github.com/luapower/mysql.git
synced 2024-11-14 06:30:17 +01:00
1244 lines
36 KiB
Lua
1244 lines
36 KiB
Lua
--mysql ffi binding by Cosmin Apreutesei. supports mysql Connector/C 6.1. based on mySQL 5.7 manual.
|
|
local ffi = require'ffi'
|
|
local bit = require'bit'
|
|
require'mysql_h'
|
|
local C = ffi.load(ffi.abi'win' and 'libmysql' or 'mysql')
|
|
local M = {C = C}
|
|
|
|
--we compare NULL pointers against NULL instead of nil for compatibility with luaffi.
|
|
local NULL = ffi.cast('void*', nil)
|
|
|
|
local function ptr(p) --convert NULLs to nil
|
|
if p == NULL then return nil end
|
|
return p
|
|
end
|
|
|
|
local function cstring(data) --convert null-term non-empty C strings to lua strings
|
|
if data == NULL or data[0] == 0 then return nil end
|
|
return ffi.string(data)
|
|
end
|
|
|
|
--error reporting
|
|
|
|
local function myerror(mysql)
|
|
local err = cstring(C.mysql_error(mysql))
|
|
if not err then return end
|
|
error(string.format('mysql error: %s', err))
|
|
end
|
|
|
|
local function checkz(mysql, ret)
|
|
if ret == 0 then return end
|
|
myerror(mysql)
|
|
end
|
|
|
|
local function checkh(mysql, ret)
|
|
if ret ~= NULL then return ret end
|
|
myerror(mysql)
|
|
end
|
|
|
|
local function enum(e, prefix)
|
|
return assert(type(e) == 'string' and (prefix and C[prefix..e] or C[e]) or e, 'invalid enum value')
|
|
end
|
|
|
|
--client library info
|
|
|
|
function M.thread_safe()
|
|
return C.mysql_thread_safe() == 1
|
|
end
|
|
|
|
function M.client_info()
|
|
return cstring(C.mysql_get_client_info())
|
|
end
|
|
|
|
M.client_version = C.mysql_get_client_version
|
|
|
|
--connections
|
|
|
|
local function bool_ptr(b)
|
|
return ffi.new('my_bool[1]', b or false)
|
|
end
|
|
|
|
local function uint_bool_ptr(b)
|
|
return ffi.new('uint32_t[1]', b or false)
|
|
end
|
|
|
|
local function uint_ptr(i)
|
|
return ffi.new('uint32_t[1]', i)
|
|
end
|
|
|
|
local function proto_ptr(proto) --proto is 'MYSQL_PROTOCOL_*' or mysql.C.MYSQL_PROTOCOL_*
|
|
return ffi.new('uint32_t[1]', enum(proto))
|
|
end
|
|
|
|
local function ignore_arg()
|
|
return nil
|
|
end
|
|
|
|
local option_encoders = {
|
|
MYSQL_ENABLE_CLEARTEXT_PLUGIN = bool_ptr,
|
|
MYSQL_OPT_LOCAL_INFILE = uint_bool_ptr,
|
|
MYSQL_OPT_PROTOCOL = proto_ptr,
|
|
MYSQL_OPT_READ_TIMEOUT = uint_ptr,
|
|
MYSQL_OPT_WRITE_TIMEOUT = uint_ptr,
|
|
MYSQL_OPT_USE_REMOTE_CONNECTION = ignore_arg,
|
|
MYSQL_OPT_USE_EMBEDDED_CONNECTION = ignore_arg,
|
|
MYSQL_OPT_GUESS_CONNECTION = ignore_arg,
|
|
MYSQL_SECURE_AUTH = bool_ptr,
|
|
MYSQL_REPORT_DATA_TRUNCATION = bool_ptr,
|
|
MYSQL_OPT_RECONNECT = bool_ptr,
|
|
MYSQL_OPT_SSL_VERIFY_SERVER_CERT = bool_ptr,
|
|
MYSQL_ENABLE_CLEARTEXT_PLUGIN = bool_ptr,
|
|
MYSQL_OPT_CAN_HANDLE_EXPIRED_PASSWORDS = bool_ptr,
|
|
}
|
|
|
|
function M.connect(t, ...)
|
|
local host, user, pass, db, charset, port
|
|
local unix_socket, flags, options, attrs
|
|
local key, cert, ca, capath, cipher
|
|
if type(t) == 'string' then
|
|
host, user, pass, db, charset, port = t, ...
|
|
else
|
|
host, user, pass, db, charset, port = t.host, t.user, t.pass, t.db, t.charset, t.port
|
|
unix_socket, flags, options, attrs = t.unix_socket, t.flags, t.options, t.attrs
|
|
key, cert, ca, capath, cipher = t.key, t.cert, t.ca, t.capath, t.cipher
|
|
end
|
|
port = port or 0
|
|
|
|
local client_flag = 0
|
|
if type(flags) == 'number' then
|
|
client_flag = flags
|
|
elseif flags then
|
|
for k,v in pairs(flags) do
|
|
local flag = enum(k, 'MYSQL_') --'CLIENT_*' or mysql.C.MYSQL_CLIENT_* enum
|
|
client_flag = v and bit.bor(client_flag, flag) or bit.band(client_flag, bit.bnot(flag))
|
|
end
|
|
end
|
|
|
|
local mysql = assert(C.mysql_init(nil))
|
|
ffi.gc(mysql, C.mysql_close)
|
|
|
|
if options then
|
|
for k,v in pairs(options) do
|
|
local opt = enum(k) --'MYSQL_OPT_*' or mysql.C.MYSQL_OPT_* enum
|
|
local encoder = option_encoders[k]
|
|
if encoder then v = encoder(v) end
|
|
assert(C.mysql_options(mysql, opt, ffi.cast('const void*', v)) == 0, 'invalid option')
|
|
end
|
|
end
|
|
|
|
if attrs then
|
|
for k,v in pairs(attrs) do
|
|
assert(C.mysql_options4(mysql, C.MYSQL_OPT_CONNECT_ATTR_ADD, k, v) == 0)
|
|
end
|
|
end
|
|
|
|
if key then
|
|
checkz(mysql, C.mysql_ssl_set(mysql, key, cert, ca, capath, cipher))
|
|
end
|
|
|
|
checkh(mysql, C.mysql_real_connect(mysql, host, user, pass, db, port, unix_socket, client_flag))
|
|
|
|
if charset then mysql:set_charset(charset) end
|
|
|
|
return mysql
|
|
end
|
|
|
|
local conn = {} --connection methods
|
|
|
|
function conn.close(mysql)
|
|
C.mysql_close(mysql)
|
|
ffi.gc(mysql, nil)
|
|
end
|
|
|
|
function conn.set_charset(mysql, charset)
|
|
checkz(mysql, C.mysql_set_character_set(mysql, charset))
|
|
end
|
|
|
|
function conn.select_db(mysql, db)
|
|
checkz(mysql, C.mysql_select_db(mysql, db))
|
|
end
|
|
|
|
function conn.change_user(mysql, user, pass, db)
|
|
checkz(mysql, C.mysql_change_user(mysql, user, pass, db))
|
|
end
|
|
|
|
function conn.set_multiple_statements(mysql, yes)
|
|
checkz(mysql, C.mysql_set_server_option(mysql, yes and C.MYSQL_OPTION_MULTI_STATEMENTS_ON or
|
|
C.MYSQL_OPTION_MULTI_STATEMENTS_OFF))
|
|
end
|
|
|
|
--connection info
|
|
|
|
function conn.charset(mysql)
|
|
return cstring(C.mysql_character_set_name(mysql))
|
|
end
|
|
|
|
function conn.charset_info(mysql)
|
|
local info = ffi.new'MY_CHARSET_INFO'
|
|
checkz(C.mysql_get_character_set_info(mysql, info))
|
|
assert(info.name ~= NULL)
|
|
assert(info.csname ~= NULL)
|
|
return {
|
|
number = info.number,
|
|
state = info.state,
|
|
name = cstring(info.csname), --csname and name are inverted from the spec
|
|
collation = cstring(info.name),
|
|
comment = cstring(info.comment),
|
|
dir = cstring(info.dir),
|
|
mbminlen = info.mbminlen,
|
|
mbmaxlen = info.mbmaxlen,
|
|
}
|
|
end
|
|
|
|
function conn.ping(mysql)
|
|
local ret = C.mysql_ping(mysql)
|
|
if ret == 0 then
|
|
return true
|
|
elseif C.mysql_error(mysql) == C.MYSQL_CR_SERVER_GONE_ERROR then
|
|
return false
|
|
end
|
|
myerror(mysql)
|
|
end
|
|
|
|
conn.thread_id = C.mysql_thread_id
|
|
|
|
function conn.stat(mysql)
|
|
return cstring(checkh(mysql, C.mysql_stat(mysql)))
|
|
end
|
|
|
|
function conn.server_info(mysql)
|
|
return cstring(checkh(mysql, C.mysql_get_server_info(mysql)))
|
|
end
|
|
|
|
function conn.host_info(mysql)
|
|
return cstring(checkh(mysql, C.mysql_get_host_info(mysql)))
|
|
end
|
|
|
|
conn.server_version = C.mysql_get_server_version
|
|
conn.proto_info = C.mysql_get_proto_info
|
|
|
|
|
|
function conn.ssl_cipher(mysql)
|
|
return cstring(C.mysql_get_ssl_cipher(mysql))
|
|
end
|
|
|
|
--transactions
|
|
|
|
function conn.commit(mysql) checkz(mysql, C.mysql_commit(mysql)) end
|
|
function conn.rollback(mysql) checkz(mysql, C.mysql_rollback(mysql)) end
|
|
function conn.set_autocommit(mysql, yes) checkz(mysql, C.mysql_autocommit(mysql, yes == nil or yes)) end
|
|
|
|
--queries
|
|
|
|
function conn.escape_tobuffer(mysql, data, size, buf, sz)
|
|
size = size or #data
|
|
assert(sz >= size * 2 + 1)
|
|
return C.mysql_real_escape_string(mysql, buf, data, size)
|
|
end
|
|
|
|
function conn.escape(mysql, data, size)
|
|
size = size or #data
|
|
local sz = size * 2 + 1
|
|
local buf = ffi.new('uint8_t[?]', sz)
|
|
sz = conn.escape_tobuffer(mysql, data, size, buf, sz)
|
|
return ffi.string(buf, sz)
|
|
end
|
|
|
|
function conn.query(mysql, data, size)
|
|
checkz(mysql, C.mysql_real_query(mysql, data, size or #data))
|
|
end
|
|
|
|
--query info
|
|
|
|
conn.field_count = C.mysql_field_count
|
|
|
|
local minus1_uint64 = ffi.cast('uint64_t', ffi.cast('int64_t', -1))
|
|
function conn.affected_rows(mysql)
|
|
local n = C.mysql_affected_rows(mysql)
|
|
if n == minus1_uint64 then myerror(mysql) end
|
|
return tonumber(n)
|
|
end
|
|
|
|
conn.insert_id = C.mysql_insert_id
|
|
|
|
function conn.errno(conn)
|
|
local err = C.mysql_errno(conn)
|
|
if err == 0 then return end
|
|
return err
|
|
end
|
|
|
|
function conn.sqlstate(mysql)
|
|
return cstring(C.mysql_sqlstate(mysql))
|
|
end
|
|
|
|
conn.warning_count = C.mysql_warning_count
|
|
|
|
function conn.info(mysql)
|
|
return cstring(C.mysql_info(mysql))
|
|
end
|
|
|
|
--query results
|
|
|
|
function conn.next_result(mysql) --multiple statement queries return multiple results
|
|
local ret = C.mysql_next_result(mysql)
|
|
if ret == 0 then return true end
|
|
if ret == -1 then return false end
|
|
myerror(mysql)
|
|
end
|
|
|
|
function conn.more_results(mysql)
|
|
return C.mysql_more_results(mysql) == 1
|
|
end
|
|
|
|
local function result_function(func)
|
|
return function(mysql)
|
|
local res = checkh(mysql, func(mysql))
|
|
return ffi.gc(res, C.mysql_free_result)
|
|
end
|
|
end
|
|
|
|
conn.store_result = result_function(C.mysql_store_result)
|
|
conn.use_result = result_function(C.mysql_use_result)
|
|
|
|
local res = {} --result methods
|
|
|
|
function res.free(res)
|
|
C.mysql_free_result(res)
|
|
ffi.gc(res, nil)
|
|
end
|
|
|
|
function res.row_count(res)
|
|
return tonumber(C.mysql_num_rows(res))
|
|
end
|
|
|
|
res.field_count = C.mysql_num_fields
|
|
|
|
function res.eof(res)
|
|
return C.mysql_eof(res) ~= 0
|
|
end
|
|
|
|
--field info
|
|
|
|
local field_type_names = {
|
|
[C.MYSQL_TYPE_DECIMAL] = 'decimal', --DECIMAL or NUMERIC
|
|
[C.MYSQL_TYPE_TINY] = 'tinyint',
|
|
[C.MYSQL_TYPE_SHORT] = 'smallint',
|
|
[C.MYSQL_TYPE_LONG] = 'int',
|
|
[C.MYSQL_TYPE_FLOAT] = 'float',
|
|
[C.MYSQL_TYPE_DOUBLE] = 'double', --DOUBLE or REAL
|
|
[C.MYSQL_TYPE_NULL] = 'null',
|
|
[C.MYSQL_TYPE_TIMESTAMP] = 'timestamp',
|
|
[C.MYSQL_TYPE_LONGLONG] = 'bigint',
|
|
[C.MYSQL_TYPE_INT24] = 'mediumint',
|
|
[C.MYSQL_TYPE_DATE] = 'date', --pre mysql 5.0, storage = 4 bytes
|
|
[C.MYSQL_TYPE_TIME] = 'time',
|
|
[C.MYSQL_TYPE_DATETIME] = 'datetime',
|
|
[C.MYSQL_TYPE_YEAR] = 'year',
|
|
[C.MYSQL_TYPE_NEWDATE] = 'date', --mysql 5.0+, storage = 3 bytes
|
|
[C.MYSQL_TYPE_VARCHAR] = 'varchar',
|
|
[C.MYSQL_TYPE_BIT] = 'bit',
|
|
[C.MYSQL_TYPE_TIMESTAMP2] = 'timestamp', --mysql 5.6+, can store fractional seconds
|
|
[C.MYSQL_TYPE_DATETIME2] = 'datetime', --mysql 5.6+, can store fractional seconds
|
|
[C.MYSQL_TYPE_TIME2] = 'time', --mysql 5.6+, can store fractional seconds
|
|
[C.MYSQL_TYPE_NEWDECIMAL] = 'decimal', --mysql 5.0+, Precision math DECIMAL or NUMERIC
|
|
[C.MYSQL_TYPE_ENUM] = 'enum',
|
|
[C.MYSQL_TYPE_SET] = 'set',
|
|
[C.MYSQL_TYPE_TINY_BLOB] = 'tinyblob',
|
|
[C.MYSQL_TYPE_MEDIUM_BLOB] = 'mediumblob',
|
|
[C.MYSQL_TYPE_LONG_BLOB] = 'longblob',
|
|
[C.MYSQL_TYPE_BLOB] = 'text', --TEXT or BLOB
|
|
[C.MYSQL_TYPE_VAR_STRING] = 'varchar', --VARCHAR or VARBINARY
|
|
[C.MYSQL_TYPE_STRING] = 'char', --CHAR or BINARY
|
|
[C.MYSQL_TYPE_GEOMETRY] = 'spatial', --Spatial field
|
|
}
|
|
|
|
local binary_field_type_names = {
|
|
[C.MYSQL_TYPE_BLOB] = 'blob',
|
|
[C.MYSQL_TYPE_VAR_STRING] = 'varbinary',
|
|
[C.MYSQL_TYPE_STRING] = 'binary',
|
|
}
|
|
|
|
local field_flag_names = {
|
|
[C.MYSQL_NOT_NULL_FLAG] = 'not_null',
|
|
[C.MYSQL_PRI_KEY_FLAG] = 'pri_key',
|
|
[C.MYSQL_UNIQUE_KEY_FLAG] = 'unique_key',
|
|
[C.MYSQL_MULTIPLE_KEY_FLAG] = 'key',
|
|
[C.MYSQL_BLOB_FLAG] = 'is_blob',
|
|
[C.MYSQL_UNSIGNED_FLAG] = 'unsigned',
|
|
[C.MYSQL_ZEROFILL_FLAG] = 'zerofill',
|
|
[C.MYSQL_BINARY_FLAG] = 'is_binary',
|
|
[C.MYSQL_ENUM_FLAG] = 'is_enum',
|
|
[C.MYSQL_AUTO_INCREMENT_FLAG] = 'autoincrement',
|
|
[C.MYSQL_TIMESTAMP_FLAG] = 'is_timestamp',
|
|
[C.MYSQL_SET_FLAG] = 'is_set',
|
|
[C.MYSQL_NO_DEFAULT_VALUE_FLAG] = 'no_default',
|
|
[C.MYSQL_ON_UPDATE_NOW_FLAG] = 'on_update_now',
|
|
[C.MYSQL_NUM_FLAG] = 'is_number',
|
|
}
|
|
|
|
local function field_type_name(info)
|
|
local type_flag = tonumber(info.type)
|
|
local field_type = field_type_names[type_flag]
|
|
--charsetnr 63 changes CHAR into BINARY, VARCHAR into VARBYNARY, TEXT into BLOB
|
|
field_type = info.charsetnr == 63 and binary_field_type_names[type_flag] or field_type
|
|
return field_type
|
|
end
|
|
|
|
--convenience field type fetcher (less garbage)
|
|
function res.field_type(res, i)
|
|
assert(i >= 1 and i <= res:field_count(), 'index out of range')
|
|
local info = C.mysql_fetch_field_direct(res, i-1)
|
|
local unsigned = bit.bor(info.flags, C.MYSQL_UNSIGNED_FLAG) ~= 0
|
|
return field_type_name(info), info.length, unsigned, info.decimals
|
|
end
|
|
|
|
function res.field_info(res, i)
|
|
assert(i >= 1 and i <= res:field_count(), 'index out of range')
|
|
local info = C.mysql_fetch_field_direct(res, i-1)
|
|
local t = {
|
|
name = cstring(info.name, info.name_length),
|
|
org_name = cstring(info.org_name, info.org_name_length),
|
|
table = cstring(info.table, info.table_length),
|
|
org_table = cstring(info.org_table, info.org_table_length),
|
|
db = cstring(info.db, info.db_length),
|
|
catalog = cstring(info.catalog, info.catalog_length),
|
|
def = cstring(info.def, info.def_length),
|
|
length = info.length,
|
|
max_length = info.max_length,
|
|
decimals = info.decimals,
|
|
charsetnr = info.charsetnr,
|
|
type_flag = tonumber(info.type),
|
|
type = field_type_name(info),
|
|
flags = info.flags,
|
|
extension = ptr(info.extension),
|
|
}
|
|
for flag, name in pairs(field_flag_names) do
|
|
t[name] = bit.band(flag, info.flags) ~= 0
|
|
end
|
|
return t
|
|
end
|
|
|
|
--convenience field name fetcher (less garbage)
|
|
function res.field_name(res, i)
|
|
assert(i >= 1 and i <= res:field_count(), 'index out of range')
|
|
local info = C.mysql_fetch_field_direct(res, i-1)
|
|
return cstring(info.name, info.name_length)
|
|
end
|
|
|
|
--convenience field iterator, shortcut for: for i=1,res:field_count() do local field = res:field_info(i) ... end
|
|
function res.fields(res)
|
|
local n = res:field_count()
|
|
local i = 0
|
|
return function()
|
|
if i == n then return end
|
|
i = i + 1
|
|
return i, res:field_info(i)
|
|
end
|
|
end
|
|
|
|
--row data fetching and parsing
|
|
|
|
ffi.cdef('double strtod(const char*, char**);')
|
|
local function parse_int(data, sz) --using strtod to avoid string creation
|
|
return ffi.C.strtod(data, nil)
|
|
end
|
|
|
|
local function parse_float(data, sz)
|
|
return tonumber(ffi.cast('float', ffi.C.strtod(data, nil))) --because windows is missing strtof()
|
|
end
|
|
|
|
local function parse_double(data, sz)
|
|
return ffi.C.strtod(data, nil)
|
|
end
|
|
|
|
ffi.cdef('int64_t strtoll(const char*, char**, int) ' ..(ffi.os == 'Windows' and ' asm("_strtoi64")' or '') .. ';')
|
|
local function parse_int64(data, sz)
|
|
return ffi.C.strtoll(data, nil, 10)
|
|
end
|
|
|
|
ffi.cdef('uint64_t strtoull(const char*, char**, int) ' ..(ffi.os == 'Windows' and ' asm("_strtoui64")' or '') .. ';')
|
|
local function parse_uint64(data, sz)
|
|
return ffi.C.strtoull(data, nil, 10)
|
|
end
|
|
|
|
local function parse_bit(data, sz)
|
|
data = ffi.cast('uint8_t*', data) --force unsigned
|
|
local n = data[0] --this is the msb: bit fields always come in big endian byte order
|
|
if sz > 6 then --we can cover up to 6 bytes with only Lua numbers
|
|
n = ffi.new('uint64_t', n)
|
|
end
|
|
for i=1,sz-1 do
|
|
n = n * 256 + data[i]
|
|
end
|
|
return n
|
|
end
|
|
|
|
local function parse_date_(data, sz)
|
|
assert(sz >= 10)
|
|
local z = ('0'):byte()
|
|
local year = (data[0] - z) * 1000 + (data[1] - z) * 100 + (data[2] - z) * 10 + (data[3] - z)
|
|
local month = (data[5] - z) * 10 + (data[6] - z)
|
|
local day = (data[8] - z) * 10 + (data[9] - z)
|
|
return year, month, day
|
|
end
|
|
|
|
local function parse_time_(data, sz)
|
|
assert(sz >= 8)
|
|
local z = ('0'):byte()
|
|
local hour = (data[0] - z) * 10 + (data[1] - z)
|
|
local min = (data[3] - z) * 10 + (data[4] - z)
|
|
local sec = (data[6] - z) * 10 + (data[7] - z)
|
|
local frac = 0
|
|
for i = 9, sz-1 do
|
|
frac = frac * 10 + (data[i] - z)
|
|
end
|
|
return hour, min, sec, frac
|
|
end
|
|
|
|
local function format_date(year, month, day)
|
|
return string.format('%04d-%02d-%02d', year, month, day)
|
|
end
|
|
|
|
local function format_time(hour, min, sec, frac)
|
|
if frac and frac ~= 0 then
|
|
return string.format('%02d:%02d:%02d.%d', hour, min, sec, frac)
|
|
else
|
|
return string.format('%02d:%02d:%02d', hour, min, sec)
|
|
end
|
|
end
|
|
|
|
local function datetime_tostring(t)
|
|
local date, time
|
|
if t.year then
|
|
date = format_date(t.year, t.month, t.day)
|
|
end
|
|
if t.sec then
|
|
time = format_time(t.hour, t.min, t.sec, t.frac)
|
|
end
|
|
if date and time then
|
|
return date .. ' ' .. time
|
|
else
|
|
return assert(date or time)
|
|
end
|
|
end
|
|
|
|
local datetime_meta = {__tostring = datetime_tostring}
|
|
local function datetime(t)
|
|
return setmetatable(t, datetime_meta)
|
|
end
|
|
|
|
local function parse_date(data, sz)
|
|
local year, month, day = parse_date_(data, sz)
|
|
return datetime{year = year, month = month, day = day}
|
|
end
|
|
|
|
local function parse_time(data, sz)
|
|
local hour, min, sec, frac = parse_time_(data, sz)
|
|
return datetime{hour = hour, min = min, sec = sec, frac = frac}
|
|
end
|
|
|
|
local function parse_datetime(data, sz)
|
|
local year, month, day = parse_date_(data, sz)
|
|
local hour, min, sec, frac = parse_time_(data + 11, sz - 11)
|
|
return datetime{year = year, month = month, day = day, hour = hour, min = min, sec = sec, frac = frac}
|
|
end
|
|
|
|
local field_decoders = { --other field types not present here are returned as strings, unparsed
|
|
[C.MYSQL_TYPE_TINY] = parse_int,
|
|
[C.MYSQL_TYPE_SHORT] = parse_int,
|
|
[C.MYSQL_TYPE_LONG] = parse_int,
|
|
[C.MYSQL_TYPE_FLOAT] = parse_float,
|
|
[C.MYSQL_TYPE_DOUBLE] = parse_double,
|
|
[C.MYSQL_TYPE_TIMESTAMP] = parse_datetime,
|
|
[C.MYSQL_TYPE_LONGLONG] = parse_int64,
|
|
[C.MYSQL_TYPE_INT24] = parse_int,
|
|
[C.MYSQL_TYPE_DATE] = parse_date,
|
|
[C.MYSQL_TYPE_TIME] = parse_time,
|
|
[C.MYSQL_TYPE_DATETIME] = parse_datetime,
|
|
[C.MYSQL_TYPE_NEWDATE] = parse_date,
|
|
[C.MYSQL_TYPE_TIMESTAMP2] = parse_datetime,
|
|
[C.MYSQL_TYPE_DATETIME2] = parse_datetime,
|
|
[C.MYSQL_TYPE_TIME2] = parse_time,
|
|
[C.MYSQL_TYPE_YEAR] = parse_int,
|
|
[C.MYSQL_TYPE_BIT] = parse_bit,
|
|
}
|
|
|
|
local unsigned_decoders = {
|
|
[C.MYSQL_TYPE_LONGLONG] = parse_uint64,
|
|
}
|
|
|
|
local function mode_flags(mode)
|
|
local assoc = mode and mode:find'a'
|
|
local numeric = not mode or not assoc or mode:find'n'
|
|
local decode = not mode or not mode:find's'
|
|
local packed = mode and mode:find'[an]'
|
|
local fetch_fields = assoc or decode --if assoc we need field_name, if decode we need field_type
|
|
return numeric, assoc, decode, packed, fetch_fields
|
|
end
|
|
|
|
local function fetch_row(res, numeric, assoc, decode, field_count, fields, t)
|
|
local values = C.mysql_fetch_row(res)
|
|
if values == NULL then
|
|
if res.conn ~= NULL then --buffered read: check for errors
|
|
myerror(res.conn)
|
|
end
|
|
return nil
|
|
end
|
|
local sizes = C.mysql_fetch_lengths(res)
|
|
for i=0,field_count-1 do
|
|
local v = values[i]
|
|
if v ~= NULL then
|
|
local decoder
|
|
if decode then
|
|
local ftype = tonumber(fields[i].type)
|
|
local unsigned = bit.bor(fields[i].flags, C.MYSQL_UNSIGNED_FLAG) ~= 0
|
|
decoder = unsigned and unsigned_decoders[ftype] or field_decoders[ftype] or ffi.string
|
|
else
|
|
decoder = ffi.string
|
|
end
|
|
v = decoder(values[i], sizes[i])
|
|
if numeric then
|
|
t[i+1] = v
|
|
end
|
|
if assoc then
|
|
local k = ffi.string(fields[i].name, fields[i].name_length)
|
|
t[k] = v
|
|
end
|
|
end
|
|
end
|
|
return t
|
|
end
|
|
|
|
function res.fetch(res, mode, t)
|
|
local numeric, assoc, decode, packed, fetch_fields = mode_flags(mode)
|
|
local field_count = C.mysql_num_fields(res)
|
|
local fields = fetch_fields and C.mysql_fetch_fields(res)
|
|
local row = fetch_row(res, numeric, assoc, decode, field_count, fields, t or {})
|
|
if packed then
|
|
return row
|
|
else
|
|
return true, unpack(row)
|
|
end
|
|
end
|
|
|
|
function res.rows(res, mode, t)
|
|
local numeric, assoc, decode, packed, fetch_fields = mode_flags(mode)
|
|
local field_count = C.mysql_num_fields(res)
|
|
local fields = fetch_fields and C.mysql_fetch_fields(res)
|
|
local i = 0
|
|
res:seek(1)
|
|
return function()
|
|
local row = fetch_row(res, numeric, assoc, decode, field_count, fields, t or {})
|
|
if not row then return end
|
|
i = i + 1
|
|
if packed then
|
|
return i, row
|
|
else
|
|
return i, unpack(row)
|
|
end
|
|
end
|
|
end
|
|
|
|
res.tell = C.mysql_row_tell
|
|
|
|
function res.seek(res, where) --use in conjunction with res:row_count()
|
|
if type(where) == 'number' then
|
|
C.mysql_data_seek(res, where-1)
|
|
else
|
|
C.mysql_row_seek(res, where)
|
|
end
|
|
end
|
|
|
|
--reflection
|
|
|
|
local function list_function(func)
|
|
return function(mysql, wild)
|
|
local res = checkh(mysql, func(mysql, wild))
|
|
return ffi.gc(res, C.mysql_free_result)
|
|
end
|
|
end
|
|
|
|
conn.list_dbs = list_function(C.mysql_list_dbs)
|
|
conn.list_tables = list_function(C.mysql_list_tables)
|
|
conn.list_processes = result_function(C.mysql_list_processes)
|
|
|
|
--remote control
|
|
|
|
function conn.kill(mysql, pid)
|
|
checkz(mysql, C.mysql_kill(mysql, pid))
|
|
end
|
|
|
|
function conn.shutdown(mysql, level)
|
|
checkz(mysql, C.mysql_shutdown(mysql, enum(level or C.MYSQL_SHUTDOWN_DEFAULT, 'MYSQL_')))
|
|
end
|
|
|
|
function conn.refresh(mysql, t) --options are 'REFRESH_*' or mysql.C.MYSQL_REFRESH_* enums
|
|
local options = 0
|
|
if type(t) == 'number' then
|
|
options = t
|
|
else
|
|
for k,v in pairs(t) do
|
|
if v then
|
|
options = bit.bor(options, enum(k, 'MYSQL_'))
|
|
end
|
|
end
|
|
end
|
|
checkz(mysql, C.mysql_refresh(mysql, options))
|
|
end
|
|
|
|
function conn.dump_debug_info(mysql)
|
|
checkz(mysql, C.mysql_dump_debug_info(mysql))
|
|
end
|
|
|
|
--prepared statements
|
|
|
|
local function sterror(stmt)
|
|
local err = cstring(C.mysql_stmt_error(stmt))
|
|
if not err then return end
|
|
error(string.format('mysql error: %s', err))
|
|
end
|
|
|
|
local function stcheckz(stmt, ret)
|
|
if ret == 0 then return end
|
|
sterror(stmt)
|
|
end
|
|
|
|
local function stcheckbool(stmt, ret)
|
|
if ret == 1 then return end
|
|
sterror(stmt)
|
|
end
|
|
|
|
local function stcheckh(stmt, ret)
|
|
if ret ~= NULL then return ret end
|
|
sterror(stmt)
|
|
end
|
|
|
|
function conn.prepare(mysql, query)
|
|
local stmt = checkh(mysql, C.mysql_stmt_init(mysql))
|
|
ffi.gc(stmt, C.mysql_stmt_close)
|
|
stcheckz(stmt, C.mysql_stmt_prepare(stmt, query, #query))
|
|
return stmt
|
|
end
|
|
|
|
local stmt = {} --statement methods
|
|
|
|
function stmt.close(stmt)
|
|
stcheckbool(stmt, C.mysql_stmt_close(stmt))
|
|
ffi.gc(stmt, nil)
|
|
end
|
|
|
|
function stmt.exec(stmt)
|
|
stcheckz(stmt, C.mysql_stmt_execute(stmt))
|
|
end
|
|
|
|
function stmt.next_result(stmt)
|
|
local ret = C.mysql_stmt_next_result(stmt)
|
|
if ret == 0 then return true end
|
|
if ret == -1 then return false end
|
|
sterror(stmt)
|
|
end
|
|
|
|
function stmt.store_result(stmt)
|
|
stcheckz(stmt, C.mysql_stmt_store_result(stmt))
|
|
end
|
|
|
|
function stmt.free_result(stmt)
|
|
stcheckbool(stmt, C.mysql_stmt_free_result(stmt))
|
|
end
|
|
|
|
function stmt.row_count(stmt)
|
|
return tonumber(C.mysql_stmt_num_rows(stmt))
|
|
end
|
|
|
|
function stmt.affected_rows(stmt)
|
|
local n = C.mysql_stmt_affected_rows(stmt)
|
|
if n == minus1_uint64 then sterror(stmt) end
|
|
return tonumber(n)
|
|
end
|
|
|
|
stmt.insert_id = C.mysql_stmt_insert_id
|
|
stmt.field_count = C.mysql_stmt_field_count
|
|
stmt.param_count = C.mysql_stmt_param_count
|
|
|
|
function stmt.errno(stmt)
|
|
local err = C.mysql_stmt_errno(stmt)
|
|
if err == 0 then return end
|
|
return err
|
|
end
|
|
|
|
function stmt.sqlstate(stmt)
|
|
return cstring(C.mysql_stmt_sqlstate(stmt))
|
|
end
|
|
|
|
function stmt.result_metadata(stmt)
|
|
local res = stcheckh(stmt, C.mysql_stmt_result_metadata(stmt))
|
|
return ffi.gc(res, C.mysql_free_result)
|
|
end
|
|
|
|
function stmt.fields(stmt)
|
|
local res = stmt:result_metadata()
|
|
local fields = res:fields()
|
|
return function()
|
|
local i, info = fields()
|
|
if not i then
|
|
res:free()
|
|
end
|
|
return i, info
|
|
end
|
|
end
|
|
|
|
function stmt.fetch(stmt)
|
|
local ret = C.mysql_stmt_fetch(stmt)
|
|
if ret == 0 then return true end
|
|
if ret == C.MYSQL_NO_DATA then return false end
|
|
if ret == C.MYSQL_DATA_TRUNCATED then return true, 'truncated' end
|
|
sterror(stmt)
|
|
end
|
|
|
|
function stmt.reset(stmt)
|
|
stcheckz(stmt, C.mysql_stmt_reset(stmt))
|
|
end
|
|
|
|
stmt.tell = C.mysql_stmt_row_tell
|
|
|
|
function stmt.seek(stmt, where) --use in conjunction with stmt:row_count()
|
|
if type(where) == 'number' then
|
|
C.mysql_stmt_data_seek(stmt, where-1)
|
|
else
|
|
C.mysql_stmt_row_seek(stmt, where)
|
|
end
|
|
end
|
|
|
|
function stmt.write(stmt, param_number, data, size)
|
|
stcheckz(stmt, C.mysql_stmt_send_long_data(stmt, param_number, data, size or #data))
|
|
end
|
|
|
|
function stmt.update_max_length(stmt)
|
|
local attr = ffi.new'my_bool[1]'
|
|
stcheckz(stmt, C.mysql_stmt_attr_get(stmt, C.STMT_ATTR_UPDATE_MAX_LENGTH, attr))
|
|
return attr[0] == 1
|
|
end
|
|
|
|
function stmt.set_update_max_length(stmt, yes)
|
|
local attr = ffi.new('my_bool[1]', yes == nil or yes)
|
|
stcheckz(stmt, C.mysql_stmt_attr_set(stmt, C.STMT_ATTR_CURSOR_TYPE, attr))
|
|
end
|
|
|
|
function stmt.cursor_type(stmt)
|
|
local attr = ffi.new'uint32_t[1]'
|
|
stcheckz(stmt, C.mysql_stmt_attr_get(stmt, C.STMT_ATTR_CURSOR_TYPE, attr))
|
|
return attr[0]
|
|
end
|
|
|
|
function stmt.set_cursor_type(stmt, cursor_type)
|
|
local attr = ffi.new('uint32_t[1]', enum(cursor_type, 'MYSQL_'))
|
|
stcheckz(stmt, C.mysql_stmt_attr_set(stmt, C.STMT_ATTR_CURSOR_TYPE, attr))
|
|
end
|
|
|
|
function stmt.prefetch_rows(stmt)
|
|
local attr = ffi.new'uint32_t[1]'
|
|
stcheckz(stmt, C.mysql_stmt_attr_get(stmt, C.STMT_ATTR_PREFETCH_ROWS, attr))
|
|
return attr[0]
|
|
end
|
|
|
|
function stmt.set_prefetch_rows(stmt, n)
|
|
local attr = ffi.new('uint32_t[1]', n)
|
|
stcheckz(stmt, C.mysql_stmt_attr_set(stmt, C.STMT_ATTR_PREFETCH_ROWS, attr))
|
|
end
|
|
|
|
--prepared statements / bind buffers
|
|
|
|
--see http://dev.mysql.com/doc/refman/5.7/en/c-api-prepared-statement-type-codes.html
|
|
local bb_types_input = {
|
|
--conversion-free types
|
|
tinyint = C.MYSQL_TYPE_TINY,
|
|
smallint = C.MYSQL_TYPE_SHORT,
|
|
int = C.MYSQL_TYPE_LONG,
|
|
integer = C.MYSQL_TYPE_LONG, --alias of int
|
|
bigint = C.MYSQL_TYPE_LONGLONG,
|
|
float = C.MYSQL_TYPE_FLOAT,
|
|
double = C.MYSQL_TYPE_DOUBLE,
|
|
time = C.MYSQL_TYPE_TIME,
|
|
date = C.MYSQL_TYPE_DATE,
|
|
datetime = C.MYSQL_TYPE_DATETIME,
|
|
timestamp = C.MYSQL_TYPE_TIMESTAMP,
|
|
text = C.MYSQL_TYPE_STRING,
|
|
char = C.MYSQL_TYPE_STRING,
|
|
varchar = C.MYSQL_TYPE_STRING,
|
|
blob = C.MYSQL_TYPE_BLOB,
|
|
binary = C.MYSQL_TYPE_BLOB,
|
|
varbinary = C.MYSQL_TYPE_BLOB,
|
|
null = C.MYSQL_TYPE_NULL,
|
|
--conversion types (can only use one of the above C types)
|
|
mediumint = C.MYSQL_TYPE_LONG,
|
|
real = C.MYSQL_TYPE_DOUBLE,
|
|
decimal = C.MYSQL_TYPE_BLOB,
|
|
numeric = C.MYSQL_TYPE_BLOB,
|
|
year = C.MYSQL_TYPE_SHORT,
|
|
tinyblob = C.MYSQL_TYPE_BLOB,
|
|
tinytext = C.MYSQL_TYPE_BLOB,
|
|
mediumblob = C.MYSQL_TYPE_BLOB,
|
|
mediumtext = C.MYSQL_TYPE_BLOB,
|
|
longblob = C.MYSQL_TYPE_BLOB,
|
|
longtext = C.MYSQL_TYPE_BLOB,
|
|
bit = C.MYSQL_TYPE_LONGLONG, --MYSQL_TYPE_BIT is not available for input params
|
|
set = C.MYSQL_TYPE_BLOB,
|
|
enum = C.MYSQL_TYPE_BLOB,
|
|
}
|
|
|
|
local bb_types_output = {
|
|
--conversion-free types
|
|
tinyint = C.MYSQL_TYPE_TINY,
|
|
smallint = C.MYSQL_TYPE_SHORT,
|
|
mediumint = C.MYSQL_TYPE_INT24, --int32
|
|
int = C.MYSQL_TYPE_LONG,
|
|
integer = C.MYSQL_TYPE_LONG, --alias of int
|
|
bigint = C.MYSQL_TYPE_LONGLONG,
|
|
float = C.MYSQL_TYPE_FLOAT,
|
|
double = C.MYSQL_TYPE_DOUBLE,
|
|
real = C.MYSQL_TYPE_DOUBLE,
|
|
decimal = C.MYSQL_TYPE_NEWDECIMAL, --char[]
|
|
numeric = C.MYSQL_TYPE_NEWDECIMAL, --char[]
|
|
year = C.MYSQL_TYPE_SHORT,
|
|
time = C.MYSQL_TYPE_TIME,
|
|
date = C.MYSQL_TYPE_DATE,
|
|
datetime = C.MYSQL_TYPE_DATETIME,
|
|
timestamp = C.MYSQL_TYPE_TIMESTAMP,
|
|
char = C.MYSQL_TYPE_STRING,
|
|
binary = C.MYSQL_TYPE_STRING,
|
|
varchar = C.MYSQL_TYPE_VAR_STRING,
|
|
varbinary = C.MYSQL_TYPE_VAR_STRING,
|
|
tinyblob = C.MYSQL_TYPE_TINY_BLOB,
|
|
tinytext = C.MYSQL_TYPE_TINY_BLOB,
|
|
blob = C.MYSQL_TYPE_BLOB,
|
|
text = C.MYSQL_TYPE_BLOB,
|
|
mediumblob = C.MYSQL_TYPE_MEDIUM_BLOB,
|
|
mediumtext = C.MYSQL_TYPE_MEDIUM_BLOB,
|
|
longblob = C.MYSQL_TYPE_LONG_BLOB,
|
|
longtext = C.MYSQL_TYPE_LONG_BLOB,
|
|
bit = C.MYSQL_TYPE_BIT,
|
|
--conversion types (can only use one of the above C types)
|
|
null = C.MYSQL_TYPE_TINY,
|
|
set = C.MYSQL_TYPE_BLOB,
|
|
enum = C.MYSQL_TYPE_BLOB,
|
|
}
|
|
|
|
local number_types = {
|
|
[C.MYSQL_TYPE_TINY] = 'int8_t[1]',
|
|
[C.MYSQL_TYPE_SHORT] = 'int16_t[1]',
|
|
[C.MYSQL_TYPE_LONG] = 'int32_t[1]',
|
|
[C.MYSQL_TYPE_INT24] = 'int32_t[1]',
|
|
[C.MYSQL_TYPE_LONGLONG] = 'int64_t[1]',
|
|
[C.MYSQL_TYPE_FLOAT] = 'float[1]',
|
|
[C.MYSQL_TYPE_DOUBLE] = 'double[1]',
|
|
}
|
|
|
|
local uint_types = {
|
|
[C.MYSQL_TYPE_TINY] = 'uint8_t[1]',
|
|
[C.MYSQL_TYPE_SHORT] = 'uint16_t[1]',
|
|
[C.MYSQL_TYPE_LONG] = 'uint32_t[1]',
|
|
[C.MYSQL_TYPE_INT24] = 'uint32_t[1]',
|
|
[C.MYSQL_TYPE_LONGLONG] = 'uint64_t[1]',
|
|
}
|
|
|
|
local time_types = {
|
|
[C.MYSQL_TYPE_TIME] = true,
|
|
[C.MYSQL_TYPE_DATE] = true,
|
|
[C.MYSQL_TYPE_DATETIME] = true,
|
|
[C.MYSQL_TYPE_TIMESTAMP] = true,
|
|
}
|
|
|
|
local time_struct_types = {
|
|
[C.MYSQL_TYPE_TIME] = C.MYSQL_TIMESTAMP_TIME,
|
|
[C.MYSQL_TYPE_DATE] = C.MYSQL_TIMESTAMP_DATE,
|
|
[C.MYSQL_TYPE_DATETIME] = C.MYSQL_TIMESTAMP_DATETIME,
|
|
[C.MYSQL_TYPE_TIMESTAMP] = C.MYSQL_TIMESTAMP_DATETIME,
|
|
}
|
|
|
|
local params = {} --params bind buffer methods
|
|
local params_meta = {__index = params}
|
|
local fields = {} --params bind buffer methods
|
|
local fields_meta = {__index = fields}
|
|
|
|
-- "varchar(200)" -> "varchar", 200; "decimal(10,4)" -> "decimal", 12; "int unsigned" -> "int", nil, true
|
|
local function parse_type(s)
|
|
s = s:lower()
|
|
local unsigned = false
|
|
local rest = s:match'(.-)%s+unsigned$'
|
|
if rest then s, unsigned = rest, true end
|
|
local rest, sz = s:match'^%s*([^%(]+)%s*%(%s*(%d+)[^%)]*%)%s*$'
|
|
if rest then
|
|
s, sz = rest, assert(tonumber(sz), 'invalid type')
|
|
if s == 'decimal' or s == 'numeric' then --make room for the dot and the minus sign
|
|
sz = sz + 2
|
|
end
|
|
end
|
|
return s, sz, unsigned
|
|
end
|
|
|
|
local function bind_buffer(bb_types, meta, types)
|
|
local self = setmetatable({}, meta)
|
|
|
|
self.count = #types
|
|
self.buffer = ffi.new('MYSQL_BIND[?]', #types)
|
|
self.data = {} --data buffers, one for each field
|
|
self.lengths = ffi.new('uint32_t[?]', #types) --length buffers, one for each field
|
|
self.null_flags = ffi.new('my_bool[?]', #types) --null flag buffers, one for each field
|
|
self.error_flags = ffi.new('my_bool[?]', #types) --error (truncation) flag buffers, one for each field
|
|
|
|
for i,typedef in ipairs(types) do
|
|
local stype, size, unsigned = parse_type(typedef)
|
|
local btype = assert(bb_types[stype], 'invalid type')
|
|
local data
|
|
if stype == 'bit' then
|
|
if btype == C.MYSQL_TYPE_LONGLONG then --for input: use unsigned int64 and ignore size
|
|
data = ffi.new'uint64_t[1]'
|
|
self.buffer[i-1].is_unsigned = 1
|
|
size = 0
|
|
elseif btype == C.MYSQL_TYPE_BIT then --for output: use mysql conversion-free type
|
|
size = size or 64 --if missing size, assume maximum
|
|
size = math.ceil(size / 8)
|
|
assert(size >= 1 and size <= 8, 'invalid size')
|
|
data = ffi.new('uint8_t[?]', size)
|
|
end
|
|
elseif number_types[btype] then
|
|
assert(not size, 'fixed size type')
|
|
data = ffi.new(unsigned and uint_types[btype] or number_types[btype])
|
|
self.buffer[i-1].is_unsigned = unsigned
|
|
size = ffi.sizeof(data)
|
|
elseif time_types[btype] then
|
|
assert(not size, 'fixed size type')
|
|
data = ffi.new'MYSQL_TIME'
|
|
data.time_type = time_struct_types[btype]
|
|
size = 0
|
|
elseif btype == C.MYSQL_TYPE_NULL then
|
|
assert(not size, 'fixed size type')
|
|
size = 0
|
|
else
|
|
assert(size, 'missing size')
|
|
data = size > 0 and ffi.new('uint8_t[?]', size) or nil
|
|
end
|
|
self.null_flags[i-1] = true
|
|
self.data[i] = data
|
|
self.lengths[i-1] = 0
|
|
self.buffer[i-1].buffer_type = btype
|
|
self.buffer[i-1].buffer = data
|
|
self.buffer[i-1].buffer_length = size
|
|
self.buffer[i-1].is_null = self.null_flags + (i - 1)
|
|
self.buffer[i-1].error = self.error_flags + (i - 1)
|
|
self.buffer[i-1].length = self.lengths + (i - 1)
|
|
end
|
|
return self
|
|
end
|
|
|
|
local function params_bind_buffer(types)
|
|
return bind_buffer(bb_types_input, params_meta, types)
|
|
end
|
|
|
|
local function fields_bind_buffer(types)
|
|
return bind_buffer(bb_types_output, fields_meta, types)
|
|
end
|
|
|
|
local function bind_check_range(self, i)
|
|
assert(i >= 1 and i <= self.count, 'index out of bounds')
|
|
end
|
|
|
|
--realloc a buffer using supplied size. only for varsize fields.
|
|
function params:realloc(i, size)
|
|
bind_check_range(self, i)
|
|
assert(ffi.istype(data, 'uint8_t[?]'), 'attempt to realloc a fixed size field')
|
|
local data = size > 0 and ffi.new('uint8_t[?]', size) or nil
|
|
self.null_flags[i-1] = true
|
|
self.data[i] = data
|
|
self.lengths[i-1] = 0
|
|
self.buffer[i-1].buffer = data
|
|
self.buffer[i-1].buffer_length = size
|
|
end
|
|
|
|
fields.realloc = params.realloc
|
|
|
|
function fields:get_date(i)
|
|
bind_check_range(self, i)
|
|
local btype = tonumber(self.buffer[i-1].buffer_type)
|
|
local date = btype == C.MYSQL_TYPE_DATE or btype == C.MYSQL_TYPE_DATETIME or btype == C.MYSQL_TYPE_TIMESTAMP
|
|
local time = btype == C.MYSQL_TYPE_TIME or btype == C.MYSQL_TYPE_DATETIME or btype == C.MYSQL_TYPE_TIMESTAMP
|
|
assert(date or time, 'not a date/time type')
|
|
if self.null_flags[i-1] == 1 then return nil end
|
|
local tm = self.data[i]
|
|
return
|
|
date and tm.year or nil,
|
|
date and tm.month or nil,
|
|
date and tm.day or nil,
|
|
time and tm.hour or nil,
|
|
time and tm.minute or nil,
|
|
time and tm.second or nil,
|
|
time and tm.second_part or nil
|
|
end
|
|
|
|
function params:set_date(i, year, month, day, hour, min, sec, frac)
|
|
bind_check_range(self, i)
|
|
local tm = self.data[i]
|
|
local btype = tonumber(self.buffer[i-1].buffer_type)
|
|
local date = btype == C.MYSQL_TYPE_DATE or btype == C.MYSQL_TYPE_DATETIME or btype == C.MYSQL_TYPE_TIMESTAMP
|
|
local time = btype == C.MYSQL_TYPE_TIME or btype == C.MYSQL_TYPE_DATETIME or btype == C.MYSQL_TYPE_TIMESTAMP
|
|
assert(date or time, 'not a date/time type')
|
|
local tm = self.data[i]
|
|
tm.year = date and math.max(0, math.min(year or 0, 9999)) or 0
|
|
tm.month = date and math.max(1, math.min(month or 0, 12)) or 0
|
|
tm.day = date and math.max(1, math.min(day or 0, 31)) or 0
|
|
tm.hour = time and math.max(0, math.min(hour or 0, 59)) or 0
|
|
tm.minute = time and math.max(0, math.min(min or 0, 59)) or 0
|
|
tm.second = time and math.max(0, math.min(sec or 0, 59)) or 0
|
|
tm.second_part = time and math.max(0, math.min(frac or 0, 999999)) or 0
|
|
self.null_flags[i-1] = false
|
|
end
|
|
|
|
function params:set(i, v, size)
|
|
bind_check_range(self, i)
|
|
v = ptr(v)
|
|
if v == nil then
|
|
self.null_flags[i-1] = true
|
|
return
|
|
end
|
|
local btype = tonumber(self.buffer[i-1].buffer_type)
|
|
if btype == C.MYSQL_TYPE_NULL then
|
|
error('attempt to set a null type param')
|
|
elseif number_types[btype] then --this includes bit type which is LONGLONG
|
|
self.data[i][0] = v
|
|
self.null_flags[i-1] = false
|
|
elseif time_types[btype] then
|
|
self:set_date(i, v.year, v.month, v.day, v.hour, v.min, v.sec, v.frac)
|
|
else --var-sized types and raw bit blobs
|
|
size = size or #v
|
|
local bsize = tonumber(self.buffer[i-1].buffer_length)
|
|
assert(bsize >= size, 'string too long')
|
|
ffi.copy(self.data[i], v, size)
|
|
self.lengths[i-1] = size
|
|
self.null_flags[i-1] = false
|
|
end
|
|
end
|
|
|
|
function fields:get(i)
|
|
bind_check_range(self, i)
|
|
local btype = tonumber(self.buffer[i-1].buffer_type)
|
|
if btype == C.MYSQL_TYPE_NULL or self.null_flags[i-1] == 1 then
|
|
return nil
|
|
end
|
|
if number_types[btype] then
|
|
return self.data[i][0] --ffi converts this to a number or int64 type, which maches result:fetch() decoding
|
|
elseif time_types[btype] then
|
|
local t = self.data[i]
|
|
if t.time_type == C.MYSQL_TIMESTAMP_TIME then
|
|
return datetime{hour = t.hour, min = t.minute, sec = t.second, frac = t.second_part}
|
|
elseif t.time_type == C.MYSQL_TIMESTAMP_DATE then
|
|
return datetime{year = t.year, month = t.month, day = t.day}
|
|
elseif t.time_type == C.MYSQL_TIMESTAMP_DATETIME then
|
|
return datetime{year = t.year, month = t.month, day = t.day,
|
|
hour = t.hour, min = t.minute, sec = t.second, frac = t.second_part}
|
|
else
|
|
error'invalid time'
|
|
end
|
|
else
|
|
local sz = math.min(self.buffer[i-1].buffer_length, self.lengths[i-1])
|
|
if btype == C.MYSQL_TYPE_BIT then
|
|
return parse_bit(self.data[i], sz)
|
|
else
|
|
return ffi.string(self.data[i], sz)
|
|
end
|
|
end
|
|
end
|
|
|
|
function fields:is_null(i) --returns true if the field is null
|
|
bind_check_range(self, i)
|
|
local btype = self.buffer[i-1].buffer_type
|
|
return btype == C.MYSQL_TYPE_NULL or self.null_flags[i-1] == 1
|
|
end
|
|
|
|
function fields:is_truncated(i) --returns true if the field value was truncated
|
|
bind_check_range(self, i)
|
|
return self.error_flags[i-1] == 1
|
|
end
|
|
|
|
local varsize_types = {
|
|
char = true,
|
|
binary = true,
|
|
varchar = true,
|
|
varbinary = true,
|
|
tinyblob = true,
|
|
tinytext = true,
|
|
blob = true,
|
|
text = true,
|
|
mediumblob = true,
|
|
mediumtext = true,
|
|
longblob = true,
|
|
longtext = true,
|
|
bit = true,
|
|
set = true,
|
|
enum = true,
|
|
}
|
|
|
|
function stmt.bind_result_types(stmt, maxsize)
|
|
local types = {}
|
|
local field_count = stmt:field_count()
|
|
local res = stmt:result_metadata()
|
|
for i=1,field_count do
|
|
local ftype, size, unsigned, decimals = res:field_type(i)
|
|
if ftype == 'decimal' then
|
|
ftype = string.format('%s(%d,%d)', ftype, size-2, decimals)
|
|
elseif varsize_types[ftype] then
|
|
size = math.min(size, maxsize or 65535)
|
|
ftype = string.format('%s(%d)', ftype, size)
|
|
end
|
|
ftype = unsigned and ftype..' unsigned' or ftype
|
|
types[i] = ftype
|
|
end
|
|
res:free()
|
|
return types
|
|
end
|
|
|
|
function stmt.bind_params(stmt, ...)
|
|
local types = type(...) == 'string' and {...} or ... or {}
|
|
assert(stmt:param_count() == #types, 'wrong number of param types')
|
|
local bb = params_bind_buffer(types)
|
|
stcheckz(stmt, C.mysql_stmt_bind_param(stmt, bb.buffer))
|
|
return bb
|
|
end
|
|
|
|
function stmt.bind_result(stmt, arg1, ...)
|
|
local types
|
|
if type(arg1) == 'string' then
|
|
types = {arg1, ...}
|
|
elseif type(arg1) == 'number' then
|
|
types = stmt:bind_result_types(arg1)
|
|
elseif arg1 then
|
|
types = arg1
|
|
else
|
|
types = stmt:bind_result_types()
|
|
end
|
|
assert(stmt:field_count() == #types, 'wrong number of field types')
|
|
local bb = fields_bind_buffer(types)
|
|
stcheckz(stmt, C.mysql_stmt_bind_result(stmt, bb.buffer))
|
|
return bb
|
|
end
|
|
|
|
--publish methods
|
|
|
|
if not rawget(_G, '__MYSQL__') then
|
|
_G.__MYSQL__ = true
|
|
ffi.metatype('MYSQL', {__index = conn})
|
|
ffi.metatype('MYSQL_RES', {__index = res})
|
|
ffi.metatype('MYSQL_STMT', {__index = stmt})
|
|
end
|
|
|
|
--publish classes (for introspection, not extending)
|
|
|
|
M.conn = conn
|
|
M.res = res
|
|
M.stmt = stmt
|
|
M.params = params
|
|
M.fields = fields
|
|
|
|
if not ... then require'mysql_test' end
|
|
|
|
return M
|