From 05e5a710e3de02677118e4a928586ec87f95d93d Mon Sep 17 00:00:00 2001 From: asmertpc-cloud Date: Mon, 22 Dec 2025 08:36:59 +0300 Subject: [PATCH 1/4] Update cuda_specs.py fix for windows --- bitsandbytes/cuda_specs.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 71e7568a9..ce661f692 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -6,6 +6,12 @@ from typing import Optional import torch +import sys + +if (sys.platform == "win32"): + rocminfo = "hipinfo" +else: + rocminfo = "rocminfo" @dataclasses.dataclass(frozen=True) @@ -83,7 +89,7 @@ def get_rocm_gpu_arch() -> str: logger = logging.getLogger(__name__) try: if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) + result = subprocess.run([rocminfo], capture_output=True, text=True) match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) if match: return "gfx" + match.group(1) @@ -107,7 +113,7 @@ def get_rocm_warpsize() -> int: logger = logging.getLogger(__name__) try: if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) + result = subprocess.run([rocminfo], capture_output=True, text=True) match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) if match: return int(match.group(1)) From 3364ae390d52856e7cc0da895ef54543b626989d Mon Sep 17 00:00:00 2001 From: asmertpc-cloud Date: Mon, 22 Dec 2025 19:42:48 +0300 Subject: [PATCH 2/4] Update cuda_specs.py --- bitsandbytes/cuda_specs.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index ce661f692..908570e1b 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -108,15 +108,18 @@ def get_rocm_gpu_arch() -> str: return "unknown" +# Wavefront size (or warp size) in GPU computing is the number of threads that execute +# together in lockstep on a GPU core, typically 32 or 64, depending on the architecture +# (e.g., Nvidia is 32, older AMD GCN was 64, newer AMD RDNA can be 32 or 64). def get_rocm_warpsize() -> int: """Get ROCm warp size.""" logger = logging.getLogger(__name__) try: if torch.version.hip: result = subprocess.run([rocminfo], capture_output=True, text=True) - match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) + match = re.search(r"(wavefront\s|warp)size:\s+([1-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE) if match: - return int(match.group(1)) + return int(match.group(2)) else: # default to 64 to be safe return 64 From b700ae2481a51aa3c23d8f0a503b849ebf6b88b6 Mon Sep 17 00:00:00 2001 From: asmertpc-cloud Date: Mon, 22 Dec 2025 19:52:22 +0300 Subject: [PATCH 3/4] Update cuda_specs.py oops --- bitsandbytes/cuda_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 908570e1b..5e53eaf5f 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -117,7 +117,7 @@ def get_rocm_warpsize() -> int: try: if torch.version.hip: result = subprocess.run([rocminfo], capture_output=True, text=True) - match = re.search(r"(wavefront\s|warp)size:\s+([1-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE) + match = re.search(r"(wavefront\s|warp)size:\s+([0-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE) if match: return int(match.group(2)) else: From 653f673b8fd659fcbecc2e35b7cb6593c9907c36 Mon Sep 17 00:00:00 2001 From: asmertpc-cloud Date: Sun, 11 Jan 2026 10:41:37 +0300 Subject: [PATCH 4/4] Update cuda_specs.py --- bitsandbytes/cuda_specs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 5e53eaf5f..5ee22fc91 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -90,7 +90,7 @@ def get_rocm_gpu_arch() -> str: try: if torch.version.hip: result = subprocess.run([rocminfo], capture_output=True, text=True) - match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + match = re.search(r"Name:\s+gfx([a-z\d]+)", result.stdout, re.IGNORECASE) if match: return "gfx" + match.group(1) else: @@ -117,7 +117,7 @@ def get_rocm_warpsize() -> int: try: if torch.version.hip: result = subprocess.run([rocminfo], capture_output=True, text=True) - match = re.search(r"(wavefront\s|warp)size:\s+([0-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE) + match = re.search(r"(wavefront\s|warp)size:\s+([0-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE) if match: return int(match.group(2)) else: