From 83ec2fb793b06b9ff947bb0388adad73b8bff23d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Dec 2025 11:10:41 +0530 Subject: [PATCH] support device type device_maps to work with offloading. --- src/diffusers/pipelines/pipeline_utils.py | 24 ++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 392d5fb3feb4..0ed17877c75f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -109,7 +109,7 @@ for library in LOADABLE_CLASSES: LIBRARIES.append(library) -SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()] +SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"] logger = logging.get_logger(__name__) @@ -462,8 +462,7 @@ def module_is_offloaded(module): pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) - - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + is_pipeline_device_mapped = self._is_pipeline_device_mapped() if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." @@ -1164,7 +1163,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t """ self._maybe_raise_error_if_group_offload_active(raise_error=True) - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + is_pipeline_device_mapped = self._is_pipeline_device_mapped() if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`." @@ -1286,7 +1285,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") self.remove_all_hooks() - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + is_pipeline_device_mapped = self._is_pipeline_device_mapped() if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." @@ -2171,6 +2170,21 @@ def _maybe_raise_error_if_group_offload_active( return True return False + def _is_pipeline_device_mapped(self): + # We support passing `device_map="cuda"`, for example. This is helpful, in case + # users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable + # in limited VRAM environments because quantized models often initialize directly on the accelerator. + device_map = self.hf_device_map + is_device_type_map = False + if isinstance(device_map, str): + try: + torch.device(device_map) + is_device_type_map = True + except RuntimeError: + pass + + return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 + class StableDiffusionMixin: r"""