diff --git a/CLAUDE.md b/CLAUDE.md index 2281558..629d976 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -84,6 +84,22 @@ The `_compat` module provides automatic feature detection and optimized primitiv This ensures optimal performance on modern Lua while maintaining compatibility with older versions. +### Raw Operations (bit32 and bit64) + +The bit32 and bit64 modules provide `raw_*` variants for performance-critical code: +- `raw_band`, `raw_bor`, `raw_bxor`, `raw_bnot` +- `raw_lshift`, `raw_rshift`, `raw_arshift` +- `raw_rol`, `raw_ror` +- `raw_add` + +These bypass the `to_unsigned()` wrapper used on LuaJIT, returning signed +integers when the high bit is set. On other platforms they behave identically +to regular operations. Use for crypto code and tight loops where the sign +interpretation doesn't matter. + +Note: Shift amounts >= 32 (or >= 64 for bit64) have platform-specific behavior +in raw functions. Callers should keep shift amounts in valid range. + ## Testing Tests use Lua table-based vectors for easy maintenance: diff --git a/README.md b/README.md index c94ee6a..9b0f5d6 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,41 @@ local xored = bit64.bxor( Example: `0x123456789ABCDEF0` is represented as `{0x12345678, 0x9ABCDEF0}` +### Raw Operations (Performance-Critical Code) + +The `bit32` and `bit64` modules provide `raw_*` variants for performance-critical +code paths like cryptographic operations. These bypass the unsigned conversion +wrapper used on LuaJIT, providing direct access to native bit library functions. + +**Available functions (bit32 and bit64):** +- `raw_band`, `raw_bor`, `raw_bxor`, `raw_bnot` +- `raw_lshift`, `raw_rshift`, `raw_arshift` +- `raw_rol`, `raw_ror` +- `raw_add` + +**Important:** On LuaJIT, raw_* functions may return **signed** 32-bit integers: + +```lua +local bit32 = require("bitn").bit32 + +-- Regular function (always unsigned) +bit32.bxor(0x80000000, 1) --> 2147483649 + +-- Raw function (signed on LuaJIT) +bit32.raw_bxor(0x80000000, 1) --> -2147483647 (same bit pattern!) +``` + +**When to use raw_* functions:** +- Chained bitwise operations (XOR, AND, OR, rotate) where sign doesn't matter +- Crypto algorithms (ChaCha20, etc.) that only care about bit patterns +- Tight loops where the `to_unsigned()` overhead is measurable + +**When NOT to use raw_* functions:** +- When comparing results (`<`, `>`, `==`) +- When doing arithmetic on the results +- When formatting/displaying values +- When you need guaranteed unsigned semantics + ## Development ### Setup diff --git a/src/bitn/_compat.lua b/src/bitn/_compat.lua index 917b6df..9b2afe7 100644 --- a/src/bitn/_compat.lua +++ b/src/bitn/_compat.lua @@ -97,6 +97,33 @@ if ok and result then return native_band(r, MASK32) end + -- Raw operations provide direct access to native bit functions without the + -- to_unsigned() wrapper. On Lua 5.3+, these are identical to wrapped versions + -- since native operators already return unsigned values. + -- Shifts must mask to 32 bits since native operators work on 64-bit values. + _compat.raw_band = native_band + _compat.raw_bor = native_bor + _compat.raw_bxor = native_bxor + _compat.raw_bnot = function(a) + return native_band(native_bnot(a), MASK32) + end + _compat.raw_lshift = function(a, n) + if n >= 32 then + return 0 + end + return native_band(native_lshift(a, n), MASK32) + end + _compat.raw_rshift = function(a, n) + if n >= 32 then + return 0 + end + return native_rshift(native_band(a, MASK32), n) + end + _compat.raw_arshift = _compat.arshift + -- No native rol/ror on Lua 5.3+ + _compat.raw_rol = nil + _compat.raw_ror = nil + return _compat end end @@ -218,6 +245,25 @@ if bit_lib then end end + -- Raw operations provide direct access to native bit functions without the + -- to_unsigned() wrapper. On LuaJIT, these return signed 32-bit integers. + -- On Lua 5.2 (bit32 library), these are identical to wrapped versions. + _compat.raw_band = bit_band + _compat.raw_bor = bit_bor + _compat.raw_bxor = bit_bxor + _compat.raw_bnot = bit_bnot + _compat.raw_lshift = bit_lshift + _compat.raw_rshift = bit_rshift + _compat.raw_arshift = bit_arshift + -- rol/ror only available on LuaJIT (bit library), not Lua 5.2 (bit32 library) + if bit_lib.rol then + _compat.raw_rol = bit_lib.rol + _compat.raw_ror = bit_lib.ror + else + _compat.raw_rol = nil + _compat.raw_ror = nil + end + return _compat end @@ -317,4 +363,16 @@ function _compat.arshift(a, n) return r end +-- Raw operations for pure Lua fallback are identical to wrapped versions +-- since there's no native library to bypass. +_compat.raw_band = _compat.band +_compat.raw_bor = _compat.bor +_compat.raw_bxor = _compat.bxor +_compat.raw_bnot = _compat.bnot +_compat.raw_lshift = _compat.lshift +_compat.raw_rshift = _compat.rshift +_compat.raw_arshift = _compat.arshift +_compat.raw_rol = nil +_compat.raw_ror = nil + return _compat diff --git a/src/bitn/bit16.lua b/src/bitn/bit16.lua index d912a15..ebb7f47 100644 --- a/src/bitn/bit16.lua +++ b/src/bitn/bit16.lua @@ -10,18 +10,17 @@ local _compat = require("bitn._compat") -- Cache methods as locals for faster access local compat_band = _compat.band +local compat_bnot = _compat.bnot local compat_bor = _compat.bor local compat_bxor = _compat.bxor -local compat_bnot = _compat.bnot local compat_lshift = _compat.lshift local compat_rshift = _compat.rshift local impl_name = _compat.impl_name +local math_floor = math.floor -- 16-bit mask constant local MASK16 = 0xFFFF -local math_floor = math.floor - -------------------------------------------------------------------------------- -- Core operations -------------------------------------------------------------------------------- diff --git a/src/bitn/bit32.lua b/src/bitn/bit32.lua index 7e04187..ffa8469 100644 --- a/src/bitn/bit32.lua +++ b/src/bitn/bit32.lua @@ -9,24 +9,42 @@ local bit32 = {} local _compat = require("bitn._compat") -- Cache methods as locals for faster access +local compat_arshift = _compat.arshift local compat_band = _compat.band +local compat_bnot = _compat.bnot local compat_bor = _compat.bor local compat_bxor = _compat.bxor -local compat_bnot = _compat.bnot local compat_lshift = _compat.lshift +local compat_raw_arshift = _compat.raw_arshift +local compat_raw_band = _compat.raw_band +local compat_raw_bnot = _compat.raw_bnot +local compat_raw_bor = _compat.raw_bor +local compat_raw_bxor = _compat.raw_bxor +local compat_raw_lshift = _compat.raw_lshift +local compat_raw_rol = _compat.raw_rol +local compat_raw_ror = _compat.raw_ror +local compat_raw_rshift = _compat.raw_rshift local compat_rshift = _compat.rshift -local compat_arshift = _compat.arshift +local compat_to_unsigned = _compat.to_unsigned local impl_name = _compat.impl_name +local math_floor = math.floor -- 32-bit mask constant local MASK32 = 0xFFFFFFFF -local math_floor = math.floor - -------------------------------------------------------------------------------- -- Core operations -------------------------------------------------------------------------------- +--- Convert signed 32-bit value to unsigned. +--- On LuaJIT, bit operations return signed 32-bit integers. This function +--- converts them to unsigned by adding 2^32 to negative values. +--- @param n number Potentially signed 32-bit value +--- @return integer result Unsigned 32-bit value (0 to 0xFFFFFFFF) +function bit32.to_unsigned(n) + return compat_to_unsigned(n) +end + --- Ensure value fits in 32-bit unsigned integer. --- @param n number Input value --- @return integer result 32-bit unsigned integer (0 to 0xFFFFFFFF) @@ -126,6 +144,68 @@ function bit32.add(a, b) return compat_band(compat_band(a, MASK32) + compat_band(b, MASK32), MASK32) end +-------------------------------------------------------------------------------- +-- Raw (zero-overhead) operations +-------------------------------------------------------------------------------- +-- These functions provide direct access to the underlying bit library without +-- unsigned conversion. On LuaJIT, results may be negative when the high bit +-- is set. The bit pattern is identical to the regular function. +-- Use for performance-critical code where sign interpretation doesn't matter. + +--- Raw bitwise AND (may return signed on LuaJIT). +--- @type fun(a: integer, b: integer): integer +--- @see bit32.band For guaranteed unsigned results +bit32.raw_band = compat_raw_band + +--- Raw bitwise OR (may return signed on LuaJIT). +--- @type fun(a: integer, b: integer): integer +--- @see bit32.bor For guaranteed unsigned results +bit32.raw_bor = compat_raw_bor + +--- Raw bitwise XOR (may return signed on LuaJIT). +--- @type fun(a: integer, b: integer): integer +--- @see bit32.bxor For guaranteed unsigned results +bit32.raw_bxor = compat_raw_bxor + +--- Raw bitwise NOT (may return signed on LuaJIT). +--- @type fun(a: integer): integer +--- @see bit32.bnot For guaranteed unsigned results +bit32.raw_bnot = compat_raw_bnot + +--- Raw left shift (may return signed on LuaJIT). +--- @type fun(a: integer, n: integer): integer +--- @see bit32.lshift For guaranteed unsigned results +bit32.raw_lshift = compat_raw_lshift + +--- Raw logical right shift (may return signed on LuaJIT). +--- @type fun(a: integer, n: integer): integer +--- @see bit32.rshift For guaranteed unsigned results +bit32.raw_rshift = compat_raw_rshift + +--- Raw arithmetic right shift (may return signed on LuaJIT). +--- @type fun(a: integer, n: integer): integer +--- @see bit32.arshift For guaranteed unsigned results +bit32.raw_arshift = compat_raw_arshift + +--- Raw left rotate (uses native bit.rol on LuaJIT, falls back to computed otherwise). +--- @type fun(x: integer, n: integer): integer +--- @see bit32.rol For guaranteed unsigned results +bit32.raw_rol = compat_raw_rol or bit32.rol + +--- Raw right rotate (uses native bit.ror on LuaJIT, falls back to computed otherwise). +--- @type fun(x: integer, n: integer): integer +--- @see bit32.ror For guaranteed unsigned results +bit32.raw_ror = compat_raw_ror or bit32.ror + +--- Raw 32-bit addition with overflow handling. +--- @param a integer First operand (32-bit) +--- @param b integer Second operand (32-bit) +--- @return integer result Result of (a + b) mod 2^32 (signed on LuaJIT, unsigned elsewhere) +--- @see bit32.add For guaranteed unsigned results +function bit32.raw_add(a, b) + return compat_raw_band(a + b, MASK32) +end + -------------------------------------------------------------------------------- -- Byte conversion functions -------------------------------------------------------------------------------- @@ -138,7 +218,12 @@ local string_byte = string.byte --- @return string bytes 4-byte string in big-endian order function bit32.u32_to_be_bytes(n) n = compat_band(n, MASK32) - return string_char(math_floor(n / 16777216) % 256, math_floor(n / 65536) % 256, math_floor(n / 256) % 256, n % 256) + return string_char( + math_floor(n / 16777216) % 256, + math_floor(n / 65536) % 256, + math_floor(n / 256) % 256, + math_floor(n % 256) + ) end --- Convert 32-bit unsigned integer to 4 bytes (little-endian). @@ -146,7 +231,12 @@ end --- @return string bytes 4-byte string in little-endian order function bit32.u32_to_le_bytes(n) n = compat_band(n, MASK32) - return string_char(n % 256, math_floor(n / 256) % 256, math_floor(n / 65536) % 256, math_floor(n / 16777216) % 256) + return string_char( + math_floor(n % 256), + math_floor(n / 256) % 256, + math_floor(n / 65536) % 256, + math_floor(n / 16777216) % 256 + ) end --- Convert 4 bytes to 32-bit unsigned integer (big-endian). @@ -196,6 +286,14 @@ function bit32.selftest() { name = "mask(-1)", fn = bit32.mask, inputs = { -1 }, expected = 0xFFFFFFFF }, { name = "mask(-256)", fn = bit32.mask, inputs = { -256 }, expected = 0xFFFFFF00 }, + -- to_unsigned tests + { name = "to_unsigned(0)", fn = bit32.to_unsigned, inputs = { 0 }, expected = 0 }, + { name = "to_unsigned(1)", fn = bit32.to_unsigned, inputs = { 1 }, expected = 1 }, + { name = "to_unsigned(0x7FFFFFFF)", fn = bit32.to_unsigned, inputs = { 0x7FFFFFFF }, expected = 0x7FFFFFFF }, + { name = "to_unsigned(-1)", fn = bit32.to_unsigned, inputs = { -1 }, expected = 0xFFFFFFFF }, + { name = "to_unsigned(-2147483648)", fn = bit32.to_unsigned, inputs = { -2147483648 }, expected = 0x80000000 }, + { name = "to_unsigned(-2147483647)", fn = bit32.to_unsigned, inputs = { -2147483647 }, expected = 0x80000001 }, + -- band tests { name = "band(0xFF00FF00, 0x00FF00FF)", fn = bit32.band, inputs = { 0xFF00FF00, 0x00FF00FF }, expected = 0 }, { @@ -309,25 +407,25 @@ function bit32.selftest() name = "u32_to_be_bytes(0)", fn = bit32.u32_to_be_bytes, inputs = { 0 }, - expected = string.char(0x00, 0x00, 0x00, 0x00), + expected = string_char(0x00, 0x00, 0x00, 0x00), }, { name = "u32_to_be_bytes(1)", fn = bit32.u32_to_be_bytes, inputs = { 1 }, - expected = string.char(0x00, 0x00, 0x00, 0x01), + expected = string_char(0x00, 0x00, 0x00, 0x01), }, { name = "u32_to_be_bytes(0x12345678)", fn = bit32.u32_to_be_bytes, inputs = { 0x12345678 }, - expected = string.char(0x12, 0x34, 0x56, 0x78), + expected = string_char(0x12, 0x34, 0x56, 0x78), }, { name = "u32_to_be_bytes(0xFFFFFFFF)", fn = bit32.u32_to_be_bytes, inputs = { 0xFFFFFFFF }, - expected = string.char(0xFF, 0xFF, 0xFF, 0xFF), + expected = string_char(0xFF, 0xFF, 0xFF, 0xFF), }, -- u32_to_le_bytes tests @@ -335,50 +433,50 @@ function bit32.selftest() name = "u32_to_le_bytes(0)", fn = bit32.u32_to_le_bytes, inputs = { 0 }, - expected = string.char(0x00, 0x00, 0x00, 0x00), + expected = string_char(0x00, 0x00, 0x00, 0x00), }, { name = "u32_to_le_bytes(1)", fn = bit32.u32_to_le_bytes, inputs = { 1 }, - expected = string.char(0x01, 0x00, 0x00, 0x00), + expected = string_char(0x01, 0x00, 0x00, 0x00), }, { name = "u32_to_le_bytes(0x12345678)", fn = bit32.u32_to_le_bytes, inputs = { 0x12345678 }, - expected = string.char(0x78, 0x56, 0x34, 0x12), + expected = string_char(0x78, 0x56, 0x34, 0x12), }, { name = "u32_to_le_bytes(0xFFFFFFFF)", fn = bit32.u32_to_le_bytes, inputs = { 0xFFFFFFFF }, - expected = string.char(0xFF, 0xFF, 0xFF, 0xFF), + expected = string_char(0xFF, 0xFF, 0xFF, 0xFF), }, -- be_bytes_to_u32 tests { name = "be_bytes_to_u32(0x00000000)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0x00, 0x00, 0x00, 0x00) }, + inputs = { string_char(0x00, 0x00, 0x00, 0x00) }, expected = 0, }, { name = "be_bytes_to_u32(0x00000001)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0x00, 0x00, 0x00, 0x01) }, + inputs = { string_char(0x00, 0x00, 0x00, 0x01) }, expected = 1, }, { name = "be_bytes_to_u32(0x12345678)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0x12, 0x34, 0x56, 0x78) }, + inputs = { string_char(0x12, 0x34, 0x56, 0x78) }, expected = 0x12345678, }, { name = "be_bytes_to_u32(0xFFFFFFFF)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0xFF, 0xFF, 0xFF, 0xFF) }, + inputs = { string_char(0xFF, 0xFF, 0xFF, 0xFF) }, expected = 0xFFFFFFFF, }, @@ -386,25 +484,25 @@ function bit32.selftest() { name = "le_bytes_to_u32(0x00000000)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0x00, 0x00, 0x00, 0x00) }, + inputs = { string_char(0x00, 0x00, 0x00, 0x00) }, expected = 0, }, { name = "le_bytes_to_u32(0x00000001)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0x01, 0x00, 0x00, 0x00) }, + inputs = { string_char(0x01, 0x00, 0x00, 0x00) }, expected = 1, }, { name = "le_bytes_to_u32(0x12345678)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0x78, 0x56, 0x34, 0x12) }, + inputs = { string_char(0x78, 0x56, 0x34, 0x12) }, expected = 0x12345678, }, { name = "le_bytes_to_u32(0xFFFFFFFF)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0xFF, 0xFF, 0xFF, 0xFF) }, + inputs = { string_char(0xFF, 0xFF, 0xFF, 0xFF) }, expected = 0xFFFFFFFF, }, } @@ -420,10 +518,10 @@ function bit32.selftest() if type(test.expected) == "string" then local exp_hex, got_hex = "", "" for i = 1, #test.expected do - exp_hex = exp_hex .. string.format("%02X", string.byte(test.expected, i)) + exp_hex = exp_hex .. string.format("%02X", string_byte(test.expected, i)) end for i = 1, #result do - got_hex = got_hex .. string.format("%02X", string.byte(result, i)) + got_hex = got_hex .. string.format("%02X", string_byte(result, i)) end print(" Expected: " .. exp_hex) print(" Got: " .. got_hex) @@ -434,6 +532,191 @@ function bit32.selftest() end end + -- Test raw_* operations + print("\n Testing raw_* operations...") + + local raw_tests = { + -- Core bitwise (test high-bit cases where sign matters) + { + name = "raw_band(0xFFFFFFFF, 0x80000000)", + fn = function() + return bit32.to_unsigned(bit32.raw_band(0xFFFFFFFF, 0x80000000)) + end, + expected = bit32.band(0xFFFFFFFF, 0x80000000), + }, + { + name = "raw_bor(0x80000000, 0x00000001)", + fn = function() + return bit32.to_unsigned(bit32.raw_bor(0x80000000, 0x00000001)) + end, + expected = bit32.bor(0x80000000, 0x00000001), + }, + { + name = "raw_bxor(0xAAAAAAAA, 0x55555555)", + fn = function() + return bit32.to_unsigned(bit32.raw_bxor(0xAAAAAAAA, 0x55555555)) + end, + expected = bit32.bxor(0xAAAAAAAA, 0x55555555), + }, + { + name = "raw_bnot(0)", + fn = function() + return bit32.to_unsigned(bit32.raw_bnot(0)) + end, + expected = bit32.bnot(0), + }, + { + name = "raw_bnot(0x80000000)", + fn = function() + return bit32.to_unsigned(bit32.raw_bnot(0x80000000)) + end, + expected = bit32.bnot(0x80000000), + }, + + -- Shifts + { + name = "raw_lshift(1, 31)", + fn = function() + return bit32.to_unsigned(bit32.raw_lshift(1, 31)) + end, + expected = bit32.lshift(1, 31), + }, + { + name = "raw_rshift(0x80000000, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_rshift(0x80000000, 1)) + end, + expected = bit32.rshift(0x80000000, 1), + }, + { + name = "raw_arshift(0x80000000, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_arshift(0x80000000, 1)) + end, + expected = bit32.arshift(0x80000000, 1), + }, + + -- Shift masking (ensure 32-bit semantics on all platforms) + -- Note: n >= 32 behavior is platform-specific for raw shifts; callers should use n in 0-31 + { + name = "raw_lshift(0x12345678, 16) masks to 32 bits", + fn = function() + return bit32.to_unsigned(bit32.raw_lshift(0x12345678, 16)) + end, + expected = 0x56780000, + }, + { + name = "raw_rshift(0xFFFFFFFF, 16) masks to 32 bits", + fn = function() + return bit32.to_unsigned(bit32.raw_rshift(0xFFFFFFFF, 16)) + end, + expected = 0x0000FFFF, + }, + + -- Addition overflow + { + name = "raw_add(0xFFFFFFFF, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_add(0xFFFFFFFF, 1)) + end, + expected = bit32.add(0xFFFFFFFF, 1), + }, + { + name = "raw_add(0x80000000, 0x80000000)", + fn = function() + return bit32.to_unsigned(bit32.raw_add(0x80000000, 0x80000000)) + end, + expected = bit32.add(0x80000000, 0x80000000), + }, + } + + for _, test in ipairs(raw_tests) do + total = total + 1 + local result = test.fn() + if result == test.expected then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name) + print(string.format(" Expected: 0x%08X", test.expected)) + print(string.format(" Got: 0x%08X", result)) + end + end + + -- Test raw_rol/raw_ror (always available - falls back to computed if no native) + print("\n Testing raw_rol/raw_ror...") + local rol_ror_tests = { + { + name = "raw_rol(0x80000000, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_rol(0x80000000, 1)) + end, + expected = bit32.rol(0x80000000, 1), + }, + { + name = "raw_rol(0x12345678, 8)", + fn = function() + return bit32.to_unsigned(bit32.raw_rol(0x12345678, 8)) + end, + expected = bit32.rol(0x12345678, 8), + }, + { + name = "raw_ror(1, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_ror(1, 1)) + end, + expected = bit32.ror(1, 1), + }, + { + name = "raw_ror(0x12345678, 8)", + fn = function() + return bit32.to_unsigned(bit32.raw_ror(0x12345678, 8)) + end, + expected = bit32.ror(0x12345678, 8), + }, + } + + for _, test in ipairs(rol_ror_tests) do + total = total + 1 + local result = test.fn() + if result == test.expected then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name) + print(string.format(" Expected: 0x%08X", test.expected)) + print(string.format(" Got: 0x%08X", result)) + end + end + + -- Test zero-overhead on LuaJIT (identity check) + if _compat.is_luajit then + print("\n Testing zero-overhead (LuaJIT function identity)...") + local bit = require("bit") + + local identity_tests = { + { name = "raw_band == bit.band", got = bit32.raw_band, expected = bit.band }, + { name = "raw_bor == bit.bor", got = bit32.raw_bor, expected = bit.bor }, + { name = "raw_bxor == bit.bxor", got = bit32.raw_bxor, expected = bit.bxor }, + { name = "raw_bnot == bit.bnot", got = bit32.raw_bnot, expected = bit.bnot }, + { name = "raw_lshift == bit.lshift", got = bit32.raw_lshift, expected = bit.lshift }, + { name = "raw_rshift == bit.rshift", got = bit32.raw_rshift, expected = bit.rshift }, + { name = "raw_arshift == bit.arshift", got = bit32.raw_arshift, expected = bit.arshift }, + { name = "raw_rol == bit.rol", got = bit32.raw_rol, expected = bit.rol }, + { name = "raw_ror == bit.ror", got = bit32.raw_ror, expected = bit.ror }, + } + + for _, test in ipairs(identity_tests) do + total = total + 1 + if rawequal(test.got, test.expected) then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name .. " (not identical function reference)") + end + end + end + print(string.format("\n32-bit operations: %d/%d tests passed\n", passed, total)) return passed == total end diff --git a/src/bitn/bit64.lua b/src/bitn/bit64.lua index c7021ad..eb59e31 100644 --- a/src/bitn/bit64.lua +++ b/src/bitn/bit64.lua @@ -9,20 +9,27 @@ local bit64 = {} local bit32 = require("bitn.bit32") local _compat = require("bitn._compat") -local impl_name = _compat.impl_name --- Cache bit32 methods as locals for faster access +-- Cache methods as locals for faster access +local bit32_arshift = bit32.arshift local bit32_band = bit32.band +local bit32_be_bytes_to_u32 = bit32.be_bytes_to_u32 +local bit32_bnot = bit32.bnot local bit32_bor = bit32.bor local bit32_bxor = bit32.bxor -local bit32_bnot = bit32.bnot +local bit32_le_bytes_to_u32 = bit32.le_bytes_to_u32 local bit32_lshift = bit32.lshift +local bit32_raw_arshift = bit32.raw_arshift +local bit32_raw_band = bit32.raw_band +local bit32_raw_bnot = bit32.raw_bnot +local bit32_raw_bor = bit32.raw_bor +local bit32_raw_bxor = bit32.raw_bxor +local bit32_raw_lshift = bit32.raw_lshift +local bit32_raw_rshift = bit32.raw_rshift local bit32_rshift = bit32.rshift -local bit32_arshift = bit32.arshift local bit32_u32_to_be_bytes = bit32.u32_to_be_bytes local bit32_u32_to_le_bytes = bit32.u32_to_le_bytes -local bit32_be_bytes_to_u32 = bit32.be_bytes_to_u32 -local bit32_le_bytes_to_u32 = bit32.le_bytes_to_u32 +local impl_name = _compat.impl_name -- Private metatable for Int64 type identification local Int64Meta = { __name = "Int64" } @@ -35,11 +42,21 @@ local Int64Meta = { __name = "Int64" } -------------------------------------------------------------------------------- --- Create a new Int64 value with metatable marker. +--- Normalizes signed 32-bit values to unsigned (for LuaJIT raw_* compatibility). --- @param high? integer Upper 32 bits (default: 0) --- @param low? integer Lower 32 bits (default: 0) --- @return Int64HighLow value Int64 value with metatable marker function bit64.new(high, low) - return setmetatable({ high or 0, low or 0 }, Int64Meta) + high = high or 0 + low = low or 0 + -- Normalize signed to unsigned (handles LuaJIT raw_* results) + if high < 0 then + high = high + 0x100000000 + end + if low < 0 then + low = low + 0x100000000 + end + return setmetatable({ high, low }, Int64Meta) end --- Check if a value is an Int64 (created by bit64 functions). @@ -373,6 +390,178 @@ bit64.asr = bit64.arshift --- Alias for is_int64 (compatibility with older API). bit64.isInt64 = bit64.is_int64 +-------------------------------------------------------------------------------- +-- Raw (zero-overhead) operations +-------------------------------------------------------------------------------- +-- These functions use bit32.raw_* internally for performance-critical code. +-- On LuaJIT, the internal 32-bit values may be signed, but bit patterns are correct. +-- Use for crypto code and tight loops where sign interpretation doesn't matter. + +--- Raw bitwise AND (uses bit32.raw_band internally). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} AND result +function bit64.raw_band(a, b) + return bit64.new(bit32_raw_band(a[1], b[1]), bit32_raw_band(a[2], b[2])) +end + +--- Raw bitwise OR (uses bit32.raw_bor internally). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} OR result +function bit64.raw_bor(a, b) + return bit64.new(bit32_raw_bor(a[1], b[1]), bit32_raw_bor(a[2], b[2])) +end + +--- Raw bitwise XOR (uses bit32.raw_bxor internally). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} XOR result +function bit64.raw_bxor(a, b) + return bit64.new(bit32_raw_bxor(a[1], b[1]), bit32_raw_bxor(a[2], b[2])) +end + +--- Raw bitwise NOT (uses bit32.raw_bnot internally). +--- @param a Int64HighLow Operand {high, low} +--- @return Int64HighLow result {high, low} NOT result +function bit64.raw_bnot(a) + return bit64.new(bit32_raw_bnot(a[1]), bit32_raw_bnot(a[2])) +end + +--- Raw left shift (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to shift {high, low} +--- @param n integer Number of positions to shift (must be >= 0) +--- @return Int64HighLow result {high, low} shifted value +function bit64.raw_lshift(x, n) + if n == 0 then + return bit64.new(x[1], x[2]) + elseif n >= 64 then + return bit64.new(0, 0) + elseif n >= 32 then + return bit64.new(bit32_raw_lshift(x[2], n - 32), 0) + else + local new_high = bit32_raw_bor(bit32_raw_lshift(x[1], n), bit32_raw_rshift(x[2], 32 - n)) + local new_low = bit32_raw_lshift(x[2], n) + return bit64.new(new_high, new_low) + end +end + +--- Raw logical right shift (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to shift {high, low} +--- @param n integer Number of positions to shift (must be >= 0) +--- @return Int64HighLow result {high, low} shifted value +function bit64.raw_rshift(x, n) + if n == 0 then + return bit64.new(x[1], x[2]) + elseif n >= 64 then + return bit64.new(0, 0) + elseif n >= 32 then + return bit64.new(0, bit32_raw_rshift(x[1], n - 32)) + else + local new_low = bit32_raw_bor(bit32_raw_rshift(x[2], n), bit32_raw_lshift(x[1], 32 - n)) + local new_high = bit32_raw_rshift(x[1], n) + return bit64.new(new_high, new_low) + end +end + +--- Raw arithmetic right shift (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to shift {high, low} +--- @param n integer Number of positions to shift (must be >= 0) +--- @return Int64HighLow result {high, low} shifted value +function bit64.raw_arshift(x, n) + if n == 0 then + return bit64.new(x[1], x[2]) + end + + local is_negative = bit32_raw_band(x[1], 0x80000000) ~= 0 + + if n >= 64 then + if is_negative then + return bit64.new(0xFFFFFFFF, 0xFFFFFFFF) + else + return bit64.new(0, 0) + end + elseif n >= 32 then + local new_low = bit32_raw_arshift(x[1], n - 32) + local new_high = is_negative and 0xFFFFFFFF or 0 + return bit64.new(new_high, new_low) + else + local new_low = bit32_raw_bor(bit32_raw_rshift(x[2], n), bit32_raw_lshift(x[1], 32 - n)) + local new_high = bit32_raw_arshift(x[1], n) + return bit64.new(new_high, new_low) + end +end + +--- Raw left rotate (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to rotate {high, low} +--- @param n integer Number of positions to rotate +--- @return Int64HighLow result {high, low} rotated value +function bit64.raw_rol(x, n) + n = n % 64 + if n == 0 then + return bit64.new(x[1], x[2]) + end + + local high, low = x[1], x[2] + + if n == 32 then + return bit64.new(low, high) + elseif n < 32 then + local new_high = bit32_raw_bor(bit32_raw_lshift(high, n), bit32_raw_rshift(low, 32 - n)) + local new_low = bit32_raw_bor(bit32_raw_lshift(low, n), bit32_raw_rshift(high, 32 - n)) + return bit64.new(new_high, new_low) + else + n = n - 32 + local new_high = bit32_raw_bor(bit32_raw_lshift(low, n), bit32_raw_rshift(high, 32 - n)) + local new_low = bit32_raw_bor(bit32_raw_lshift(high, n), bit32_raw_rshift(low, 32 - n)) + return bit64.new(new_high, new_low) + end +end + +--- Raw right rotate (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to rotate {high, low} +--- @param n integer Number of positions to rotate +--- @return Int64HighLow result {high, low} rotated value +function bit64.raw_ror(x, n) + n = n % 64 + if n == 0 then + return bit64.new(x[1], x[2]) + end + + local high, low = x[1], x[2] + + if n == 32 then + return bit64.new(low, high) + elseif n < 32 then + local new_low = bit32_raw_bor(bit32_raw_rshift(low, n), bit32_raw_lshift(high, 32 - n)) + local new_high = bit32_raw_bor(bit32_raw_rshift(high, n), bit32_raw_lshift(low, 32 - n)) + return bit64.new(new_high, new_low) + else + n = n - 32 + local new_low = bit32_raw_bor(bit32_raw_rshift(high, n), bit32_raw_lshift(low, 32 - n)) + local new_high = bit32_raw_bor(bit32_raw_rshift(low, n), bit32_raw_lshift(high, 32 - n)) + return bit64.new(new_high, new_low) + end +end + +--- Raw 64-bit addition (uses bit32.raw_band for masking). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} sum +function bit64.raw_add(a, b) + local low = a[2] + b[2] + local high = a[1] + b[1] + + if low >= 0x100000000 then + high = high + 1 + low = low % 0x100000000 + end + + high = high % 0x100000000 + + return bit64.new(high, low) +end + -------------------------------------------------------------------------------- -- Self-test -------------------------------------------------------------------------------- @@ -821,6 +1010,18 @@ function bit64.selftest() print(" FAIL: new() with no args creates {0, 0}") end + -- Test bit64.new() normalizes negative values (LuaJIT raw_* compatibility) + total = total + 1 + local neg_val = bit64.new(-1, -2147483648) -- -1 -> 0xFFFFFFFF, -2147483648 -> 0x80000000 + if bit64.is_int64(neg_val) and neg_val[1] == 0xFFFFFFFF and neg_val[2] == 0x80000000 then + print(" PASS: new() normalizes negative values to unsigned") + passed = passed + 1 + else + print(" FAIL: new() normalizes negative values to unsigned") + print(string.format(" Expected: {0x%08X, 0x%08X}", 0xFFFFFFFF, 0x80000000)) + print(string.format(" Got: {0x%08X, 0x%08X}", neg_val[1], neg_val[2])) + end + -- Test is_int64() returns false for regular tables total = total + 1 local plain_table = { 0x12345678, 0x9ABCDEF0 } @@ -845,61 +1046,61 @@ function bit64.selftest() { name = "band", fn = function() - return bit64.band({ 1, 2 }, { 3, 4 }) + return bit64.band(bit64.new(1, 2), bit64.new(3, 4)) end, }, { name = "bor", fn = function() - return bit64.bor({ 1, 2 }, { 3, 4 }) + return bit64.bor(bit64.new(1, 2), bit64.new(3, 4)) end, }, { name = "bxor", fn = function() - return bit64.bxor({ 1, 2 }, { 3, 4 }) + return bit64.bxor(bit64.new(1, 2), bit64.new(3, 4)) end, }, { name = "bnot", fn = function() - return bit64.bnot({ 1, 2 }) + return bit64.bnot(bit64.new(1, 2)) end, }, { name = "lshift", fn = function() - return bit64.lshift({ 1, 2 }, 1) + return bit64.lshift(bit64.new(1, 2), 1) end, }, { name = "rshift", fn = function() - return bit64.rshift({ 1, 2 }, 1) + return bit64.rshift(bit64.new(1, 2), 1) end, }, { name = "arshift", fn = function() - return bit64.arshift({ 1, 2 }, 1) + return bit64.arshift(bit64.new(1, 2), 1) end, }, { name = "rol", fn = function() - return bit64.rol({ 1, 2 }, 1) + return bit64.rol(bit64.new(1, 2), 1) end, }, { name = "ror", fn = function() - return bit64.ror({ 1, 2 }, 1) + return bit64.ror(bit64.new(1, 2), 1) end, }, { name = "add", fn = function() - return bit64.add({ 1, 2 }, { 3, 4 }) + return bit64.add(bit64.new(1, 2), bit64.new(3, 4)) end, }, { @@ -928,7 +1129,7 @@ function bit64.selftest() end -- Test to_number strict mode error case - print("\nRunning to_number strict mode tests...") + print("\nRunning to_number/from_number edge case tests...") total = total + 1 local ok, err = pcall(function() bit64.to_number(bit64.new(0x00200000, 0x00000000), true) -- 2^53, exceeds 53-bit @@ -945,6 +1146,220 @@ function bit64.selftest() end end + -- Test to_number pass-through for number input + total = total + 1 + local num_input = 12345 + local num_result = bit64.to_number(num_input) + if num_result == num_input then + print(" PASS: to_number passes through number input unchanged") + passed = passed + 1 + else + print(" FAIL: to_number passes through number input unchanged") + print(" Expected: " .. tostring(num_input)) + print(" Got: " .. tostring(num_result)) + end + + -- Test to_number errors on plain table (non-Int64) + total = total + 1 + ok, err = pcall(function() + bit64.to_number({ 1, 2 }) -- plain table, not Int64 + end) + if not ok and type(err) == "string" and string.find(err, "not a valid Int64") then + print(" PASS: to_number errors on plain table (non-Int64)") + passed = passed + 1 + else + print(" FAIL: to_number errors on plain table (non-Int64)") + if ok then + print(" Expected error but got success") + else + print(" Expected 'not a valid Int64' error but got: " .. tostring(err)) + end + end + + -- Test from_number pass-through for Int64 input + total = total + 1 + local int64_input = bit64.new(0x12345678, 0x9ABCDEF0) + local int64_result = bit64.from_number(int64_input) + if rawequal(int64_result, int64_input) then + print(" PASS: from_number passes through Int64 input unchanged") + passed = passed + 1 + else + print(" FAIL: from_number passes through Int64 input unchanged") + print(" Expected same reference, got different object") + end + + -- Test raw_* operations + print("\n Testing raw_* operations...") + + local raw_tests = { + -- Core bitwise (test high-bit cases where sign matters) + { + name = "raw_band(new(0xFFFFFFFF, 0x80000000), new(0x80000000, 0xFFFFFFFF))", + fn = function() + return bit64.raw_band(bit64.new(0xFFFFFFFF, 0x80000000), bit64.new(0x80000000, 0xFFFFFFFF)) + end, + expected = bit64.new(0x80000000, 0x80000000), + }, + { + name = "raw_bor(new(0x80000000, 0), new(0, 0x80000000))", + fn = function() + return bit64.raw_bor(bit64.new(0x80000000, 0), bit64.new(0, 0x80000000)) + end, + expected = bit64.new(0x80000000, 0x80000000), + }, + { + name = "raw_bxor(new(0xAAAAAAAA, 0x55555555), new(0x55555555, 0xAAAAAAAA))", + fn = function() + return bit64.raw_bxor(bit64.new(0xAAAAAAAA, 0x55555555), bit64.new(0x55555555, 0xAAAAAAAA)) + end, + expected = bit64.new(0xFFFFFFFF, 0xFFFFFFFF), + }, + { + name = "raw_bnot(new(0, 0))", + fn = function() + return bit64.raw_bnot(bit64.new(0, 0)) + end, + expected = bit64.new(0xFFFFFFFF, 0xFFFFFFFF), + }, + + -- Shifts + { + name = "raw_lshift(new(0, 1), 63)", + fn = function() + return bit64.raw_lshift(bit64.new(0, 1), 63) + end, + expected = bit64.new(0x80000000, 0), + }, + { + name = "raw_rshift(new(0x80000000, 0), 63)", + fn = function() + return bit64.raw_rshift(bit64.new(0x80000000, 0), 63) + end, + expected = bit64.new(0, 1), + }, + { + name = "raw_arshift(new(0x80000000, 0), 32)", + fn = function() + return bit64.raw_arshift(bit64.new(0x80000000, 0), 32) + end, + expected = bit64.new(0xFFFFFFFF, 0x80000000), + }, + + -- Rotates + { + name = "raw_rol(new(0x12345678, 0x9ABCDEF0), 16)", + fn = function() + return bit64.raw_rol(bit64.new(0x12345678, 0x9ABCDEF0), 16) + end, + expected = bit64.new(0x56789ABC, 0xDEF01234), + }, + { + name = "raw_ror(new(0x12345678, 0x9ABCDEF0), 16)", + fn = function() + return bit64.raw_ror(bit64.new(0x12345678, 0x9ABCDEF0), 16) + end, + expected = bit64.new(0xDEF01234, 0x56789ABC), + }, + + -- Addition + { + name = "raw_add(new(0xFFFFFFFF, 0xFFFFFFFF), new(0, 1))", + fn = function() + return bit64.raw_add(bit64.new(0xFFFFFFFF, 0xFFFFFFFF), bit64.new(0, 1)) + end, + expected = bit64.new(0, 0), + }, + } + + for _, test in ipairs(raw_tests) do + total = total + 1 + local result = test.fn() + if eq64(result, test.expected) then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name) + print(" Expected: " .. fmt64(test.expected)) + print(" Got: " .. fmt64(result)) + end + end + + -- Test that raw_* operations return Int64 + print("\n Testing raw_* operations return Int64...") + local raw_ops_returning_int64 = { + { + name = "raw_band", + fn = function() + return bit64.raw_band(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + { + name = "raw_bor", + fn = function() + return bit64.raw_bor(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + { + name = "raw_bxor", + fn = function() + return bit64.raw_bxor(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + { + name = "raw_bnot", + fn = function() + return bit64.raw_bnot(bit64.new(1, 2)) + end, + }, + { + name = "raw_lshift", + fn = function() + return bit64.raw_lshift(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_rshift", + fn = function() + return bit64.raw_rshift(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_arshift", + fn = function() + return bit64.raw_arshift(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_rol", + fn = function() + return bit64.raw_rol(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_ror", + fn = function() + return bit64.raw_ror(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_add", + fn = function() + return bit64.raw_add(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + } + + for _, op in ipairs(raw_ops_returning_int64) do + total = total + 1 + local result = op.fn() + if bit64.is_int64(result) then + print(" PASS: " .. op.name .. "() returns Int64") + passed = passed + 1 + else + print(" FAIL: " .. op.name .. "() returns Int64") + end + end + print(string.format("\n64-bit operations: %d/%d tests passed\n", passed, total)) return passed == total end