From 4fa60a055b1ab789c3ffc9bc11890c6bdf1a2e4e Mon Sep 17 00:00:00 2001 From: Derek Miller Date: Wed, 28 Jan 2026 18:04:50 -0600 Subject: [PATCH] Add raw_* methods to bit32 and bit64 modules for zero-overhead bit operations This adds raw_* variants of all bit operations that bypass the to_unsigned() wrapper on LuaJIT, providing direct access to native bit library functions. For XOR, AND, OR, and rotate operations, the signedness doesn't matter since bit patterns are identical. New functions in bit32: - raw_band, raw_bor, raw_bxor, raw_bnot - raw_lshift, raw_rshift, raw_arshift - raw_rol, raw_ror - raw_add New functions in bit64: - raw_band, raw_bor, raw_bxor, raw_bnot - raw_lshift, raw_rshift, raw_arshift - raw_rol, raw_ror - raw_add Key implementation details: - On LuaJIT: raw_* are direct references to bit.* functions (zero overhead) - On Lua 5.3+: raw_* use native operators with 32-bit masking for shifts - On Lua 5.2: raw_* are direct references to bit32.* (already unsigned) - On Pure Lua: raw_* fall back to safe implementations (no native library) - raw_rol/raw_ror fall back to computed versions on non-LuaJIT (not nil) Note: Shift amounts >= 32 (or >= 64 for bit64) have platform-specific behavior. LuaJIT wraps (n % 32), Lua 5.3+ returns 0. Callers should keep n in valid range. --- CLAUDE.md | 16 ++ README.md | 35 ++++ src/bitn/_compat.lua | 58 ++++++ src/bitn/bit16.lua | 5 +- src/bitn/bit32.lua | 331 ++++++++++++++++++++++++++++--- src/bitn/bit64.lua | 451 +++++++++++++++++++++++++++++++++++++++++-- 6 files changed, 851 insertions(+), 45 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 2281558..629d976 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -84,6 +84,22 @@ The `_compat` module provides automatic feature detection and optimized primitiv This ensures optimal performance on modern Lua while maintaining compatibility with older versions. +### Raw Operations (bit32 and bit64) + +The bit32 and bit64 modules provide `raw_*` variants for performance-critical code: +- `raw_band`, `raw_bor`, `raw_bxor`, `raw_bnot` +- `raw_lshift`, `raw_rshift`, `raw_arshift` +- `raw_rol`, `raw_ror` +- `raw_add` + +These bypass the `to_unsigned()` wrapper used on LuaJIT, returning signed +integers when the high bit is set. On other platforms they behave identically +to regular operations. Use for crypto code and tight loops where the sign +interpretation doesn't matter. + +Note: Shift amounts >= 32 (or >= 64 for bit64) have platform-specific behavior +in raw functions. Callers should keep shift amounts in valid range. + ## Testing Tests use Lua table-based vectors for easy maintenance: diff --git a/README.md b/README.md index c94ee6a..9b0f5d6 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,41 @@ local xored = bit64.bxor( Example: `0x123456789ABCDEF0` is represented as `{0x12345678, 0x9ABCDEF0}` +### Raw Operations (Performance-Critical Code) + +The `bit32` and `bit64` modules provide `raw_*` variants for performance-critical +code paths like cryptographic operations. These bypass the unsigned conversion +wrapper used on LuaJIT, providing direct access to native bit library functions. + +**Available functions (bit32 and bit64):** +- `raw_band`, `raw_bor`, `raw_bxor`, `raw_bnot` +- `raw_lshift`, `raw_rshift`, `raw_arshift` +- `raw_rol`, `raw_ror` +- `raw_add` + +**Important:** On LuaJIT, raw_* functions may return **signed** 32-bit integers: + +```lua +local bit32 = require("bitn").bit32 + +-- Regular function (always unsigned) +bit32.bxor(0x80000000, 1) --> 2147483649 + +-- Raw function (signed on LuaJIT) +bit32.raw_bxor(0x80000000, 1) --> -2147483647 (same bit pattern!) +``` + +**When to use raw_* functions:** +- Chained bitwise operations (XOR, AND, OR, rotate) where sign doesn't matter +- Crypto algorithms (ChaCha20, etc.) that only care about bit patterns +- Tight loops where the `to_unsigned()` overhead is measurable + +**When NOT to use raw_* functions:** +- When comparing results (`<`, `>`, `==`) +- When doing arithmetic on the results +- When formatting/displaying values +- When you need guaranteed unsigned semantics + ## Development ### Setup diff --git a/src/bitn/_compat.lua b/src/bitn/_compat.lua index 917b6df..9b2afe7 100644 --- a/src/bitn/_compat.lua +++ b/src/bitn/_compat.lua @@ -97,6 +97,33 @@ if ok and result then return native_band(r, MASK32) end + -- Raw operations provide direct access to native bit functions without the + -- to_unsigned() wrapper. On Lua 5.3+, these are identical to wrapped versions + -- since native operators already return unsigned values. + -- Shifts must mask to 32 bits since native operators work on 64-bit values. + _compat.raw_band = native_band + _compat.raw_bor = native_bor + _compat.raw_bxor = native_bxor + _compat.raw_bnot = function(a) + return native_band(native_bnot(a), MASK32) + end + _compat.raw_lshift = function(a, n) + if n >= 32 then + return 0 + end + return native_band(native_lshift(a, n), MASK32) + end + _compat.raw_rshift = function(a, n) + if n >= 32 then + return 0 + end + return native_rshift(native_band(a, MASK32), n) + end + _compat.raw_arshift = _compat.arshift + -- No native rol/ror on Lua 5.3+ + _compat.raw_rol = nil + _compat.raw_ror = nil + return _compat end end @@ -218,6 +245,25 @@ if bit_lib then end end + -- Raw operations provide direct access to native bit functions without the + -- to_unsigned() wrapper. On LuaJIT, these return signed 32-bit integers. + -- On Lua 5.2 (bit32 library), these are identical to wrapped versions. + _compat.raw_band = bit_band + _compat.raw_bor = bit_bor + _compat.raw_bxor = bit_bxor + _compat.raw_bnot = bit_bnot + _compat.raw_lshift = bit_lshift + _compat.raw_rshift = bit_rshift + _compat.raw_arshift = bit_arshift + -- rol/ror only available on LuaJIT (bit library), not Lua 5.2 (bit32 library) + if bit_lib.rol then + _compat.raw_rol = bit_lib.rol + _compat.raw_ror = bit_lib.ror + else + _compat.raw_rol = nil + _compat.raw_ror = nil + end + return _compat end @@ -317,4 +363,16 @@ function _compat.arshift(a, n) return r end +-- Raw operations for pure Lua fallback are identical to wrapped versions +-- since there's no native library to bypass. +_compat.raw_band = _compat.band +_compat.raw_bor = _compat.bor +_compat.raw_bxor = _compat.bxor +_compat.raw_bnot = _compat.bnot +_compat.raw_lshift = _compat.lshift +_compat.raw_rshift = _compat.rshift +_compat.raw_arshift = _compat.arshift +_compat.raw_rol = nil +_compat.raw_ror = nil + return _compat diff --git a/src/bitn/bit16.lua b/src/bitn/bit16.lua index d912a15..ebb7f47 100644 --- a/src/bitn/bit16.lua +++ b/src/bitn/bit16.lua @@ -10,18 +10,17 @@ local _compat = require("bitn._compat") -- Cache methods as locals for faster access local compat_band = _compat.band +local compat_bnot = _compat.bnot local compat_bor = _compat.bor local compat_bxor = _compat.bxor -local compat_bnot = _compat.bnot local compat_lshift = _compat.lshift local compat_rshift = _compat.rshift local impl_name = _compat.impl_name +local math_floor = math.floor -- 16-bit mask constant local MASK16 = 0xFFFF -local math_floor = math.floor - -------------------------------------------------------------------------------- -- Core operations -------------------------------------------------------------------------------- diff --git a/src/bitn/bit32.lua b/src/bitn/bit32.lua index 7e04187..ffa8469 100644 --- a/src/bitn/bit32.lua +++ b/src/bitn/bit32.lua @@ -9,24 +9,42 @@ local bit32 = {} local _compat = require("bitn._compat") -- Cache methods as locals for faster access +local compat_arshift = _compat.arshift local compat_band = _compat.band +local compat_bnot = _compat.bnot local compat_bor = _compat.bor local compat_bxor = _compat.bxor -local compat_bnot = _compat.bnot local compat_lshift = _compat.lshift +local compat_raw_arshift = _compat.raw_arshift +local compat_raw_band = _compat.raw_band +local compat_raw_bnot = _compat.raw_bnot +local compat_raw_bor = _compat.raw_bor +local compat_raw_bxor = _compat.raw_bxor +local compat_raw_lshift = _compat.raw_lshift +local compat_raw_rol = _compat.raw_rol +local compat_raw_ror = _compat.raw_ror +local compat_raw_rshift = _compat.raw_rshift local compat_rshift = _compat.rshift -local compat_arshift = _compat.arshift +local compat_to_unsigned = _compat.to_unsigned local impl_name = _compat.impl_name +local math_floor = math.floor -- 32-bit mask constant local MASK32 = 0xFFFFFFFF -local math_floor = math.floor - -------------------------------------------------------------------------------- -- Core operations -------------------------------------------------------------------------------- +--- Convert signed 32-bit value to unsigned. +--- On LuaJIT, bit operations return signed 32-bit integers. This function +--- converts them to unsigned by adding 2^32 to negative values. +--- @param n number Potentially signed 32-bit value +--- @return integer result Unsigned 32-bit value (0 to 0xFFFFFFFF) +function bit32.to_unsigned(n) + return compat_to_unsigned(n) +end + --- Ensure value fits in 32-bit unsigned integer. --- @param n number Input value --- @return integer result 32-bit unsigned integer (0 to 0xFFFFFFFF) @@ -126,6 +144,68 @@ function bit32.add(a, b) return compat_band(compat_band(a, MASK32) + compat_band(b, MASK32), MASK32) end +-------------------------------------------------------------------------------- +-- Raw (zero-overhead) operations +-------------------------------------------------------------------------------- +-- These functions provide direct access to the underlying bit library without +-- unsigned conversion. On LuaJIT, results may be negative when the high bit +-- is set. The bit pattern is identical to the regular function. +-- Use for performance-critical code where sign interpretation doesn't matter. + +--- Raw bitwise AND (may return signed on LuaJIT). +--- @type fun(a: integer, b: integer): integer +--- @see bit32.band For guaranteed unsigned results +bit32.raw_band = compat_raw_band + +--- Raw bitwise OR (may return signed on LuaJIT). +--- @type fun(a: integer, b: integer): integer +--- @see bit32.bor For guaranteed unsigned results +bit32.raw_bor = compat_raw_bor + +--- Raw bitwise XOR (may return signed on LuaJIT). +--- @type fun(a: integer, b: integer): integer +--- @see bit32.bxor For guaranteed unsigned results +bit32.raw_bxor = compat_raw_bxor + +--- Raw bitwise NOT (may return signed on LuaJIT). +--- @type fun(a: integer): integer +--- @see bit32.bnot For guaranteed unsigned results +bit32.raw_bnot = compat_raw_bnot + +--- Raw left shift (may return signed on LuaJIT). +--- @type fun(a: integer, n: integer): integer +--- @see bit32.lshift For guaranteed unsigned results +bit32.raw_lshift = compat_raw_lshift + +--- Raw logical right shift (may return signed on LuaJIT). +--- @type fun(a: integer, n: integer): integer +--- @see bit32.rshift For guaranteed unsigned results +bit32.raw_rshift = compat_raw_rshift + +--- Raw arithmetic right shift (may return signed on LuaJIT). +--- @type fun(a: integer, n: integer): integer +--- @see bit32.arshift For guaranteed unsigned results +bit32.raw_arshift = compat_raw_arshift + +--- Raw left rotate (uses native bit.rol on LuaJIT, falls back to computed otherwise). +--- @type fun(x: integer, n: integer): integer +--- @see bit32.rol For guaranteed unsigned results +bit32.raw_rol = compat_raw_rol or bit32.rol + +--- Raw right rotate (uses native bit.ror on LuaJIT, falls back to computed otherwise). +--- @type fun(x: integer, n: integer): integer +--- @see bit32.ror For guaranteed unsigned results +bit32.raw_ror = compat_raw_ror or bit32.ror + +--- Raw 32-bit addition with overflow handling. +--- @param a integer First operand (32-bit) +--- @param b integer Second operand (32-bit) +--- @return integer result Result of (a + b) mod 2^32 (signed on LuaJIT, unsigned elsewhere) +--- @see bit32.add For guaranteed unsigned results +function bit32.raw_add(a, b) + return compat_raw_band(a + b, MASK32) +end + -------------------------------------------------------------------------------- -- Byte conversion functions -------------------------------------------------------------------------------- @@ -138,7 +218,12 @@ local string_byte = string.byte --- @return string bytes 4-byte string in big-endian order function bit32.u32_to_be_bytes(n) n = compat_band(n, MASK32) - return string_char(math_floor(n / 16777216) % 256, math_floor(n / 65536) % 256, math_floor(n / 256) % 256, n % 256) + return string_char( + math_floor(n / 16777216) % 256, + math_floor(n / 65536) % 256, + math_floor(n / 256) % 256, + math_floor(n % 256) + ) end --- Convert 32-bit unsigned integer to 4 bytes (little-endian). @@ -146,7 +231,12 @@ end --- @return string bytes 4-byte string in little-endian order function bit32.u32_to_le_bytes(n) n = compat_band(n, MASK32) - return string_char(n % 256, math_floor(n / 256) % 256, math_floor(n / 65536) % 256, math_floor(n / 16777216) % 256) + return string_char( + math_floor(n % 256), + math_floor(n / 256) % 256, + math_floor(n / 65536) % 256, + math_floor(n / 16777216) % 256 + ) end --- Convert 4 bytes to 32-bit unsigned integer (big-endian). @@ -196,6 +286,14 @@ function bit32.selftest() { name = "mask(-1)", fn = bit32.mask, inputs = { -1 }, expected = 0xFFFFFFFF }, { name = "mask(-256)", fn = bit32.mask, inputs = { -256 }, expected = 0xFFFFFF00 }, + -- to_unsigned tests + { name = "to_unsigned(0)", fn = bit32.to_unsigned, inputs = { 0 }, expected = 0 }, + { name = "to_unsigned(1)", fn = bit32.to_unsigned, inputs = { 1 }, expected = 1 }, + { name = "to_unsigned(0x7FFFFFFF)", fn = bit32.to_unsigned, inputs = { 0x7FFFFFFF }, expected = 0x7FFFFFFF }, + { name = "to_unsigned(-1)", fn = bit32.to_unsigned, inputs = { -1 }, expected = 0xFFFFFFFF }, + { name = "to_unsigned(-2147483648)", fn = bit32.to_unsigned, inputs = { -2147483648 }, expected = 0x80000000 }, + { name = "to_unsigned(-2147483647)", fn = bit32.to_unsigned, inputs = { -2147483647 }, expected = 0x80000001 }, + -- band tests { name = "band(0xFF00FF00, 0x00FF00FF)", fn = bit32.band, inputs = { 0xFF00FF00, 0x00FF00FF }, expected = 0 }, { @@ -309,25 +407,25 @@ function bit32.selftest() name = "u32_to_be_bytes(0)", fn = bit32.u32_to_be_bytes, inputs = { 0 }, - expected = string.char(0x00, 0x00, 0x00, 0x00), + expected = string_char(0x00, 0x00, 0x00, 0x00), }, { name = "u32_to_be_bytes(1)", fn = bit32.u32_to_be_bytes, inputs = { 1 }, - expected = string.char(0x00, 0x00, 0x00, 0x01), + expected = string_char(0x00, 0x00, 0x00, 0x01), }, { name = "u32_to_be_bytes(0x12345678)", fn = bit32.u32_to_be_bytes, inputs = { 0x12345678 }, - expected = string.char(0x12, 0x34, 0x56, 0x78), + expected = string_char(0x12, 0x34, 0x56, 0x78), }, { name = "u32_to_be_bytes(0xFFFFFFFF)", fn = bit32.u32_to_be_bytes, inputs = { 0xFFFFFFFF }, - expected = string.char(0xFF, 0xFF, 0xFF, 0xFF), + expected = string_char(0xFF, 0xFF, 0xFF, 0xFF), }, -- u32_to_le_bytes tests @@ -335,50 +433,50 @@ function bit32.selftest() name = "u32_to_le_bytes(0)", fn = bit32.u32_to_le_bytes, inputs = { 0 }, - expected = string.char(0x00, 0x00, 0x00, 0x00), + expected = string_char(0x00, 0x00, 0x00, 0x00), }, { name = "u32_to_le_bytes(1)", fn = bit32.u32_to_le_bytes, inputs = { 1 }, - expected = string.char(0x01, 0x00, 0x00, 0x00), + expected = string_char(0x01, 0x00, 0x00, 0x00), }, { name = "u32_to_le_bytes(0x12345678)", fn = bit32.u32_to_le_bytes, inputs = { 0x12345678 }, - expected = string.char(0x78, 0x56, 0x34, 0x12), + expected = string_char(0x78, 0x56, 0x34, 0x12), }, { name = "u32_to_le_bytes(0xFFFFFFFF)", fn = bit32.u32_to_le_bytes, inputs = { 0xFFFFFFFF }, - expected = string.char(0xFF, 0xFF, 0xFF, 0xFF), + expected = string_char(0xFF, 0xFF, 0xFF, 0xFF), }, -- be_bytes_to_u32 tests { name = "be_bytes_to_u32(0x00000000)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0x00, 0x00, 0x00, 0x00) }, + inputs = { string_char(0x00, 0x00, 0x00, 0x00) }, expected = 0, }, { name = "be_bytes_to_u32(0x00000001)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0x00, 0x00, 0x00, 0x01) }, + inputs = { string_char(0x00, 0x00, 0x00, 0x01) }, expected = 1, }, { name = "be_bytes_to_u32(0x12345678)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0x12, 0x34, 0x56, 0x78) }, + inputs = { string_char(0x12, 0x34, 0x56, 0x78) }, expected = 0x12345678, }, { name = "be_bytes_to_u32(0xFFFFFFFF)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0xFF, 0xFF, 0xFF, 0xFF) }, + inputs = { string_char(0xFF, 0xFF, 0xFF, 0xFF) }, expected = 0xFFFFFFFF, }, @@ -386,25 +484,25 @@ function bit32.selftest() { name = "le_bytes_to_u32(0x00000000)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0x00, 0x00, 0x00, 0x00) }, + inputs = { string_char(0x00, 0x00, 0x00, 0x00) }, expected = 0, }, { name = "le_bytes_to_u32(0x00000001)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0x01, 0x00, 0x00, 0x00) }, + inputs = { string_char(0x01, 0x00, 0x00, 0x00) }, expected = 1, }, { name = "le_bytes_to_u32(0x12345678)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0x78, 0x56, 0x34, 0x12) }, + inputs = { string_char(0x78, 0x56, 0x34, 0x12) }, expected = 0x12345678, }, { name = "le_bytes_to_u32(0xFFFFFFFF)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0xFF, 0xFF, 0xFF, 0xFF) }, + inputs = { string_char(0xFF, 0xFF, 0xFF, 0xFF) }, expected = 0xFFFFFFFF, }, } @@ -420,10 +518,10 @@ function bit32.selftest() if type(test.expected) == "string" then local exp_hex, got_hex = "", "" for i = 1, #test.expected do - exp_hex = exp_hex .. string.format("%02X", string.byte(test.expected, i)) + exp_hex = exp_hex .. string.format("%02X", string_byte(test.expected, i)) end for i = 1, #result do - got_hex = got_hex .. string.format("%02X", string.byte(result, i)) + got_hex = got_hex .. string.format("%02X", string_byte(result, i)) end print(" Expected: " .. exp_hex) print(" Got: " .. got_hex) @@ -434,6 +532,191 @@ function bit32.selftest() end end + -- Test raw_* operations + print("\n Testing raw_* operations...") + + local raw_tests = { + -- Core bitwise (test high-bit cases where sign matters) + { + name = "raw_band(0xFFFFFFFF, 0x80000000)", + fn = function() + return bit32.to_unsigned(bit32.raw_band(0xFFFFFFFF, 0x80000000)) + end, + expected = bit32.band(0xFFFFFFFF, 0x80000000), + }, + { + name = "raw_bor(0x80000000, 0x00000001)", + fn = function() + return bit32.to_unsigned(bit32.raw_bor(0x80000000, 0x00000001)) + end, + expected = bit32.bor(0x80000000, 0x00000001), + }, + { + name = "raw_bxor(0xAAAAAAAA, 0x55555555)", + fn = function() + return bit32.to_unsigned(bit32.raw_bxor(0xAAAAAAAA, 0x55555555)) + end, + expected = bit32.bxor(0xAAAAAAAA, 0x55555555), + }, + { + name = "raw_bnot(0)", + fn = function() + return bit32.to_unsigned(bit32.raw_bnot(0)) + end, + expected = bit32.bnot(0), + }, + { + name = "raw_bnot(0x80000000)", + fn = function() + return bit32.to_unsigned(bit32.raw_bnot(0x80000000)) + end, + expected = bit32.bnot(0x80000000), + }, + + -- Shifts + { + name = "raw_lshift(1, 31)", + fn = function() + return bit32.to_unsigned(bit32.raw_lshift(1, 31)) + end, + expected = bit32.lshift(1, 31), + }, + { + name = "raw_rshift(0x80000000, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_rshift(0x80000000, 1)) + end, + expected = bit32.rshift(0x80000000, 1), + }, + { + name = "raw_arshift(0x80000000, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_arshift(0x80000000, 1)) + end, + expected = bit32.arshift(0x80000000, 1), + }, + + -- Shift masking (ensure 32-bit semantics on all platforms) + -- Note: n >= 32 behavior is platform-specific for raw shifts; callers should use n in 0-31 + { + name = "raw_lshift(0x12345678, 16) masks to 32 bits", + fn = function() + return bit32.to_unsigned(bit32.raw_lshift(0x12345678, 16)) + end, + expected = 0x56780000, + }, + { + name = "raw_rshift(0xFFFFFFFF, 16) masks to 32 bits", + fn = function() + return bit32.to_unsigned(bit32.raw_rshift(0xFFFFFFFF, 16)) + end, + expected = 0x0000FFFF, + }, + + -- Addition overflow + { + name = "raw_add(0xFFFFFFFF, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_add(0xFFFFFFFF, 1)) + end, + expected = bit32.add(0xFFFFFFFF, 1), + }, + { + name = "raw_add(0x80000000, 0x80000000)", + fn = function() + return bit32.to_unsigned(bit32.raw_add(0x80000000, 0x80000000)) + end, + expected = bit32.add(0x80000000, 0x80000000), + }, + } + + for _, test in ipairs(raw_tests) do + total = total + 1 + local result = test.fn() + if result == test.expected then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name) + print(string.format(" Expected: 0x%08X", test.expected)) + print(string.format(" Got: 0x%08X", result)) + end + end + + -- Test raw_rol/raw_ror (always available - falls back to computed if no native) + print("\n Testing raw_rol/raw_ror...") + local rol_ror_tests = { + { + name = "raw_rol(0x80000000, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_rol(0x80000000, 1)) + end, + expected = bit32.rol(0x80000000, 1), + }, + { + name = "raw_rol(0x12345678, 8)", + fn = function() + return bit32.to_unsigned(bit32.raw_rol(0x12345678, 8)) + end, + expected = bit32.rol(0x12345678, 8), + }, + { + name = "raw_ror(1, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_ror(1, 1)) + end, + expected = bit32.ror(1, 1), + }, + { + name = "raw_ror(0x12345678, 8)", + fn = function() + return bit32.to_unsigned(bit32.raw_ror(0x12345678, 8)) + end, + expected = bit32.ror(0x12345678, 8), + }, + } + + for _, test in ipairs(rol_ror_tests) do + total = total + 1 + local result = test.fn() + if result == test.expected then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name) + print(string.format(" Expected: 0x%08X", test.expected)) + print(string.format(" Got: 0x%08X", result)) + end + end + + -- Test zero-overhead on LuaJIT (identity check) + if _compat.is_luajit then + print("\n Testing zero-overhead (LuaJIT function identity)...") + local bit = require("bit") + + local identity_tests = { + { name = "raw_band == bit.band", got = bit32.raw_band, expected = bit.band }, + { name = "raw_bor == bit.bor", got = bit32.raw_bor, expected = bit.bor }, + { name = "raw_bxor == bit.bxor", got = bit32.raw_bxor, expected = bit.bxor }, + { name = "raw_bnot == bit.bnot", got = bit32.raw_bnot, expected = bit.bnot }, + { name = "raw_lshift == bit.lshift", got = bit32.raw_lshift, expected = bit.lshift }, + { name = "raw_rshift == bit.rshift", got = bit32.raw_rshift, expected = bit.rshift }, + { name = "raw_arshift == bit.arshift", got = bit32.raw_arshift, expected = bit.arshift }, + { name = "raw_rol == bit.rol", got = bit32.raw_rol, expected = bit.rol }, + { name = "raw_ror == bit.ror", got = bit32.raw_ror, expected = bit.ror }, + } + + for _, test in ipairs(identity_tests) do + total = total + 1 + if rawequal(test.got, test.expected) then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name .. " (not identical function reference)") + end + end + end + print(string.format("\n32-bit operations: %d/%d tests passed\n", passed, total)) return passed == total end diff --git a/src/bitn/bit64.lua b/src/bitn/bit64.lua index c7021ad..eb59e31 100644 --- a/src/bitn/bit64.lua +++ b/src/bitn/bit64.lua @@ -9,20 +9,27 @@ local bit64 = {} local bit32 = require("bitn.bit32") local _compat = require("bitn._compat") -local impl_name = _compat.impl_name --- Cache bit32 methods as locals for faster access +-- Cache methods as locals for faster access +local bit32_arshift = bit32.arshift local bit32_band = bit32.band +local bit32_be_bytes_to_u32 = bit32.be_bytes_to_u32 +local bit32_bnot = bit32.bnot local bit32_bor = bit32.bor local bit32_bxor = bit32.bxor -local bit32_bnot = bit32.bnot +local bit32_le_bytes_to_u32 = bit32.le_bytes_to_u32 local bit32_lshift = bit32.lshift +local bit32_raw_arshift = bit32.raw_arshift +local bit32_raw_band = bit32.raw_band +local bit32_raw_bnot = bit32.raw_bnot +local bit32_raw_bor = bit32.raw_bor +local bit32_raw_bxor = bit32.raw_bxor +local bit32_raw_lshift = bit32.raw_lshift +local bit32_raw_rshift = bit32.raw_rshift local bit32_rshift = bit32.rshift -local bit32_arshift = bit32.arshift local bit32_u32_to_be_bytes = bit32.u32_to_be_bytes local bit32_u32_to_le_bytes = bit32.u32_to_le_bytes -local bit32_be_bytes_to_u32 = bit32.be_bytes_to_u32 -local bit32_le_bytes_to_u32 = bit32.le_bytes_to_u32 +local impl_name = _compat.impl_name -- Private metatable for Int64 type identification local Int64Meta = { __name = "Int64" } @@ -35,11 +42,21 @@ local Int64Meta = { __name = "Int64" } -------------------------------------------------------------------------------- --- Create a new Int64 value with metatable marker. +--- Normalizes signed 32-bit values to unsigned (for LuaJIT raw_* compatibility). --- @param high? integer Upper 32 bits (default: 0) --- @param low? integer Lower 32 bits (default: 0) --- @return Int64HighLow value Int64 value with metatable marker function bit64.new(high, low) - return setmetatable({ high or 0, low or 0 }, Int64Meta) + high = high or 0 + low = low or 0 + -- Normalize signed to unsigned (handles LuaJIT raw_* results) + if high < 0 then + high = high + 0x100000000 + end + if low < 0 then + low = low + 0x100000000 + end + return setmetatable({ high, low }, Int64Meta) end --- Check if a value is an Int64 (created by bit64 functions). @@ -373,6 +390,178 @@ bit64.asr = bit64.arshift --- Alias for is_int64 (compatibility with older API). bit64.isInt64 = bit64.is_int64 +-------------------------------------------------------------------------------- +-- Raw (zero-overhead) operations +-------------------------------------------------------------------------------- +-- These functions use bit32.raw_* internally for performance-critical code. +-- On LuaJIT, the internal 32-bit values may be signed, but bit patterns are correct. +-- Use for crypto code and tight loops where sign interpretation doesn't matter. + +--- Raw bitwise AND (uses bit32.raw_band internally). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} AND result +function bit64.raw_band(a, b) + return bit64.new(bit32_raw_band(a[1], b[1]), bit32_raw_band(a[2], b[2])) +end + +--- Raw bitwise OR (uses bit32.raw_bor internally). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} OR result +function bit64.raw_bor(a, b) + return bit64.new(bit32_raw_bor(a[1], b[1]), bit32_raw_bor(a[2], b[2])) +end + +--- Raw bitwise XOR (uses bit32.raw_bxor internally). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} XOR result +function bit64.raw_bxor(a, b) + return bit64.new(bit32_raw_bxor(a[1], b[1]), bit32_raw_bxor(a[2], b[2])) +end + +--- Raw bitwise NOT (uses bit32.raw_bnot internally). +--- @param a Int64HighLow Operand {high, low} +--- @return Int64HighLow result {high, low} NOT result +function bit64.raw_bnot(a) + return bit64.new(bit32_raw_bnot(a[1]), bit32_raw_bnot(a[2])) +end + +--- Raw left shift (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to shift {high, low} +--- @param n integer Number of positions to shift (must be >= 0) +--- @return Int64HighLow result {high, low} shifted value +function bit64.raw_lshift(x, n) + if n == 0 then + return bit64.new(x[1], x[2]) + elseif n >= 64 then + return bit64.new(0, 0) + elseif n >= 32 then + return bit64.new(bit32_raw_lshift(x[2], n - 32), 0) + else + local new_high = bit32_raw_bor(bit32_raw_lshift(x[1], n), bit32_raw_rshift(x[2], 32 - n)) + local new_low = bit32_raw_lshift(x[2], n) + return bit64.new(new_high, new_low) + end +end + +--- Raw logical right shift (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to shift {high, low} +--- @param n integer Number of positions to shift (must be >= 0) +--- @return Int64HighLow result {high, low} shifted value +function bit64.raw_rshift(x, n) + if n == 0 then + return bit64.new(x[1], x[2]) + elseif n >= 64 then + return bit64.new(0, 0) + elseif n >= 32 then + return bit64.new(0, bit32_raw_rshift(x[1], n - 32)) + else + local new_low = bit32_raw_bor(bit32_raw_rshift(x[2], n), bit32_raw_lshift(x[1], 32 - n)) + local new_high = bit32_raw_rshift(x[1], n) + return bit64.new(new_high, new_low) + end +end + +--- Raw arithmetic right shift (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to shift {high, low} +--- @param n integer Number of positions to shift (must be >= 0) +--- @return Int64HighLow result {high, low} shifted value +function bit64.raw_arshift(x, n) + if n == 0 then + return bit64.new(x[1], x[2]) + end + + local is_negative = bit32_raw_band(x[1], 0x80000000) ~= 0 + + if n >= 64 then + if is_negative then + return bit64.new(0xFFFFFFFF, 0xFFFFFFFF) + else + return bit64.new(0, 0) + end + elseif n >= 32 then + local new_low = bit32_raw_arshift(x[1], n - 32) + local new_high = is_negative and 0xFFFFFFFF or 0 + return bit64.new(new_high, new_low) + else + local new_low = bit32_raw_bor(bit32_raw_rshift(x[2], n), bit32_raw_lshift(x[1], 32 - n)) + local new_high = bit32_raw_arshift(x[1], n) + return bit64.new(new_high, new_low) + end +end + +--- Raw left rotate (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to rotate {high, low} +--- @param n integer Number of positions to rotate +--- @return Int64HighLow result {high, low} rotated value +function bit64.raw_rol(x, n) + n = n % 64 + if n == 0 then + return bit64.new(x[1], x[2]) + end + + local high, low = x[1], x[2] + + if n == 32 then + return bit64.new(low, high) + elseif n < 32 then + local new_high = bit32_raw_bor(bit32_raw_lshift(high, n), bit32_raw_rshift(low, 32 - n)) + local new_low = bit32_raw_bor(bit32_raw_lshift(low, n), bit32_raw_rshift(high, 32 - n)) + return bit64.new(new_high, new_low) + else + n = n - 32 + local new_high = bit32_raw_bor(bit32_raw_lshift(low, n), bit32_raw_rshift(high, 32 - n)) + local new_low = bit32_raw_bor(bit32_raw_lshift(high, n), bit32_raw_rshift(low, 32 - n)) + return bit64.new(new_high, new_low) + end +end + +--- Raw right rotate (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to rotate {high, low} +--- @param n integer Number of positions to rotate +--- @return Int64HighLow result {high, low} rotated value +function bit64.raw_ror(x, n) + n = n % 64 + if n == 0 then + return bit64.new(x[1], x[2]) + end + + local high, low = x[1], x[2] + + if n == 32 then + return bit64.new(low, high) + elseif n < 32 then + local new_low = bit32_raw_bor(bit32_raw_rshift(low, n), bit32_raw_lshift(high, 32 - n)) + local new_high = bit32_raw_bor(bit32_raw_rshift(high, n), bit32_raw_lshift(low, 32 - n)) + return bit64.new(new_high, new_low) + else + n = n - 32 + local new_low = bit32_raw_bor(bit32_raw_rshift(high, n), bit32_raw_lshift(low, 32 - n)) + local new_high = bit32_raw_bor(bit32_raw_rshift(low, n), bit32_raw_lshift(high, 32 - n)) + return bit64.new(new_high, new_low) + end +end + +--- Raw 64-bit addition (uses bit32.raw_band for masking). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} sum +function bit64.raw_add(a, b) + local low = a[2] + b[2] + local high = a[1] + b[1] + + if low >= 0x100000000 then + high = high + 1 + low = low % 0x100000000 + end + + high = high % 0x100000000 + + return bit64.new(high, low) +end + -------------------------------------------------------------------------------- -- Self-test -------------------------------------------------------------------------------- @@ -821,6 +1010,18 @@ function bit64.selftest() print(" FAIL: new() with no args creates {0, 0}") end + -- Test bit64.new() normalizes negative values (LuaJIT raw_* compatibility) + total = total + 1 + local neg_val = bit64.new(-1, -2147483648) -- -1 -> 0xFFFFFFFF, -2147483648 -> 0x80000000 + if bit64.is_int64(neg_val) and neg_val[1] == 0xFFFFFFFF and neg_val[2] == 0x80000000 then + print(" PASS: new() normalizes negative values to unsigned") + passed = passed + 1 + else + print(" FAIL: new() normalizes negative values to unsigned") + print(string.format(" Expected: {0x%08X, 0x%08X}", 0xFFFFFFFF, 0x80000000)) + print(string.format(" Got: {0x%08X, 0x%08X}", neg_val[1], neg_val[2])) + end + -- Test is_int64() returns false for regular tables total = total + 1 local plain_table = { 0x12345678, 0x9ABCDEF0 } @@ -845,61 +1046,61 @@ function bit64.selftest() { name = "band", fn = function() - return bit64.band({ 1, 2 }, { 3, 4 }) + return bit64.band(bit64.new(1, 2), bit64.new(3, 4)) end, }, { name = "bor", fn = function() - return bit64.bor({ 1, 2 }, { 3, 4 }) + return bit64.bor(bit64.new(1, 2), bit64.new(3, 4)) end, }, { name = "bxor", fn = function() - return bit64.bxor({ 1, 2 }, { 3, 4 }) + return bit64.bxor(bit64.new(1, 2), bit64.new(3, 4)) end, }, { name = "bnot", fn = function() - return bit64.bnot({ 1, 2 }) + return bit64.bnot(bit64.new(1, 2)) end, }, { name = "lshift", fn = function() - return bit64.lshift({ 1, 2 }, 1) + return bit64.lshift(bit64.new(1, 2), 1) end, }, { name = "rshift", fn = function() - return bit64.rshift({ 1, 2 }, 1) + return bit64.rshift(bit64.new(1, 2), 1) end, }, { name = "arshift", fn = function() - return bit64.arshift({ 1, 2 }, 1) + return bit64.arshift(bit64.new(1, 2), 1) end, }, { name = "rol", fn = function() - return bit64.rol({ 1, 2 }, 1) + return bit64.rol(bit64.new(1, 2), 1) end, }, { name = "ror", fn = function() - return bit64.ror({ 1, 2 }, 1) + return bit64.ror(bit64.new(1, 2), 1) end, }, { name = "add", fn = function() - return bit64.add({ 1, 2 }, { 3, 4 }) + return bit64.add(bit64.new(1, 2), bit64.new(3, 4)) end, }, { @@ -928,7 +1129,7 @@ function bit64.selftest() end -- Test to_number strict mode error case - print("\nRunning to_number strict mode tests...") + print("\nRunning to_number/from_number edge case tests...") total = total + 1 local ok, err = pcall(function() bit64.to_number(bit64.new(0x00200000, 0x00000000), true) -- 2^53, exceeds 53-bit @@ -945,6 +1146,220 @@ function bit64.selftest() end end + -- Test to_number pass-through for number input + total = total + 1 + local num_input = 12345 + local num_result = bit64.to_number(num_input) + if num_result == num_input then + print(" PASS: to_number passes through number input unchanged") + passed = passed + 1 + else + print(" FAIL: to_number passes through number input unchanged") + print(" Expected: " .. tostring(num_input)) + print(" Got: " .. tostring(num_result)) + end + + -- Test to_number errors on plain table (non-Int64) + total = total + 1 + ok, err = pcall(function() + bit64.to_number({ 1, 2 }) -- plain table, not Int64 + end) + if not ok and type(err) == "string" and string.find(err, "not a valid Int64") then + print(" PASS: to_number errors on plain table (non-Int64)") + passed = passed + 1 + else + print(" FAIL: to_number errors on plain table (non-Int64)") + if ok then + print(" Expected error but got success") + else + print(" Expected 'not a valid Int64' error but got: " .. tostring(err)) + end + end + + -- Test from_number pass-through for Int64 input + total = total + 1 + local int64_input = bit64.new(0x12345678, 0x9ABCDEF0) + local int64_result = bit64.from_number(int64_input) + if rawequal(int64_result, int64_input) then + print(" PASS: from_number passes through Int64 input unchanged") + passed = passed + 1 + else + print(" FAIL: from_number passes through Int64 input unchanged") + print(" Expected same reference, got different object") + end + + -- Test raw_* operations + print("\n Testing raw_* operations...") + + local raw_tests = { + -- Core bitwise (test high-bit cases where sign matters) + { + name = "raw_band(new(0xFFFFFFFF, 0x80000000), new(0x80000000, 0xFFFFFFFF))", + fn = function() + return bit64.raw_band(bit64.new(0xFFFFFFFF, 0x80000000), bit64.new(0x80000000, 0xFFFFFFFF)) + end, + expected = bit64.new(0x80000000, 0x80000000), + }, + { + name = "raw_bor(new(0x80000000, 0), new(0, 0x80000000))", + fn = function() + return bit64.raw_bor(bit64.new(0x80000000, 0), bit64.new(0, 0x80000000)) + end, + expected = bit64.new(0x80000000, 0x80000000), + }, + { + name = "raw_bxor(new(0xAAAAAAAA, 0x55555555), new(0x55555555, 0xAAAAAAAA))", + fn = function() + return bit64.raw_bxor(bit64.new(0xAAAAAAAA, 0x55555555), bit64.new(0x55555555, 0xAAAAAAAA)) + end, + expected = bit64.new(0xFFFFFFFF, 0xFFFFFFFF), + }, + { + name = "raw_bnot(new(0, 0))", + fn = function() + return bit64.raw_bnot(bit64.new(0, 0)) + end, + expected = bit64.new(0xFFFFFFFF, 0xFFFFFFFF), + }, + + -- Shifts + { + name = "raw_lshift(new(0, 1), 63)", + fn = function() + return bit64.raw_lshift(bit64.new(0, 1), 63) + end, + expected = bit64.new(0x80000000, 0), + }, + { + name = "raw_rshift(new(0x80000000, 0), 63)", + fn = function() + return bit64.raw_rshift(bit64.new(0x80000000, 0), 63) + end, + expected = bit64.new(0, 1), + }, + { + name = "raw_arshift(new(0x80000000, 0), 32)", + fn = function() + return bit64.raw_arshift(bit64.new(0x80000000, 0), 32) + end, + expected = bit64.new(0xFFFFFFFF, 0x80000000), + }, + + -- Rotates + { + name = "raw_rol(new(0x12345678, 0x9ABCDEF0), 16)", + fn = function() + return bit64.raw_rol(bit64.new(0x12345678, 0x9ABCDEF0), 16) + end, + expected = bit64.new(0x56789ABC, 0xDEF01234), + }, + { + name = "raw_ror(new(0x12345678, 0x9ABCDEF0), 16)", + fn = function() + return bit64.raw_ror(bit64.new(0x12345678, 0x9ABCDEF0), 16) + end, + expected = bit64.new(0xDEF01234, 0x56789ABC), + }, + + -- Addition + { + name = "raw_add(new(0xFFFFFFFF, 0xFFFFFFFF), new(0, 1))", + fn = function() + return bit64.raw_add(bit64.new(0xFFFFFFFF, 0xFFFFFFFF), bit64.new(0, 1)) + end, + expected = bit64.new(0, 0), + }, + } + + for _, test in ipairs(raw_tests) do + total = total + 1 + local result = test.fn() + if eq64(result, test.expected) then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name) + print(" Expected: " .. fmt64(test.expected)) + print(" Got: " .. fmt64(result)) + end + end + + -- Test that raw_* operations return Int64 + print("\n Testing raw_* operations return Int64...") + local raw_ops_returning_int64 = { + { + name = "raw_band", + fn = function() + return bit64.raw_band(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + { + name = "raw_bor", + fn = function() + return bit64.raw_bor(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + { + name = "raw_bxor", + fn = function() + return bit64.raw_bxor(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + { + name = "raw_bnot", + fn = function() + return bit64.raw_bnot(bit64.new(1, 2)) + end, + }, + { + name = "raw_lshift", + fn = function() + return bit64.raw_lshift(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_rshift", + fn = function() + return bit64.raw_rshift(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_arshift", + fn = function() + return bit64.raw_arshift(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_rol", + fn = function() + return bit64.raw_rol(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_ror", + fn = function() + return bit64.raw_ror(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_add", + fn = function() + return bit64.raw_add(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + } + + for _, op in ipairs(raw_ops_returning_int64) do + total = total + 1 + local result = op.fn() + if bit64.is_int64(result) then + print(" PASS: " .. op.name .. "() returns Int64") + passed = passed + 1 + else + print(" FAIL: " .. op.name .. "() returns Int64") + end + end + print(string.format("\n64-bit operations: %d/%d tests passed\n", passed, total)) return passed == total end