diff --git a/mysql_client.lua b/mysql_client.lua index 9a9fb1c..eb349ab 100644 --- a/mysql_client.lua +++ b/mysql_client.lua @@ -21,6 +21,8 @@ local shl = bit.lshift local shr = bit.rshift local tohex = bit.tohex local concat = table.concat +local floor = math.floor +local tonumber = tonumber local buffer = glue.buffer local dynarray = glue.dynarray @@ -599,9 +601,9 @@ local function get_time(buf, time_format) end local sign = get_u8(buf) == 1 and -1 or 1 local days = get_u4(buf) * sign - local H = get_u8(buf) - local M = get_u8(buf) - local S = get_u8(buf) + local H = get_u8(buf) + local M = get_u8(buf) + local S = get_u8(buf) local ms = len == 8 and 0 or get_u32(buf) return time_format == '*t' and {days = days, hour = H, min = M, sec = S + ms / 10^6} @@ -612,11 +614,59 @@ local function get_time(buf, time_format) end local function set_datetime(buf, t) - + local y, m, d, H, M, S + if type(t) == 'string' then + y, m, d, t = t:match'^(%d+)-(%d+)-(%d+)(.*)' + H, M, S = t:match' (%d+):(%d+):([%d.]+)' + y = tonumber(y) + m = tonumber(m) + d = tonumber(d) + H = tonumber(H) or 0 + M = tonumber(M) or 0 + S = tonumber(S) or 0 + else + y = t.year + m = t.month + d = t.day + H = t.hour + M = t.min + S = t.sec + end + local ms = (S - floor(S)) * 10^6 + set_u8 (buf, 11) + set_u16(buf, y) + set_u8 (buf, m) + set_u8 (buf, d) + set_u8 (buf, H) + set_u8 (buf, M) + set_u8 (buf, S) + set_u32(buf, ms) end local function set_time(buf, t) - + local days, H, M, S + if type(t) == 'string' then + local d, rest = t:match'^([+%-%d]+)d (.*)' + days = d and tonumber(d) or 0 + t = d and rest or t + H, M, S = t:match'^(%d+):(%d+):([%d.]+)$' + H = tonumber(H) + M = tonumber(M) + S = tonumber(S) + else + days = (t.days or 0) + floor(t.hour / 24) + H = t.hour % 24 + M = t.min + S = t.sec + end + local ms = (S - floor(S)) * 10^6 + set_u8 (buf, 12) + set_u8 (buf, days < 0 and 1 or 0) + set_u32(buf, math.abs(days)) + set_u8 (buf, H) + set_u8 (buf, M) + set_u8 (buf, S) + set_u32(buf, ms) end local function set_u8(buf, x) @@ -1006,7 +1056,7 @@ function conn:read_result(opt) if self.state == 'read_binary' then check(get_u8(buf) == 0, 'invalid row packet') - local nulls_len = math.floor((#cols + 7 + 2) / 8) + local nulls_len = floor((#cols + 7 + 2) / 8) local nulls, nulls_offset = buf(nulls_len) for i, col in ipairs(cols) do local null_byte = shr(i-1+2, 3) + nulls_offset @@ -1141,7 +1191,7 @@ function stmt:exec(...) set_u8(buf, stmt.cursor) set_u32(buf, 1) --iteration-count, must be 1 if #stmt.params > 0 then - local nulls_len = math.floor((#stmt.params + 7) / 8) + local nulls_len = floor((#stmt.params + 7) / 8) local nulls = ffi.new('uint8_t[?]', nulls_len) for i = 1, #stmt.params do local val = select(i, ...)