1
0
mirror of https://github.com/luapower/mysql.git synced 2025-07-01 16:00:33 +02:00

small fixes / documentation

This commit is contained in:
Cosmin Apreuetsei
2014-09-10 18:32:38 +03:00
parent 63beb9ab35
commit 439f6bc64f
3 changed files with 70 additions and 47 deletions

View File

@ -1,4 +1,10 @@
--mysql ffi binding (Cosmin Apreutesei, public domain). supports mysql Connector/C 6.1. based on mySQL 5.7 manual.
--mySQL client library ffi binding.
--Written by Cosmin Apreutesei. Public domain.
--Supports mysql Connector/C 6.1.
--Based on mySQL 5.7 manual.
local ffi = require'ffi'
local bit = require'bit'
require'mysql_h'
@ -22,24 +28,25 @@ end
--error reporting
local function myerror(mysql)
local function myerror(mysql, stacklevel)
local err = cstring(C.mysql_error(mysql))
if not err then return end
error(string.format('mysql error: %s', err))
error(string.format('mysql error: %s', err), stacklevel or 3)
end
local function checkz(mysql, ret)
if ret == 0 then return end
myerror(mysql)
myerror(mysql, 4)
end
local function checkh(mysql, ret)
if ret ~= NULL then return ret end
myerror(mysql)
myerror(mysql, 4)
end
local function enum(e, prefix)
return assert(type(e) == 'string' and (prefix and C[prefix..e] or C[e]) or e, 'invalid enum value')
local v = type(e) == 'string' and (prefix and C[prefix..e] or C[e]) or e
return assert(v, 'invalid enum value')
end
--client library info
@ -582,7 +589,7 @@ local function fetch_row(res, numeric, assoc, decode, field_count, fields, t)
local values = C.mysql_fetch_row(res)
if values == NULL then
if res.conn ~= NULL then --buffered read: check for errors
myerror(res.conn)
myerror(res.conn, 4)
end
return nil
end
@ -694,25 +701,25 @@ end
--prepared statements
local function sterror(stmt)
local function sterror(stmt, stacklevel)
local err = cstring(C.mysql_stmt_error(stmt))
if not err then return end
error(string.format('mysql error: %s', err))
error(string.format('mysql error: %s', err), stacklevel or 3)
end
local function stcheckz(stmt, ret)
if ret == 0 then return end
sterror(stmt)
sterror(stmt, 4)
end
local function stcheckbool(stmt, ret)
if ret == 1 then return end
sterror(stmt)
sterror(stmt, 4)
end
local function stcheckh(stmt, ret)
if ret ~= NULL then return ret end
sterror(stmt)
sterror(stmt, 4)
end
function conn.prepare(mysql, query)
@ -774,11 +781,12 @@ end
function stmt.result_metadata(stmt)
local res = stcheckh(stmt, C.mysql_stmt_result_metadata(stmt))
return ffi.gc(res, C.mysql_free_result)
return res and ffi.gc(res, C.mysql_free_result)
end
function stmt.fields(stmt)
local res = stmt:result_metadata()
if not res then return nil end
local fields = res:fields()
return function()
local i, info = fields()
@ -1183,6 +1191,7 @@ function stmt.bind_result_types(stmt, maxsize)
local types = {}
local field_count = stmt:field_count()
local res = stmt:result_metadata()
if not res then return nil end
for i=1,field_count do
local ftype, size, unsigned, decimals = res:field_type(i)
if ftype == 'decimal' then
@ -1199,7 +1208,7 @@ function stmt.bind_result_types(stmt, maxsize)
end
function stmt.bind_params(stmt, ...)
local types = type(...) == 'string' and {...} or ... or {}
local types = type((...)) == 'string' and {...} or ... or {}
assert(stmt:param_count() == #types, 'wrong number of param types')
local bb = params_bind_buffer(types)
stcheckz(stmt, C.mysql_stmt_bind_param(stmt, bb.buffer))