diff --git a/bitsandbytes/backends/triton/kernels_4bit.py b/bitsandbytes/backends/triton/kernels_4bit.py index 0e94f49e8..bdd59fad2 100644 --- a/bitsandbytes/backends/triton/kernels_4bit.py +++ b/bitsandbytes/backends/triton/kernels_4bit.py @@ -66,7 +66,8 @@ def quantize_fp4_blockwise_kernel( packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) - out_mask = out_offsets < n_elements // 2 + # Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n + out_mask = out_offsets < (n_elements - n_elements // 2) tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask) @@ -148,7 +149,8 @@ def quantize_nf4_blockwise_kernel( packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) - out_mask = out_offsets < n_elements // 2 + # Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n + out_mask = out_offsets < (n_elements - n_elements // 2) tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask) @@ -330,7 +332,14 @@ def dequant_nf4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.const # ) @triton.jit def dequant_4bit_kernel( - a_ptr, c_ptr, quant_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr + a_ptr, + c_ptr, + quant_ptr, + absmax_ptr, + num_paired_elements, + num_output_elements, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. block_start = pid * SPLIT_SIZE @@ -350,7 +359,7 @@ def dequant_4bit_kernel( out_block_start = pid * SPLIT_SIZE * 2 offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2) - mask = offs < num_paired_elements * 2 + mask = offs < num_output_elements tl.store(c_ptr + offs, out_dq, mask) @@ -367,7 +376,13 @@ def dequant_4bit_kernel( # ) @triton.jit def dequant_fp4_kernel( - a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr + a_ptr, + c_ptr, + absmax_ptr, + num_paired_elements, + num_output_elements, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. block_start = pid * SPLIT_SIZE @@ -386,7 +401,7 @@ def dequant_fp4_kernel( out_block_start = pid * SPLIT_SIZE * 2 offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2) - mask = offs < num_paired_elements * 2 + mask = offs < num_output_elements tl.store(c_ptr + offs, out_dq, mask) @@ -403,7 +418,13 @@ def dequant_fp4_kernel( # ) @triton.jit def dequant_nf4_kernel( - a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr + a_ptr, + c_ptr, + absmax_ptr, + num_paired_elements, + num_output_elements, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. block_start = pid * SPLIT_SIZE @@ -422,7 +443,7 @@ def dequant_nf4_kernel( out_block_start = pid * SPLIT_SIZE * 2 offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2) - mask = offs < num_paired_elements * 2 + mask = offs < num_output_elements tl.store(c_ptr + offs, out_dq, mask) @@ -439,15 +460,16 @@ def dequantize_4bit_impl( # Elements are in uint8 format, so interleaved # so total amount of data is 2 * elem_count number_of_paired_elements = A.numel() + num_output_elements = out.numel() # we assume that split_size > quant_blocksize SPLIT_SIZE = 256 # grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), ) grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) if quant_type == "fp4": - dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) + dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE) else: - dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) + dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE) def dequantize_4bit_impl_passing_code( @@ -459,12 +481,15 @@ def dequantize_4bit_impl_passing_code( out: torch.Tensor, ) -> None: number_of_paired_elements = A.numel() + num_output_elements = out.numel() # we assume that split_size > quant_blocksize SPLIT_SIZE = 256 # grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), ) grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) - dequant_4bit_kernel[grid](A, out, code, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) + dequant_4bit_kernel[grid]( + A, out, code, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE + ) ######################### Fallback dequantization functions ######################### diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 66bff3c94..3a16961fa 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -82,7 +82,8 @@ def quantize_4bit( blocks = -(n // -(blocksize * 2)) absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype) - out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) + # Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n + out = torch.empty((n - n // 2, 1), device=A.device, dtype=torch.uint8) with torch_accelerator_module.device(A.device): kernels_4bit.quantize_4bit_blockwise_triton( diff --git a/csrc/xpu_kernels.cpp b/csrc/xpu_kernels.cpp index 8ee8add98..e9d4c0ccb 100644 --- a/csrc/xpu_kernels.cpp +++ b/csrc/xpu_kernels.cpp @@ -95,20 +95,21 @@ inline float dDequantizeNF4(unsigned char val) { template SYCL_EXTERNAL void kDequantizeBlockwise::operator()(sycl::nd_item<1> item) const { - const int base_idx = item.get_group(0) * TILE_SIZE; - size_t local_idx = item.get_local_id(0) * NUM_PER_TH; + const int64_t base_idx = static_cast(item.get_group(0)) * TILE_SIZE; + int64_t local_idx = static_cast(item.get_local_id(0)) * NUM_PER_TH; float local_abs_max = -FLT_MAX; - int local_load_idx = 0; - int local_store_idx = 0; + int64_t local_load_idx = 0; + int64_t local_store_idx = 0; uint8_t qvals[NUM_PER_TH]; T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; if (DATA_TYPE > 0) { - local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx); - local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2); + // Cast n to int64_t to avoid overflow for large n (same as CUDA) + local_load_idx = sycl::min(static_cast(TILE_SIZE), (static_cast(n) + 1) / 2 - base_idx); + local_store_idx = sycl::min(static_cast(TILE_SIZE * 2), static_cast(n) - base_idx * 2); } else { - local_load_idx = sycl::min(TILE_SIZE, n - base_idx); + local_load_idx = sycl::min(static_cast(TILE_SIZE), static_cast(n) - base_idx); local_store_idx = local_load_idx; } diff --git a/csrc/xpu_ops.cpp b/csrc/xpu_ops.cpp index 48c986fc4..960201e2e 100644 --- a/csrc/xpu_ops.cpp +++ b/csrc/xpu_ops.cpp @@ -10,7 +10,8 @@ void dequantizeBlockwise( const int num_per_th = 4; const int tile_size = workgroup_size * num_per_th; if (DATA_TYPE > 0) { - const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2); + // Upcast to int64 to avoid overflow for large n (same as CUDA) + const int workgroup_num = (static_cast(n) + tile_size * 2 - 1) / (tile_size * 2); sycl::range<1> local_range{(size_t)workgroup_size}; sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; kDequantizeBlockwise kfn(code, A, absmax, out, blocksize / 2, n); @@ -18,7 +19,8 @@ void dequantizeBlockwise( sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn ); } else { - const int workgroup_num = (n + tile_size - 1) / tile_size; + // Upcast to int64 to avoid overflow for large n (same as CUDA) + const int workgroup_num = (static_cast(n) + tile_size - 1) / tile_size; sycl::range<1> local_range{(size_t)workgroup_size}; sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; kDequantizeBlockwise kfn(code, A, absmax, out, blocksize, n);