mysql Lua+ffi binding http://luapower.com/mysql
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1305 lines
37KB

  1. --MySQL client library ffi binding.
  2. --Written by Cosmin Apreutesei. Public domain.
  3. --Supports MySQL Connector/C 6.1.
  4. --Based on MySQL 5.7 manual.
  5. if not ... then require'mysql_test'; return end
  6. local ffi = require'ffi'
  7. local bit = require'bit'
  8. require'mysql_h'
  9. local C
  10. local M = {}
  11. --select a mysql client library implementation.
  12. local function bind(lib)
  13. if not C then
  14. if not lib or lib == 'mysql' then
  15. C = ffi.load(ffi.abi'win' and 'libmysql' or 'mysqlclient')
  16. elseif lib == 'mariadb' then
  17. C = ffi.load'mariadb'
  18. elseif type(lib) == 'string' then
  19. C = ffi.load(lib)
  20. else
  21. C = lib
  22. end
  23. M.C = C
  24. end
  25. return M
  26. end
  27. M.bind = bind
  28. --we compare NULL pointers against NULL instead of nil for compatibility with luaffi.
  29. local NULL = ffi.cast('void*', nil)
  30. local function ptr(p) --convert NULLs to nil
  31. if p == NULL then return nil end
  32. return p
  33. end
  34. local function cstring(data) --convert null-term non-empty C strings to lua strings
  35. if data == NULL or data[0] == 0 then return nil end
  36. return ffi.string(data)
  37. end
  38. --error reporting
  39. local function myerror(mysql, stacklevel)
  40. local err = cstring(C.mysql_error(mysql))
  41. if not err then return end
  42. error(string.format('mysql error: %s', err), stacklevel or 3)
  43. end
  44. local function checkz(mysql, ret)
  45. if ret == 0 then return end
  46. myerror(mysql, 4)
  47. end
  48. local function checkh(mysql, ret)
  49. if ret ~= NULL then return ret end
  50. myerror(mysql, 4)
  51. end
  52. local function enum(e, prefix)
  53. local v = type(e) == 'string' and (prefix and C[prefix..e] or C[e]) or e
  54. return assert(v, 'invalid enum value')
  55. end
  56. --client library info
  57. function M.thread_safe()
  58. bind()
  59. return C.mysql_thread_safe() == 1
  60. end
  61. function M.client_info()
  62. bind()
  63. return cstring(C.mysql_get_client_info())
  64. end
  65. function M.client_version()
  66. bind()
  67. return tonumber(C.mysql_get_client_version())
  68. end
  69. --connections
  70. local function bool_ptr(b)
  71. return ffi.new('my_bool[1]', b or false)
  72. end
  73. local function uint_bool_ptr(b)
  74. return ffi.new('uint32_t[1]', b or false)
  75. end
  76. local function uint_ptr(i)
  77. return ffi.new('uint32_t[1]', i)
  78. end
  79. local function proto_ptr(proto) --proto is 'MYSQL_PROTOCOL_*' or mysql.C.MYSQL_PROTOCOL_*
  80. return ffi.new('uint32_t[1]', enum(proto))
  81. end
  82. local function ignore_arg()
  83. return nil
  84. end
  85. local option_encoders = {
  86. MYSQL_ENABLE_CLEARTEXT_PLUGIN = bool_ptr,
  87. MYSQL_OPT_LOCAL_INFILE = uint_bool_ptr,
  88. MYSQL_OPT_PROTOCOL = proto_ptr,
  89. MYSQL_OPT_READ_TIMEOUT = uint_ptr,
  90. MYSQL_OPT_WRITE_TIMEOUT = uint_ptr,
  91. MYSQL_OPT_USE_REMOTE_CONNECTION = ignore_arg,
  92. MYSQL_OPT_USE_EMBEDDED_CONNECTION = ignore_arg,
  93. MYSQL_OPT_GUESS_CONNECTION = ignore_arg,
  94. MYSQL_SECURE_AUTH = bool_ptr,
  95. MYSQL_REPORT_DATA_TRUNCATION = bool_ptr,
  96. MYSQL_OPT_RECONNECT = bool_ptr,
  97. MYSQL_OPT_SSL_VERIFY_SERVER_CERT = bool_ptr,
  98. MYSQL_ENABLE_CLEARTEXT_PLUGIN = bool_ptr,
  99. MYSQL_OPT_CAN_HANDLE_EXPIRED_PASSWORDS = bool_ptr,
  100. }
  101. function M.connect(t, ...)
  102. bind()
  103. local host, user, pass, db, charset, port
  104. local unix_socket, flags, options, attrs
  105. local key, cert, ca, capath, cipher
  106. if type(t) == 'string' then
  107. host, user, pass, db, charset, port = t, ...
  108. else
  109. host, user, pass, db, charset, port = t.host, t.user, t.pass, t.db, t.charset, t.port
  110. unix_socket, flags, options, attrs = t.unix_socket, t.flags, t.options, t.attrs
  111. key, cert, ca, capath, cipher = t.key, t.cert, t.ca, t.capath, t.cipher
  112. end
  113. port = port or 0
  114. local client_flag = 0
  115. if type(flags) == 'number' then
  116. client_flag = flags
  117. elseif flags then
  118. for k,v in pairs(flags) do
  119. local flag = enum(k, 'MYSQL_') --'CLIENT_*' or mysql.C.MYSQL_CLIENT_* enum
  120. client_flag = v and bit.bor(client_flag, flag) or bit.band(client_flag, bit.bnot(flag))
  121. end
  122. end
  123. local mysql = assert(C.mysql_init(nil))
  124. ffi.gc(mysql, C.mysql_close)
  125. if options then
  126. for k,v in pairs(options) do
  127. local opt = enum(k) --'MYSQL_OPT_*' or mysql.C.MYSQL_OPT_* enum
  128. local encoder = option_encoders[k]
  129. if encoder then v = encoder(v) end
  130. assert(C.mysql_options(mysql, opt, ffi.cast('const void*', v)) == 0, 'invalid option')
  131. end
  132. end
  133. if attrs then
  134. for k,v in pairs(attrs) do
  135. assert(C.mysql_options4(mysql, C.MYSQL_OPT_CONNECT_ATTR_ADD, k, v) == 0)
  136. end
  137. end
  138. if key then
  139. checkz(mysql, C.mysql_ssl_set(mysql, key, cert, ca, capath, cipher))
  140. end
  141. checkh(mysql, C.mysql_real_connect(mysql, host, user, pass, db, port, unix_socket, client_flag))
  142. if charset then mysql:set_charset(charset) end
  143. return mysql
  144. end
  145. local conn = {} --connection methods
  146. function conn.close(mysql)
  147. C.mysql_close(mysql)
  148. ffi.gc(mysql, nil)
  149. end
  150. function conn.set_charset(mysql, charset)
  151. checkz(mysql, C.mysql_set_character_set(mysql, charset))
  152. end
  153. function conn.select_db(mysql, db)
  154. checkz(mysql, C.mysql_select_db(mysql, db))
  155. end
  156. function conn.change_user(mysql, user, pass, db)
  157. checkz(mysql, C.mysql_change_user(mysql, user, pass, db))
  158. end
  159. function conn.set_multiple_statements(mysql, yes)
  160. checkz(mysql, C.mysql_set_server_option(mysql, yes and C.MYSQL_OPTION_MULTI_STATEMENTS_ON or
  161. C.MYSQL_OPTION_MULTI_STATEMENTS_OFF))
  162. end
  163. --connection info
  164. function conn.charset(mysql)
  165. return cstring(C.mysql_character_set_name(mysql))
  166. end
  167. function conn.charset_info(mysql)
  168. local info = ffi.new'MY_CHARSET_INFO'
  169. checkz(C.mysql_get_character_set_info(mysql, info))
  170. assert(info.name ~= NULL)
  171. assert(info.csname ~= NULL)
  172. return {
  173. number = info.number,
  174. state = info.state,
  175. name = cstring(info.csname), --csname and name are inverted from the spec
  176. collation = cstring(info.name),
  177. comment = cstring(info.comment),
  178. dir = cstring(info.dir),
  179. mbminlen = info.mbminlen,
  180. mbmaxlen = info.mbmaxlen,
  181. }
  182. end
  183. function conn.ping(mysql)
  184. local ret = C.mysql_ping(mysql)
  185. if ret == 0 then
  186. return true
  187. elseif C.mysql_error(mysql) == C.MYSQL_CR_SERVER_GONE_ERROR then
  188. return false
  189. end
  190. myerror(mysql)
  191. end
  192. function conn.thread_id(mysql)
  193. return C.mysql_thread_id(mysql) --NOTE: result is cdata on x64!
  194. end
  195. function conn.stat(mysql)
  196. return cstring(checkh(mysql, C.mysql_stat(mysql)))
  197. end
  198. function conn.server_info(mysql)
  199. return cstring(checkh(mysql, C.mysql_get_server_info(mysql)))
  200. end
  201. function conn.host_info(mysql)
  202. return cstring(checkh(mysql, C.mysql_get_host_info(mysql)))
  203. end
  204. function conn.server_version(mysql)
  205. return tonumber(C.mysql_get_server_version(mysql))
  206. end
  207. function conn.proto_info(...)
  208. return C.mysql_get_proto_info(...)
  209. end
  210. function conn.ssl_cipher(mysql)
  211. return cstring(C.mysql_get_ssl_cipher(mysql))
  212. end
  213. --transactions
  214. function conn.commit(mysql) checkz(mysql, C.mysql_commit(mysql)) end
  215. function conn.rollback(mysql) checkz(mysql, C.mysql_rollback(mysql)) end
  216. function conn.set_autocommit(mysql, yes)
  217. checkz(mysql, C.mysql_autocommit(mysql, yes == nil or yes))
  218. end
  219. --queries
  220. function conn.escape_tobuffer(mysql, data, size, buf, sz)
  221. size = size or #data
  222. assert(sz >= size * 2 + 1)
  223. return tonumber(C.mysql_real_escape_string(mysql, buf, data, size))
  224. end
  225. function conn.escape(mysql, data, size)
  226. size = size or #data
  227. local sz = size * 2 + 1
  228. local buf = ffi.new('uint8_t[?]', sz)
  229. sz = conn.escape_tobuffer(mysql, data, size, buf, sz)
  230. return ffi.string(buf, sz)
  231. end
  232. function conn.query(mysql, data, size)
  233. checkz(mysql, C.mysql_real_query(mysql, data, size or #data))
  234. end
  235. --query info
  236. function conn.field_count(...)
  237. return C.mysql_field_count(...)
  238. end
  239. local minus1_uint64 = ffi.cast('uint64_t', ffi.cast('int64_t', -1))
  240. function conn.affected_rows(mysql)
  241. local n = C.mysql_affected_rows(mysql)
  242. if n == minus1_uint64 then myerror(mysql) end
  243. return tonumber(n)
  244. end
  245. function conn.insert_id(...)
  246. return C.mysql_insert_id(...) --NOTE: result is cdata on x64!
  247. end
  248. function conn.errno(conn)
  249. local err = C.mysql_errno(conn)
  250. if err == 0 then return end
  251. return err
  252. end
  253. function conn.sqlstate(mysql)
  254. return cstring(C.mysql_sqlstate(mysql))
  255. end
  256. function conn.warning_count(...)
  257. return C.mysql_warning_count(...)
  258. end
  259. function conn.info(mysql)
  260. return cstring(C.mysql_info(mysql))
  261. end
  262. --query results
  263. function conn.next_result(mysql) --multiple statement queries return multiple results
  264. local ret = C.mysql_next_result(mysql)
  265. if ret == 0 then return true end
  266. if ret == -1 then return false end
  267. myerror(mysql)
  268. end
  269. function conn.more_results(mysql)
  270. return C.mysql_more_results(mysql) == 1
  271. end
  272. local function result_function(func)
  273. return function(mysql)
  274. local res = checkh(mysql, C[func](mysql))
  275. return ffi.gc(res, C.mysql_free_result)
  276. end
  277. end
  278. conn.store_result = result_function'mysql_store_result'
  279. conn.use_result = result_function'mysql_use_result'
  280. local res = {} --result methods
  281. function res.free(res)
  282. C.mysql_free_result(res)
  283. ffi.gc(res, nil)
  284. end
  285. function res.row_count(res)
  286. return tonumber(C.mysql_num_rows(res))
  287. end
  288. function res.field_count(...)
  289. return C.mysql_num_fields(...)
  290. end
  291. function res.eof(res)
  292. return C.mysql_eof(res) ~= 0
  293. end
  294. --field info
  295. local field_type_names = {
  296. [ffi.C.MYSQL_TYPE_DECIMAL] = 'decimal', --DECIMAL or NUMERIC
  297. [ffi.C.MYSQL_TYPE_TINY] = 'tinyint',
  298. [ffi.C.MYSQL_TYPE_SHORT] = 'smallint',
  299. [ffi.C.MYSQL_TYPE_LONG] = 'int',
  300. [ffi.C.MYSQL_TYPE_FLOAT] = 'float',
  301. [ffi.C.MYSQL_TYPE_DOUBLE] = 'double', --DOUBLE or REAL
  302. [ffi.C.MYSQL_TYPE_NULL] = 'null',
  303. [ffi.C.MYSQL_TYPE_TIMESTAMP] = 'timestamp',
  304. [ffi.C.MYSQL_TYPE_LONGLONG] = 'bigint',
  305. [ffi.C.MYSQL_TYPE_INT24] = 'mediumint',
  306. [ffi.C.MYSQL_TYPE_DATE] = 'date', --pre mysql 5.0, storage = 4 bytes
  307. [ffi.C.MYSQL_TYPE_TIME] = 'time',
  308. [ffi.C.MYSQL_TYPE_DATETIME] = 'datetime',
  309. [ffi.C.MYSQL_TYPE_YEAR] = 'year',
  310. [ffi.C.MYSQL_TYPE_NEWDATE] = 'date', --mysql 5.0+, storage = 3 bytes
  311. [ffi.C.MYSQL_TYPE_VARCHAR] = 'varchar',
  312. [ffi.C.MYSQL_TYPE_BIT] = 'bit',
  313. [ffi.C.MYSQL_TYPE_TIMESTAMP2] = 'timestamp', --mysql 5.6+, can store fractional seconds
  314. [ffi.C.MYSQL_TYPE_DATETIME2] = 'datetime', --mysql 5.6+, can store fractional seconds
  315. [ffi.C.MYSQL_TYPE_TIME2] = 'time', --mysql 5.6+, can store fractional seconds
  316. [ffi.C.MYSQL_TYPE_NEWDECIMAL] = 'decimal', --mysql 5.0+, Precision math DECIMAL or NUMERIC
  317. [ffi.C.MYSQL_TYPE_ENUM] = 'enum',
  318. [ffi.C.MYSQL_TYPE_SET] = 'set',
  319. [ffi.C.MYSQL_TYPE_TINY_BLOB] = 'tinyblob',
  320. [ffi.C.MYSQL_TYPE_MEDIUM_BLOB] = 'mediumblob',
  321. [ffi.C.MYSQL_TYPE_LONG_BLOB] = 'longblob',
  322. [ffi.C.MYSQL_TYPE_BLOB] = 'text', --TEXT or BLOB
  323. [ffi.C.MYSQL_TYPE_VAR_STRING] = 'varchar', --VARCHAR or VARBINARY
  324. [ffi.C.MYSQL_TYPE_STRING] = 'char', --CHAR or BINARY
  325. [ffi.C.MYSQL_TYPE_GEOMETRY] = 'spatial', --Spatial field
  326. }
  327. local binary_field_type_names = {
  328. [ffi.C.MYSQL_TYPE_BLOB] = 'blob',
  329. [ffi.C.MYSQL_TYPE_VAR_STRING] = 'varbinary',
  330. [ffi.C.MYSQL_TYPE_STRING] = 'binary',
  331. }
  332. local field_flag_names = {
  333. [ffi.C.MYSQL_NOT_NULL_FLAG] = 'not_null',
  334. [ffi.C.MYSQL_PRI_KEY_FLAG] = 'pri_key',
  335. [ffi.C.MYSQL_UNIQUE_KEY_FLAG] = 'unique_key',
  336. [ffi.C.MYSQL_MULTIPLE_KEY_FLAG] = 'key',
  337. [ffi.C.MYSQL_BLOB_FLAG] = 'is_blob',
  338. [ffi.C.MYSQL_UNSIGNED_FLAG] = 'unsigned',
  339. [ffi.C.MYSQL_ZEROFILL_FLAG] = 'zerofill',
  340. [ffi.C.MYSQL_BINARY_FLAG] = 'is_binary',
  341. [ffi.C.MYSQL_ENUM_FLAG] = 'is_enum',
  342. [ffi.C.MYSQL_AUTO_INCREMENT_FLAG] = 'autoincrement',
  343. [ffi.C.MYSQL_TIMESTAMP_FLAG] = 'is_timestamp',
  344. [ffi.C.MYSQL_SET_FLAG] = 'is_set',
  345. [ffi.C.MYSQL_NO_DEFAULT_VALUE_FLAG] = 'no_default',
  346. [ffi.C.MYSQL_ON_UPDATE_NOW_FLAG] = 'on_update_now',
  347. [ffi.C.MYSQL_NUM_FLAG] = 'is_number',
  348. }
  349. local function field_type_name(info)
  350. local type_flag = tonumber(info.type)
  351. local field_type = field_type_names[type_flag]
  352. --charsetnr 63 changes CHAR into BINARY, VARCHAR into VARBYNARY, TEXT into BLOB
  353. field_type = info.charsetnr == 63 and binary_field_type_names[type_flag] or field_type
  354. return field_type
  355. end
  356. --convenience field type fetcher (less garbage)
  357. function res.field_type(res, i)
  358. assert(i >= 1 and i <= res:field_count(), 'index out of range')
  359. local info = C.mysql_fetch_field_direct(res, i-1)
  360. local unsigned = bit.bor(info.flags, C.MYSQL_UNSIGNED_FLAG) ~= 0
  361. return field_type_name(info), tonumber(info.length), unsigned, info.decimals
  362. end
  363. function res.field_info(res, i)
  364. assert(i >= 1 and i <= res:field_count(), 'index out of range')
  365. local info = C.mysql_fetch_field_direct(res, i-1)
  366. local t = {
  367. name = cstring(info.name, info.name_length),
  368. org_name = cstring(info.org_name, info.org_name_length),
  369. table = cstring(info.table, info.table_length),
  370. org_table = cstring(info.org_table, info.org_table_length),
  371. db = cstring(info.db, info.db_length),
  372. catalog = cstring(info.catalog, info.catalog_length),
  373. def = cstring(info.def, info.def_length),
  374. length = tonumber(info.length),
  375. max_length = tonumber(info.max_length),
  376. decimals = info.decimals,
  377. charsetnr = info.charsetnr,
  378. type_flag = tonumber(info.type),
  379. type = field_type_name(info),
  380. flags = info.flags,
  381. extension = ptr(info.extension),
  382. }
  383. for flag, name in pairs(field_flag_names) do
  384. t[name] = bit.band(flag, info.flags) ~= 0
  385. end
  386. return t
  387. end
  388. --convenience field name fetcher (less garbage)
  389. function res.field_name(res, i)
  390. assert(i >= 1 and i <= res:field_count(), 'index out of range')
  391. local info = C.mysql_fetch_field_direct(res, i-1)
  392. return cstring(info.name, info.name_length)
  393. end
  394. --convenience field iterator, shortcut for: for i=1,res:field_count() do local field = res:field_info(i) ... end
  395. function res.fields(res)
  396. local n = res:field_count()
  397. local i = 0
  398. return function()
  399. if i == n then return end
  400. i = i + 1
  401. return i, res:field_info(i)
  402. end
  403. end
  404. --row data fetching and parsing
  405. ffi.cdef('double strtod(const char*, char**);')
  406. local function parse_int(data, sz) --using strtod to avoid string creation
  407. return ffi.C.strtod(data, nil)
  408. end
  409. local function parse_float(data, sz)
  410. return tonumber(ffi.cast('float', ffi.C.strtod(data, nil))) --because windows is missing strtof()
  411. end
  412. local function parse_double(data, sz)
  413. return ffi.C.strtod(data, nil)
  414. end
  415. ffi.cdef('int64_t strtoll(const char*, char**, int) ' ..(ffi.os == 'Windows' and ' asm("_strtoi64")' or '') .. ';')
  416. local function parse_int64(data, sz)
  417. return ffi.C.strtoll(data, nil, 10)
  418. end
  419. ffi.cdef('uint64_t strtoull(const char*, char**, int) ' ..(ffi.os == 'Windows' and ' asm("_strtoui64")' or '') .. ';')
  420. local function parse_uint64(data, sz)
  421. return ffi.C.strtoull(data, nil, 10)
  422. end
  423. local function parse_bit(data, sz)
  424. data = ffi.cast('uint8_t*', data) --force unsigned
  425. local n = data[0] --this is the msb: bit fields always come in big endian byte order
  426. if sz > 6 then --we can cover up to 6 bytes with only Lua numbers
  427. n = ffi.new('uint64_t', n)
  428. end
  429. for i=1,sz-1 do
  430. n = n * 256 + data[i]
  431. end
  432. return n
  433. end
  434. local function parse_date_(data, sz)
  435. assert(sz >= 10)
  436. local z = ('0'):byte()
  437. local year = (data[0] - z) * 1000 + (data[1] - z) * 100 + (data[2] - z) * 10 + (data[3] - z)
  438. local month = (data[5] - z) * 10 + (data[6] - z)
  439. local day = (data[8] - z) * 10 + (data[9] - z)
  440. return year, month, day
  441. end
  442. local function parse_time_(data, sz)
  443. assert(sz >= 8)
  444. local z = ('0'):byte()
  445. local hour = (data[0] - z) * 10 + (data[1] - z)
  446. local min = (data[3] - z) * 10 + (data[4] - z)
  447. local sec = (data[6] - z) * 10 + (data[7] - z)
  448. local frac = 0
  449. for i = 9, sz-1 do
  450. frac = frac * 10 + (data[i] - z)
  451. end
  452. return hour, min, sec, frac
  453. end
  454. local function format_date(year, month, day)
  455. return string.format('%04d-%02d-%02d', year, month, day)
  456. end
  457. local function format_time(hour, min, sec, frac)
  458. if frac and frac ~= 0 then
  459. return string.format('%02d:%02d:%02d.%d', hour, min, sec, frac)
  460. else
  461. return string.format('%02d:%02d:%02d', hour, min, sec)
  462. end
  463. end
  464. local function datetime_tostring(t)
  465. local date, time
  466. if t.year then
  467. date = format_date(t.year, t.month, t.day)
  468. end
  469. if t.sec then
  470. time = format_time(t.hour, t.min, t.sec, t.frac)
  471. end
  472. if date and time then
  473. return date .. ' ' .. time
  474. else
  475. return assert(date or time)
  476. end
  477. end
  478. local datetime_meta = {__tostring = datetime_tostring}
  479. local function datetime(t)
  480. return setmetatable(t, datetime_meta)
  481. end
  482. local function parse_date(data, sz)
  483. local year, month, day = parse_date_(data, sz)
  484. return datetime{year = year, month = month, day = day}
  485. end
  486. local function parse_time(data, sz)
  487. local hour, min, sec, frac = parse_time_(data, sz)
  488. return datetime{hour = hour, min = min, sec = sec, frac = frac}
  489. end
  490. local function parse_datetime(data, sz)
  491. local year, month, day = parse_date_(data, sz)
  492. local hour, min, sec, frac = parse_time_(data + 11, sz - 11)
  493. return datetime{year = year, month = month, day = day, hour = hour, min = min, sec = sec, frac = frac}
  494. end
  495. local field_decoders = { --other field types not present here are returned as strings, unparsed
  496. [ffi.C.MYSQL_TYPE_TINY] = parse_int,
  497. [ffi.C.MYSQL_TYPE_SHORT] = parse_int,
  498. [ffi.C.MYSQL_TYPE_LONG] = parse_int,
  499. [ffi.C.MYSQL_TYPE_FLOAT] = parse_float,
  500. [ffi.C.MYSQL_TYPE_DOUBLE] = parse_double,
  501. [ffi.C.MYSQL_TYPE_TIMESTAMP] = parse_datetime,
  502. [ffi.C.MYSQL_TYPE_LONGLONG] = parse_int64,
  503. [ffi.C.MYSQL_TYPE_INT24] = parse_int,
  504. [ffi.C.MYSQL_TYPE_DATE] = parse_date,
  505. [ffi.C.MYSQL_TYPE_TIME] = parse_time,
  506. [ffi.C.MYSQL_TYPE_DATETIME] = parse_datetime,
  507. [ffi.C.MYSQL_TYPE_NEWDATE] = parse_date,
  508. [ffi.C.MYSQL_TYPE_TIMESTAMP2] = parse_datetime,
  509. [ffi.C.MYSQL_TYPE_DATETIME2] = parse_datetime,
  510. [ffi.C.MYSQL_TYPE_TIME2] = parse_time,
  511. [ffi.C.MYSQL_TYPE_YEAR] = parse_int,
  512. [ffi.C.MYSQL_TYPE_BIT] = parse_bit,
  513. }
  514. local unsigned_decoders = {
  515. [ffi.C.MYSQL_TYPE_LONGLONG] = parse_uint64,
  516. }
  517. local function mode_flags(mode)
  518. local assoc = mode and mode:find'a'
  519. local numeric = not mode or not assoc or mode:find'n'
  520. local decode = not mode or not mode:find's'
  521. local packed = mode and mode:find'[an]'
  522. local fetch_fields = assoc or decode --if assoc we need field_name, if decode we need field_type
  523. return numeric, assoc, decode, packed, fetch_fields
  524. end
  525. local function fetch_row(res, numeric, assoc, decode, field_count, fields, t)
  526. local values = C.mysql_fetch_row(res)
  527. if values == NULL then
  528. if res.conn ~= NULL then --buffered read: check for errors
  529. myerror(res.conn, 4)
  530. end
  531. return nil
  532. end
  533. local sizes = C.mysql_fetch_lengths(res)
  534. for i=0,field_count-1 do
  535. local v = values[i]
  536. if v ~= NULL then
  537. local decoder
  538. if decode then
  539. local ftype = tonumber(fields[i].type)
  540. local unsigned = bit.bor(fields[i].flags, C.MYSQL_UNSIGNED_FLAG) ~= 0
  541. decoder = unsigned and unsigned_decoders[ftype] or field_decoders[ftype] or ffi.string
  542. else
  543. decoder = ffi.string
  544. end
  545. v = decoder(values[i], tonumber(sizes[i]))
  546. if numeric then
  547. t[i+1] = v
  548. end
  549. if assoc then
  550. local k = ffi.string(fields[i].name, fields[i].name_length)
  551. t[k] = v
  552. end
  553. end
  554. end
  555. return t
  556. end
  557. function res.fetch(res, mode, t)
  558. local numeric, assoc, decode, packed, fetch_fields = mode_flags(mode)
  559. local field_count = C.mysql_num_fields(res)
  560. local fields = fetch_fields and C.mysql_fetch_fields(res)
  561. local row = fetch_row(res, numeric, assoc, decode, field_count, fields, t or {})
  562. if not row then return nil end
  563. if packed then
  564. return row
  565. else
  566. return true, unpack(row)
  567. end
  568. end
  569. function res.rows(res, mode, t)
  570. local numeric, assoc, decode, packed, fetch_fields = mode_flags(mode)
  571. local field_count = C.mysql_num_fields(res)
  572. local fields = fetch_fields and C.mysql_fetch_fields(res)
  573. local i = 0
  574. res:seek(1)
  575. return function()
  576. local row = fetch_row(res, numeric, assoc, decode, field_count, fields, t or {})
  577. if not row then return nil end
  578. i = i + 1
  579. if packed then
  580. return i, row
  581. else
  582. return i, unpack(row)
  583. end
  584. end
  585. end
  586. function res.tell(...)
  587. return C.mysql_row_tell(...)
  588. end
  589. function res.seek(res, where) --use in conjunction with res:row_count()
  590. if type(where) == 'number' then
  591. C.mysql_data_seek(res, where-1)
  592. else
  593. C.mysql_row_seek(res, where)
  594. end
  595. end
  596. --reflection
  597. local function list_function(func)
  598. return function(mysql, wild)
  599. local res = checkh(mysql, C[func](mysql, wild))
  600. return ffi.gc(res, C.mysql_free_result)
  601. end
  602. end
  603. conn.list_dbs = list_function'mysql_list_dbs'
  604. conn.list_tables = list_function'mysql_list_tables'
  605. conn.list_processes = result_function'mysql_list_processes'
  606. --remote control
  607. function conn.kill(mysql, pid)
  608. checkz(mysql, C.mysql_kill(mysql, pid))
  609. end
  610. function conn.shutdown(mysql, level)
  611. checkz(mysql, C.mysql_shutdown(mysql, enum(level or C.MYSQL_SHUTDOWN_DEFAULT, 'MYSQL_')))
  612. end
  613. function conn.refresh(mysql, t) --options are 'REFRESH_*' or mysql.C.MYSQL_REFRESH_* enums
  614. local options = 0
  615. if type(t) == 'number' then
  616. options = t
  617. else
  618. for k,v in pairs(t) do
  619. if v then
  620. options = bit.bor(options, enum(k, 'MYSQL_'))
  621. end
  622. end
  623. end
  624. checkz(mysql, C.mysql_refresh(mysql, options))
  625. end
  626. function conn.dump_debug_info(mysql)
  627. checkz(mysql, C.mysql_dump_debug_info(mysql))
  628. end
  629. --prepared statements
  630. local function sterror(stmt, stacklevel)
  631. local err = cstring(C.mysql_stmt_error(stmt))
  632. if not err then return end
  633. error(string.format('mysql error: %s', err), stacklevel or 3)
  634. end
  635. local function stcheckz(stmt, ret)
  636. if ret == 0 then return end
  637. sterror(stmt, 4)
  638. end
  639. local function stcheckbool(stmt, ret)
  640. if ret == 1 then return end
  641. sterror(stmt, 4)
  642. end
  643. local function stcheckh(stmt, ret)
  644. if ret ~= NULL then return ret end
  645. sterror(stmt, 4)
  646. end
  647. function conn.prepare(mysql, query)
  648. local stmt = checkh(mysql, C.mysql_stmt_init(mysql))
  649. ffi.gc(stmt, C.mysql_stmt_close)
  650. stcheckz(stmt, C.mysql_stmt_prepare(stmt, query, #query))
  651. return stmt
  652. end
  653. local stmt = {} --statement methods
  654. function stmt.close(stmt)
  655. stcheckbool(stmt, C.mysql_stmt_close(stmt))
  656. ffi.gc(stmt, nil)
  657. end
  658. function stmt.exec(stmt)
  659. stcheckz(stmt, C.mysql_stmt_execute(stmt))
  660. end
  661. function stmt.next_result(stmt)
  662. local ret = C.mysql_stmt_next_result(stmt)
  663. if ret == 0 then return true end
  664. if ret == -1 then return false end
  665. sterror(stmt)
  666. end
  667. function stmt.store_result(stmt)
  668. stcheckz(stmt, C.mysql_stmt_store_result(stmt))
  669. end
  670. function stmt.free_result(stmt)
  671. stcheckbool(stmt, C.mysql_stmt_free_result(stmt))
  672. end
  673. function stmt.row_count(stmt)
  674. return tonumber(C.mysql_stmt_num_rows(stmt))
  675. end
  676. function stmt.affected_rows(stmt)
  677. local n = C.mysql_stmt_affected_rows(stmt)
  678. if n == minus1_uint64 then sterror(stmt) end
  679. return tonumber(n)
  680. end
  681. function stmt.insert_id(...)
  682. return C.mysql_stmt_insert_id(...)
  683. end
  684. function stmt.field_count(stmt)
  685. return tonumber(C.mysql_stmt_field_count(stmt))
  686. end
  687. function stmt.param_count(stmt)
  688. return tonumber(C.mysql_stmt_param_count(stmt))
  689. end
  690. function stmt.errno(stmt)
  691. local err = C.mysql_stmt_errno(stmt)
  692. if err == 0 then return end
  693. return err
  694. end
  695. function stmt.sqlstate(stmt)
  696. return cstring(C.mysql_stmt_sqlstate(stmt))
  697. end
  698. function stmt.result_metadata(stmt)
  699. local res = stcheckh(stmt, C.mysql_stmt_result_metadata(stmt))
  700. return res and ffi.gc(res, C.mysql_free_result)
  701. end
  702. function stmt.fields(stmt)
  703. local res = stmt:result_metadata()
  704. if not res then return nil end
  705. local fields = res:fields()
  706. return function()
  707. local i, info = fields()
  708. if not i then
  709. res:free()
  710. end
  711. return i, info
  712. end
  713. end
  714. function stmt.fetch(stmt)
  715. local ret = C.mysql_stmt_fetch(stmt)
  716. if ret == 0 then return true end
  717. if ret == C.MYSQL_NO_DATA then return false end
  718. if ret == C.MYSQL_DATA_TRUNCATED then return true, 'truncated' end
  719. sterror(stmt)
  720. end
  721. function stmt.reset(stmt)
  722. stcheckz(stmt, C.mysql_stmt_reset(stmt))
  723. end
  724. function stmt.tell(...)
  725. return C.mysql_stmt_row_tell(...)
  726. end
  727. function stmt.seek(stmt, where) --use in conjunction with stmt:row_count()
  728. if type(where) == 'number' then
  729. C.mysql_stmt_data_seek(stmt, where-1)
  730. else
  731. C.mysql_stmt_row_seek(stmt, where)
  732. end
  733. end
  734. function stmt.write(stmt, param_number, data, size)
  735. stcheckz(stmt, C.mysql_stmt_send_long_data(stmt, param_number, data, size or #data))
  736. end
  737. function stmt.update_max_length(stmt)
  738. local attr = ffi.new'my_bool[1]'
  739. stcheckz(stmt, C.mysql_stmt_attr_get(stmt, C.STMT_ATTR_UPDATE_MAX_LENGTH, attr))
  740. return attr[0] == 1
  741. end
  742. function stmt.set_update_max_length(stmt, yes)
  743. local attr = ffi.new('my_bool[1]', yes == nil or yes)
  744. stcheckz(stmt, C.mysql_stmt_attr_set(stmt, C.STMT_ATTR_CURSOR_TYPE, attr))
  745. end
  746. function stmt.cursor_type(stmt)
  747. local attr = ffi.new'uint32_t[1]'
  748. stcheckz(stmt, C.mysql_stmt_attr_get(stmt, C.STMT_ATTR_CURSOR_TYPE, attr))
  749. return attr[0]
  750. end
  751. function stmt.set_cursor_type(stmt, cursor_type)
  752. local attr = ffi.new('uint32_t[1]', enum(cursor_type, 'MYSQL_'))
  753. stcheckz(stmt, C.mysql_stmt_attr_set(stmt, C.STMT_ATTR_CURSOR_TYPE, attr))
  754. end
  755. function stmt.prefetch_rows(stmt)
  756. local attr = ffi.new'uint32_t[1]'
  757. stcheckz(stmt, C.mysql_stmt_attr_get(stmt, C.STMT_ATTR_PREFETCH_ROWS, attr))
  758. return attr[0]
  759. end
  760. function stmt.set_prefetch_rows(stmt, n)
  761. local attr = ffi.new('uint32_t[1]', n)
  762. stcheckz(stmt, C.mysql_stmt_attr_set(stmt, C.STMT_ATTR_PREFETCH_ROWS, attr))
  763. end
  764. --prepared statements / bind buffers
  765. --see http://dev.mysql.com/doc/refman/5.7/en/c-api-prepared-statement-type-codes.html
  766. local bb_types_input = {
  767. --conversion-free types
  768. tinyint = ffi.C.MYSQL_TYPE_TINY,
  769. smallint = ffi.C.MYSQL_TYPE_SHORT,
  770. int = ffi.C.MYSQL_TYPE_LONG,
  771. integer = ffi.C.MYSQL_TYPE_LONG, --alias of int
  772. bigint = ffi.C.MYSQL_TYPE_LONGLONG,
  773. float = ffi.C.MYSQL_TYPE_FLOAT,
  774. double = ffi.C.MYSQL_TYPE_DOUBLE,
  775. time = ffi.C.MYSQL_TYPE_TIME,
  776. date = ffi.C.MYSQL_TYPE_DATE,
  777. datetime = ffi.C.MYSQL_TYPE_DATETIME,
  778. timestamp = ffi.C.MYSQL_TYPE_TIMESTAMP,
  779. text = ffi.C.MYSQL_TYPE_STRING,
  780. char = ffi.C.MYSQL_TYPE_STRING,
  781. varchar = ffi.C.MYSQL_TYPE_STRING,
  782. blob = ffi.C.MYSQL_TYPE_BLOB,
  783. binary = ffi.C.MYSQL_TYPE_BLOB,
  784. varbinary = ffi.C.MYSQL_TYPE_BLOB,
  785. null = ffi.C.MYSQL_TYPE_NULL,
  786. --conversion types (can only use one of the above C types)
  787. mediumint = ffi.C.MYSQL_TYPE_LONG,
  788. real = ffi.C.MYSQL_TYPE_DOUBLE,
  789. decimal = ffi.C.MYSQL_TYPE_BLOB,
  790. numeric = ffi.C.MYSQL_TYPE_BLOB,
  791. year = ffi.C.MYSQL_TYPE_SHORT,
  792. tinyblob = ffi.C.MYSQL_TYPE_BLOB,
  793. tinytext = ffi.C.MYSQL_TYPE_BLOB,
  794. mediumblob = ffi.C.MYSQL_TYPE_BLOB,
  795. mediumtext = ffi.C.MYSQL_TYPE_BLOB,
  796. longblob = ffi.C.MYSQL_TYPE_BLOB,
  797. longtext = ffi.C.MYSQL_TYPE_BLOB,
  798. bit = ffi.C.MYSQL_TYPE_LONGLONG, --MYSQL_TYPE_BIT is not available for input params
  799. set = ffi.C.MYSQL_TYPE_BLOB,
  800. enum = ffi.C.MYSQL_TYPE_BLOB,
  801. }
  802. local bb_types_output = {
  803. --conversion-free types
  804. tinyint = ffi.C.MYSQL_TYPE_TINY,
  805. smallint = ffi.C.MYSQL_TYPE_SHORT,
  806. mediumint = ffi.C.MYSQL_TYPE_INT24, --int32
  807. int = ffi.C.MYSQL_TYPE_LONG,
  808. integer = ffi.C.MYSQL_TYPE_LONG, --alias of int
  809. bigint = ffi.C.MYSQL_TYPE_LONGLONG,
  810. float = ffi.C.MYSQL_TYPE_FLOAT,
  811. double = ffi.C.MYSQL_TYPE_DOUBLE,
  812. real = ffi.C.MYSQL_TYPE_DOUBLE,
  813. decimal = ffi.C.MYSQL_TYPE_NEWDECIMAL, --char[]
  814. numeric = ffi.C.MYSQL_TYPE_NEWDECIMAL, --char[]
  815. year = ffi.C.MYSQL_TYPE_SHORT,
  816. time = ffi.C.MYSQL_TYPE_TIME,
  817. date = ffi.C.MYSQL_TYPE_DATE,
  818. datetime = ffi.C.MYSQL_TYPE_DATETIME,
  819. timestamp = ffi.C.MYSQL_TYPE_TIMESTAMP,
  820. char = ffi.C.MYSQL_TYPE_STRING,
  821. binary = ffi.C.MYSQL_TYPE_STRING,
  822. varchar = ffi.C.MYSQL_TYPE_VAR_STRING,
  823. varbinary = ffi.C.MYSQL_TYPE_VAR_STRING,
  824. tinyblob = ffi.C.MYSQL_TYPE_TINY_BLOB,
  825. tinytext = ffi.C.MYSQL_TYPE_TINY_BLOB,
  826. blob = ffi.C.MYSQL_TYPE_BLOB,
  827. text = ffi.C.MYSQL_TYPE_BLOB,
  828. mediumblob = ffi.C.MYSQL_TYPE_MEDIUM_BLOB,
  829. mediumtext = ffi.C.MYSQL_TYPE_MEDIUM_BLOB,
  830. longblob = ffi.C.MYSQL_TYPE_LONG_BLOB,
  831. longtext = ffi.C.MYSQL_TYPE_LONG_BLOB,
  832. bit = ffi.C.MYSQL_TYPE_BIT,
  833. --conversion types (can only use one of the above C types)
  834. null = ffi.C.MYSQL_TYPE_TINY,
  835. set = ffi.C.MYSQL_TYPE_BLOB,
  836. enum = ffi.C.MYSQL_TYPE_BLOB,
  837. }
  838. local number_types = {
  839. [ffi.C.MYSQL_TYPE_TINY] = 'int8_t[1]',
  840. [ffi.C.MYSQL_TYPE_SHORT] = 'int16_t[1]',
  841. [ffi.C.MYSQL_TYPE_LONG] = 'int32_t[1]',
  842. [ffi.C.MYSQL_TYPE_INT24] = 'int32_t[1]',
  843. [ffi.C.MYSQL_TYPE_LONGLONG] = 'int64_t[1]',
  844. [ffi.C.MYSQL_TYPE_FLOAT] = 'float[1]',
  845. [ffi.C.MYSQL_TYPE_DOUBLE] = 'double[1]',
  846. }
  847. local uint_types = {
  848. [ffi.C.MYSQL_TYPE_TINY] = 'uint8_t[1]',
  849. [ffi.C.MYSQL_TYPE_SHORT] = 'uint16_t[1]',
  850. [ffi.C.MYSQL_TYPE_LONG] = 'uint32_t[1]',
  851. [ffi.C.MYSQL_TYPE_INT24] = 'uint32_t[1]',
  852. [ffi.C.MYSQL_TYPE_LONGLONG] = 'uint64_t[1]',
  853. }
  854. local time_types = {
  855. [ffi.C.MYSQL_TYPE_TIME] = true,
  856. [ffi.C.MYSQL_TYPE_DATE] = true,
  857. [ffi.C.MYSQL_TYPE_DATETIME] = true,
  858. [ffi.C.MYSQL_TYPE_TIMESTAMP] = true,
  859. }
  860. local time_struct_types = {
  861. [ffi.C.MYSQL_TYPE_TIME] = ffi.C.MYSQL_TIMESTAMP_TIME,
  862. [ffi.C.MYSQL_TYPE_DATE] = ffi.C.MYSQL_TIMESTAMP_DATE,
  863. [ffi.C.MYSQL_TYPE_DATETIME] = ffi.C.MYSQL_TIMESTAMP_DATETIME,
  864. [ffi.C.MYSQL_TYPE_TIMESTAMP] = ffi.C.MYSQL_TIMESTAMP_DATETIME,
  865. }
  866. local params = {} --params bind buffer methods
  867. local params_meta = {__index = params}
  868. local fields = {} --params bind buffer methods
  869. local fields_meta = {__index = fields}
  870. -- "varchar(200)" -> "varchar", 200; "decimal(10,4)" -> "decimal", 12; "int unsigned" -> "int", nil, true
  871. local function parse_type(s)
  872. s = s:lower()
  873. local unsigned = false
  874. local rest = s:match'(.-)%s+unsigned$'
  875. if rest then s, unsigned = rest, true end
  876. local rest, sz = s:match'^%s*([^%(]+)%s*%(%s*(%d+)[^%)]*%)%s*$'
  877. if rest then
  878. s, sz = rest, assert(tonumber(sz), 'invalid type')
  879. if s == 'decimal' or s == 'numeric' then --make room for the dot and the minus sign
  880. sz = sz + 2
  881. end
  882. end
  883. return s, sz, unsigned
  884. end
  885. local function bind_buffer(bb_types, meta, types)
  886. local self = setmetatable({}, meta)
  887. self.count = #types
  888. self.buffer = ffi.new('MYSQL_BIND[?]', #types)
  889. self.data = {} --data buffers, one for each field
  890. self.lengths = ffi.new('unsigned long[?]', #types) --length buffers, one for each field
  891. self.null_flags = ffi.new('my_bool[?]', #types) --null flag buffers, one for each field
  892. self.error_flags = ffi.new('my_bool[?]', #types) --error (truncation) flag buffers, one for each field
  893. for i,typedef in ipairs(types) do
  894. local stype, size, unsigned = parse_type(typedef)
  895. local btype = assert(bb_types[stype], 'invalid type')
  896. local data
  897. if stype == 'bit' then
  898. if btype == C.MYSQL_TYPE_LONGLONG then --for input: use unsigned int64 and ignore size
  899. data = ffi.new'uint64_t[1]'
  900. self.buffer[i-1].is_unsigned = 1
  901. size = 0
  902. elseif btype == C.MYSQL_TYPE_BIT then --for output: use mysql conversion-free type
  903. size = size or 64 --if missing size, assume maximum
  904. size = math.ceil(size / 8)
  905. assert(size >= 1 and size <= 8, 'invalid size')
  906. data = ffi.new('uint8_t[?]', size)
  907. end
  908. elseif number_types[btype] then
  909. assert(not size, 'fixed size type')
  910. data = ffi.new(unsigned and uint_types[btype] or number_types[btype])
  911. self.buffer[i-1].is_unsigned = unsigned
  912. size = ffi.sizeof(data)
  913. elseif time_types[btype] then
  914. assert(not size, 'fixed size type')
  915. data = ffi.new'MYSQL_TIME'
  916. data.time_type = time_struct_types[btype]
  917. size = 0
  918. elseif btype == C.MYSQL_TYPE_NULL then
  919. assert(not size, 'fixed size type')
  920. size = 0
  921. else
  922. assert(size, 'missing size')
  923. data = size > 0 and ffi.new('uint8_t[?]', size) or nil
  924. end
  925. self.null_flags[i-1] = true
  926. self.data[i] = data
  927. self.lengths[i-1] = 0
  928. self.buffer[i-1].buffer_type = btype
  929. self.buffer[i-1].buffer = data
  930. self.buffer[i-1].buffer_length = size
  931. self.buffer[i-1].is_null = self.null_flags + (i - 1)
  932. self.buffer[i-1].error = self.error_flags + (i - 1)
  933. self.buffer[i-1].length = self.lengths + (i - 1)
  934. end
  935. return self
  936. end
  937. local function params_bind_buffer(types)
  938. return bind_buffer(bb_types_input, params_meta, types)
  939. end
  940. local function fields_bind_buffer(types)
  941. return bind_buffer(bb_types_output, fields_meta, types)
  942. end
  943. local function bind_check_range(self, i)
  944. assert(i >= 1 and i <= self.count, 'index out of bounds')
  945. end
  946. --realloc a buffer using supplied size. only for varsize fields.
  947. function params:realloc(i, size)
  948. bind_check_range(self, i)
  949. assert(ffi.istype(data, 'uint8_t[?]'), 'attempt to realloc a fixed size field')
  950. local data = size > 0 and ffi.new('uint8_t[?]', size) or nil
  951. self.null_flags[i-1] = true
  952. self.data[i] = data
  953. self.lengths[i-1] = 0
  954. self.buffer[i-1].buffer = data
  955. self.buffer[i-1].buffer_length = size
  956. end
  957. fields.realloc = params.realloc
  958. function fields:get_date(i)
  959. bind_check_range(self, i)
  960. local btype = tonumber(self.buffer[i-1].buffer_type)
  961. local date = btype == C.MYSQL_TYPE_DATE or btype == C.MYSQL_TYPE_DATETIME or btype == C.MYSQL_TYPE_TIMESTAMP
  962. local time = btype == C.MYSQL_TYPE_TIME or btype == C.MYSQL_TYPE_DATETIME or btype == C.MYSQL_TYPE_TIMESTAMP
  963. assert(date or time, 'not a date/time type')
  964. if self.null_flags[i-1] == 1 then return nil end
  965. local tm = self.data[i]
  966. return
  967. date and tm.year or nil,
  968. date and tm.month or nil,
  969. date and tm.day or nil,
  970. time and tm.hour or nil,
  971. time and tm.minute or nil,
  972. time and tm.second or nil,
  973. time and tonumber(tm.second_part) or nil
  974. end
  975. function params:set_date(i, year, month, day, hour, min, sec, frac)
  976. bind_check_range(self, i)
  977. local tm = self.data[i]
  978. local btype = tonumber(self.buffer[i-1].buffer_type)
  979. local date = btype == C.MYSQL_TYPE_DATE or btype == C.MYSQL_TYPE_DATETIME or btype == C.MYSQL_TYPE_TIMESTAMP
  980. local time = btype == C.MYSQL_TYPE_TIME or btype == C.MYSQL_TYPE_DATETIME or btype == C.MYSQL_TYPE_TIMESTAMP
  981. assert(date or time, 'not a date/time type')
  982. local tm = self.data[i]
  983. tm.year = date and math.max(0, math.min(year or 0, 9999)) or 0
  984. tm.month = date and math.max(1, math.min(month or 0, 12)) or 0
  985. tm.day = date and math.max(1, math.min(day or 0, 31)) or 0
  986. tm.hour = time and math.max(0, math.min(hour or 0, 59)) or 0
  987. tm.minute = time and math.max(0, math.min(min or 0, 59)) or 0
  988. tm.second = time and math.max(0, math.min(sec or 0, 59)) or 0
  989. tm.second_part = time and math.max(0, math.min(frac or 0, 999999)) or 0
  990. self.null_flags[i-1] = false
  991. end
  992. function params:set(i, v, size)
  993. bind_check_range(self, i)
  994. v = ptr(v)
  995. if v == nil then
  996. self.null_flags[i-1] = true
  997. return
  998. end
  999. local btype = tonumber(self.buffer[i-1].buffer_type)
  1000. if btype == C.MYSQL_TYPE_NULL then
  1001. error('attempt to set a null type param')
  1002. elseif number_types[btype] then --this includes bit type which is LONGLONG
  1003. self.data[i][0] = v
  1004. self.null_flags[i-1] = false
  1005. elseif time_types[btype] then
  1006. self:set_date(i, v.year, v.month, v.day, v.hour, v.min, v.sec, v.frac)
  1007. else --var-sized types and raw bit blobs
  1008. size = size or #v
  1009. local bsize = tonumber(self.buffer[i-1].buffer_length)
  1010. assert(bsize >= size, 'string too long')
  1011. ffi.copy(self.data[i], v, size)
  1012. self.lengths[i-1] = size
  1013. self.null_flags[i-1] = false
  1014. end
  1015. end
  1016. function fields:get(i)
  1017. bind_check_range(self, i)
  1018. local btype = tonumber(self.buffer[i-1].buffer_type)
  1019. if btype == C.MYSQL_TYPE_NULL or self.null_flags[i-1] == 1 then
  1020. return nil
  1021. end
  1022. if number_types[btype] then
  1023. return self.data[i][0] --ffi converts this to a number or int64 type, which maches result:fetch() decoding
  1024. elseif time_types[btype] then
  1025. local t = self.data[i]
  1026. if t.time_type == C.MYSQL_TIMESTAMP_TIME then
  1027. return datetime{hour = t.hour, min = t.minute, sec = t.second, frac = tonumber(t.second_part)}
  1028. elseif t.time_type == C.MYSQL_TIMESTAMP_DATE then
  1029. return datetime{year = t.year, month = t.month, day = t.day}
  1030. elseif t.time_type == C.MYSQL_TIMESTAMP_DATETIME then
  1031. return datetime{year = t.year, month = t.month, day = t.day,
  1032. hour = t.hour, min = t.minute, sec = t.second, frac = tonumber(t.second_part)}
  1033. else
  1034. error'invalid time'
  1035. end
  1036. else
  1037. local sz = math.min(tonumber(self.buffer[i-1].buffer_length), tonumber(self.lengths[i-1]))
  1038. if btype == C.MYSQL_TYPE_BIT then
  1039. return parse_bit(self.data[i], sz)
  1040. else
  1041. return ffi.string(self.data[i], sz)
  1042. end
  1043. end
  1044. end
  1045. function fields:is_null(i) --returns true if the field is null
  1046. bind_check_range(self, i)
  1047. local btype = self.buffer[i-1].buffer_type
  1048. return btype == C.MYSQL_TYPE_NULL or self.null_flags[i-1] == 1
  1049. end
  1050. function fields:is_truncated(i) --returns true if the field value was truncated
  1051. bind_check_range(self, i)
  1052. return self.error_flags[i-1] == 1
  1053. end
  1054. local varsize_types = {
  1055. char = true,
  1056. binary = true,
  1057. varchar = true,
  1058. varbinary = true,
  1059. tinyblob = true,
  1060. tinytext = true,
  1061. blob = true,
  1062. text = true,
  1063. mediumblob = true,
  1064. mediumtext = true,
  1065. longblob = true,
  1066. longtext = true,
  1067. bit = true,
  1068. set = true,
  1069. enum = true,
  1070. }
  1071. function stmt.bind_result_types(stmt, maxsize)
  1072. local types = {}
  1073. local field_count = stmt:field_count()
  1074. local res = stmt:result_metadata()
  1075. if not res then return nil end
  1076. for i=1,field_count do
  1077. local ftype, size, unsigned, decimals = res:field_type(i)
  1078. if ftype == 'decimal' then
  1079. ftype = string.format('%s(%d,%d)', ftype, size-2, decimals)
  1080. elseif varsize_types[ftype] then
  1081. size = math.min(size, maxsize or 65535)
  1082. ftype = string.format('%s(%d)', ftype, size)
  1083. end
  1084. ftype = unsigned and ftype..' unsigned' or ftype
  1085. types[i] = ftype
  1086. end
  1087. res:free()
  1088. return types
  1089. end
  1090. function stmt.bind_params(stmt, ...)
  1091. local types = type((...)) == 'string' and {...} or ... or {}
  1092. assert(stmt:param_count() == #types, 'wrong number of param types')
  1093. local bb = params_bind_buffer(types)
  1094. stcheckz(stmt, C.mysql_stmt_bind_param(stmt, bb.buffer))
  1095. return bb
  1096. end
  1097. function stmt.bind_result(stmt, arg1, ...)
  1098. local types
  1099. if type(arg1) == 'string' then
  1100. types = {arg1, ...}
  1101. elseif type(arg1) == 'number' then
  1102. types = stmt:bind_result_types(arg1)
  1103. elseif arg1 then
  1104. types = arg1
  1105. else
  1106. types = stmt:bind_result_types()
  1107. end
  1108. assert(stmt:field_count() == #types, 'wrong number of field types')
  1109. local bb = fields_bind_buffer(types)
  1110. stcheckz(stmt, C.mysql_stmt_bind_result(stmt, bb.buffer))
  1111. return bb
  1112. end
  1113. --publish methods
  1114. ffi.metatype('MYSQL', {__index = conn})
  1115. ffi.metatype('MYSQL_RES', {__index = res})
  1116. ffi.metatype('MYSQL_STMT', {__index = stmt})
  1117. --publish classes (for introspection, not extending)
  1118. M.conn = conn
  1119. M.res = res
  1120. M.stmt = stmt
  1121. M.params = params
  1122. M.fields = fields
  1123. return M