diff --git a/md5.lua b/md5.lua index 5f75fef..9a6381d 100644 --- a/md5.lua +++ b/md5.lua @@ -35,177 +35,181 @@ local char, byte, format, rep, sub = local bit_or, bit_and, bit_not, bit_xor, bit_rshift, bit_lshift local ok, bit = pcall(require, 'bit') -if not ok then ok, bit = pcall(require, 'bit32') end - if ok then - - bit_not = bit.bnot - - local tobit = function(n) - return n <= 0x7fffffff and n or -(bit_not(n) + 1) - end - - local normalize = function(f) - return function(a,b) return tobit(f(tobit(a), tobit(b))) end - end - - bit_or, bit_and, bit_xor = normalize(bit.bor), normalize(bit.band), normalize(bit.bxor) - bit_rshift, bit_lshift = normalize(bit.rshift), normalize(bit.lshift) - + bit_or, bit_and, bit_not, bit_xor, bit_rshift, bit_lshift = bit.bor, bit.band, bit.bnot, bit.xor, bit.rshift, bit.lshift else + ok, bit = pcall(require, 'bit32') - local function tbl2number(tbl) - local result = 0 - local power = 1 - for i = 1, #tbl do - result = result + tbl[i] * power - power = power * 2 + if ok then + + bit_not = bit.bnot + + local tobit = function(n) + return n <= 0x7fffffff and n or -(bit_not(n) + 1) end - return result - end - local function expand(t1, t2) - local big, small = t1, t2 - if(#big < #small) then - big, small = small, big + local normalize = function(f) + return function(a,b) return tobit(f(tobit(a), tobit(b))) end end - -- expand small - for i = #small + 1, #big do - small[i] = 0 - end - end - local to_bits -- needs to be declared before bit_not + bit_or, bit_and, bit_xor = normalize(bit.bor), normalize(bit.band), normalize(bit.bxor) + bit_rshift, bit_lshift = normalize(bit.rshift), normalize(bit.lshift) - bit_not = function(n) - local tbl = to_bits(n) - local size = math.max(#tbl, 32) - for i = 1, size do - if(tbl[i] == 1) then - tbl[i] = 0 - else - tbl[i] = 1 + else + + local function tbl2number(tbl) + local result = 0 + local power = 1 + for i = 1, #tbl do + result = result + tbl[i] * power + power = power * 2 end - end - return tbl2number(tbl) - end - - -- defined as local above - to_bits = function (n) - if(n < 0) then - -- negative - return to_bits(bit_not(math.abs(n)) + 1) - end - -- to bits table - local tbl = {} - local cnt = 1 - local last - while n > 0 do - last = n % 2 - tbl[cnt] = last - n = (n-last)/2 - cnt = cnt + 1 + return result end - return tbl - end - - bit_or = function(m, n) - local tbl_m = to_bits(m) - local tbl_n = to_bits(n) - expand(tbl_m, tbl_n) - - local tbl = {} - for i = 1, #tbl_m do - if(tbl_m[i]== 0 and tbl_n[i] == 0) then - tbl[i] = 0 - else - tbl[i] = 1 + local function expand(t1, t2) + local big, small = t1, t2 + if(#big < #small) then + big, small = small, big + end + -- expand small + for i = #small + 1, #big do + small[i] = 0 end end - return tbl2number(tbl) - end + local to_bits -- needs to be declared before bit_not - bit_and = function(m, n) - local tbl_m = to_bits(m) - local tbl_n = to_bits(n) - expand(tbl_m, tbl_n) - - local tbl = {} - for i = 1, #tbl_m do - if(tbl_m[i]== 0 or tbl_n[i] == 0) then - tbl[i] = 0 - else - tbl[i] = 1 + bit_not = function(n) + local tbl = to_bits(n) + local size = math.max(#tbl, 32) + for i = 1, size do + if(tbl[i] == 1) then + tbl[i] = 0 + else + tbl[i] = 1 + end end + return tbl2number(tbl) end - return tbl2number(tbl) - end - - bit_xor = function(m, n) - local tbl_m = to_bits(m) - local tbl_n = to_bits(n) - expand(tbl_m, tbl_n) - - local tbl = {} - for i = 1, #tbl_m do - if(tbl_m[i] ~= tbl_n[i]) then - tbl[i] = 1 - else - tbl[i] = 0 + -- defined as local above + to_bits = function (n) + if(n < 0) then + -- negative + return to_bits(bit_not(math.abs(n)) + 1) end + -- to bits table + local tbl = {} + local cnt = 1 + local last + while n > 0 do + last = n % 2 + tbl[cnt] = last + n = (n-last)/2 + cnt = cnt + 1 + end + + return tbl end - return tbl2number(tbl) + bit_or = function(m, n) + local tbl_m = to_bits(m) + local tbl_n = to_bits(n) + expand(tbl_m, tbl_n) + + local tbl = {} + for i = 1, #tbl_m do + if(tbl_m[i]== 0 and tbl_n[i] == 0) then + tbl[i] = 0 + else + tbl[i] = 1 + end + end + + return tbl2number(tbl) + end + + bit_and = function(m, n) + local tbl_m = to_bits(m) + local tbl_n = to_bits(n) + expand(tbl_m, tbl_n) + + local tbl = {} + for i = 1, #tbl_m do + if(tbl_m[i]== 0 or tbl_n[i] == 0) then + tbl[i] = 0 + else + tbl[i] = 1 + end + end + + return tbl2number(tbl) + end + + bit_xor = function(m, n) + local tbl_m = to_bits(m) + local tbl_n = to_bits(n) + expand(tbl_m, tbl_n) + + local tbl = {} + for i = 1, #tbl_m do + if(tbl_m[i] ~= tbl_n[i]) then + tbl[i] = 1 + else + tbl[i] = 0 + end + end + + return tbl2number(tbl) + end + + bit_rshift = function(n, bits) + local high_bit = 0 + if(n < 0) then + -- negative + n = bit_not(math.abs(n)) + 1 + high_bit = 0x80000000 + end + + local floor = math.floor + + for i=1, bits do + n = n/2 + n = bit_or(floor(n), high_bit) + end + return floor(n) + end + + bit_lshift = function(n, bits) + if(n < 0) then + -- negative + n = bit_not(math.abs(n)) + 1 + end + + for i=1, bits do + n = n*2 + end + return bit_and(n, 0xFFFFFFFF) + end end - bit_rshift = function(n, bits) - local high_bit = 0 - if(n < 0) then - -- negative - n = bit_not(math.abs(n)) + 1 - high_bit = 0x80000000 - end - - local floor = math.floor - - for i=1, bits do - n = n/2 - n = bit_or(floor(n), high_bit) - end - return floor(n) + -- convert little-endian 32-bit int to a 4-char string + local function lei2str(i) + local f=function (s) return char( bit_and( bit_rshift(i, s), 255)) end + return f(0)..f(8)..f(16)..f(24) end - bit_lshift = function(n, bits) - if(n < 0) then - -- negative - n = bit_not(math.abs(n)) + 1 + -- convert raw string to big-endian int + local function str2bei(s) + local v=0 + for i=1, #s do + v = v * 256 + byte(s, i) end - - for i=1, bits do - n = n*2 - end - return bit_and(n, 0xFFFFFFFF) + return v end end --- convert little-endian 32-bit int to a 4-char string -local function lei2str(i) - local f=function (s) return char( bit_and( bit_rshift(i, s), 255)) end - return f(0)..f(8)..f(16)..f(24) -end - --- convert raw string to big-endian int -local function str2bei(s) - local v=0 - for i=1, #s do - v = v * 256 + byte(s, i) - end - return v -end - -- convert raw string to little-endian int local function str2lei(s) local v=0