diff --git a/md5.lua b/md5.lua index b99792b..f9a0e9a 100644 --- a/md5.lua +++ b/md5.lua @@ -75,24 +75,27 @@ else end end + local bits_not, bits_or, bits_and, bits_not, bits_xor local to_bits -- needs to be declared before bit_not - function bit_not(n) - local tbl = to_bits(n) - local size = math.max(#tbl, 32) - for i = 1, size do - if(tbl[i] == 1) then + bits_not = function(tbl) + for i=1, math.max(#tbl, 32) do + if tbl[i] == 1 then tbl[i] = 0 else tbl[i] = 1 end end - return tbl2number(tbl) + return tbl + end + + bit_not = function(n) + return tbl2number(bits_not(to_bits(n))) end -- defined as local above to_bits = function (n) - if(n < 0) then + if n < 0 then -- negative return to_bits(bit_not(math.abs(n)) + 1) end @@ -110,61 +113,66 @@ else return tbl end - function bit_or(m, n) - local tbl_m = to_bits(m) - local tbl_n = to_bits(n) + bits_or = function(tbl_m, tbl_n) expand(tbl_m, tbl_n) local tbl = {} for i = 1, #tbl_m do - if(tbl_m[i]== 0 and tbl_n[i] == 0) then + if(tbl_m[i] == 0 and tbl_n[i] == 0) then tbl[i] = 0 else tbl[i] = 1 end end - return tbl2number(tbl) + return tbl end - function bit_and(m, n) - local tbl_m = to_bits(m) - local tbl_n = to_bits(n) + bit_or = function(m, n) + return tbl2number(bits_or(tobits(m), tobits(n))) + end + + bits_and = function(tbl_m, tbl_n) expand(tbl_m, tbl_n) local tbl = {} for i = 1, #tbl_m do - if(tbl_m[i]== 0 or tbl_n[i] == 0) then + if tbl_m[i] == 0 or tbl_n[i] == 0 then tbl[i] = 0 else tbl[i] = 1 end end - return tbl2number(tbl) + return tbl end - function bit_xor(m, n) - local tbl_m = to_bits(m) - local tbl_n = to_bits(n) + bit_and = function(m, n) + return tbl2number(bits_and(tobits(m), tobits(n))) + end + + bits_xor = function(tbl_m, tbl_n) expand(tbl_m, tbl_n) local tbl = {} for i = 1, #tbl_m do - if(tbl_m[i] ~= tbl_n[i]) then + if tbl_m[i] ~= tbl_n[i] then tbl[i] = 1 else tbl[i] = 0 end end - return tbl2number(tbl) + return tbl end - function bit_rshift(n, bits) + bit_xor = function(m, n) + return tbl2number(bits_xor(tobits(m), tobits(n))) + end + + bit_rshift = function(n, bits) local high_bit = 0 - if(n < 0) then - -- negative + if n < 0 then n = bit_not(math.abs(n)) + 1 high_bit = 2147483648 -- 0x80000000 end @@ -178,9 +186,8 @@ else return floor(n) end - function bit_lshift(n, bits) - if(n < 0) then - -- negative + bit_lshift = function(n, bits) + if n < 0 then n = bit_not(math.abs(n)) + 1 end