From 1f026ad14eceed4e7d4329f7776921da531697ca Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 12 Nov 2025 10:17:54 +0530 Subject: [PATCH 01/12] update --- tests/models/test_modeling_common.py | 24 +- tests/models/testing_utils/__init__.py | 2 + tests/models/testing_utils/attention.py | 0 tests/models/testing_utils/common.py | 304 ++++++++++++++++++ tests/models/testing_utils/compile.py | 166 ++++++++++ tests/models/testing_utils/hub.py | 110 +++++++ tests/models/testing_utils/ip_adapter.py | 0 tests/models/testing_utils/lora.py | 0 tests/models/testing_utils/offloading.py | 0 tests/models/testing_utils/single_file.py | 252 +++++++++++++++ tests/models/testing_utils/training.py | 0 .../test_models_transformer_flux_.py | 154 +++++++++ 12 files changed, 1000 insertions(+), 12 deletions(-) create mode 100644 tests/models/testing_utils/__init__.py create mode 100644 tests/models/testing_utils/attention.py create mode 100644 tests/models/testing_utils/common.py create mode 100644 tests/models/testing_utils/compile.py create mode 100644 tests/models/testing_utils/hub.py create mode 100644 tests/models/testing_utils/ip_adapter.py create mode 100644 tests/models/testing_utils/lora.py create mode 100644 tests/models/testing_utils/offloading.py create mode 100644 tests/models/testing_utils/single_file.py create mode 100644 tests/models/testing_utils/training.py create mode 100644 tests/models/transformers/test_models_transformer_flux_.py diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 6f4c3d544b45..520bd8f871a4 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -317,9 +317,9 @@ def test_local_files_only_with_sharded_checkpoint(self): repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True ) - assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), ( - "Model parameters don't match!" - ) + assert all( + torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters()) + ), "Model parameters don't match!" # Remove a shard file cached_shard_file = try_to_load_from_cache( @@ -335,9 +335,9 @@ def test_local_files_only_with_sharded_checkpoint(self): # Verify error mentions the missing shard error_msg = str(context.exception) - assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, ( - f"Expected error about missing shard, got: {error_msg}" - ) + assert ( + cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg + ), f"Expected error about missing shard, got: {error_msg}" @unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners") @unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.") @@ -354,9 +354,9 @@ def test_one_request_upon_cached(self): ) download_requests = [r.method for r in m.request_history] - assert download_requests.count("HEAD") == 3, ( - "3 HEAD requests one for config, one for model, and one for shard index file." - ) + assert ( + download_requests.count("HEAD") == 3 + ), "3 HEAD requests one for config, one for model, and one for shard index file." assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" with requests_mock.mock(real_http=True) as m: @@ -368,9 +368,9 @@ def test_one_request_upon_cached(self): ) cache_requests = [r.method for r in m.request_history] - assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, ( - "We should call only `model_info` to check for commit hash and knowing if shard index is present." - ) + assert ( + "HEAD" == cache_requests[0] and len(cache_requests) == 2 + ), "We should call only `model_info` to check for commit hash and knowing if shard index is present." def test_weight_overwrite(self): with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py new file mode 100644 index 000000000000..7c64e0a04bac --- /dev/null +++ b/tests/models/testing_utils/__init__.py @@ -0,0 +1,2 @@ +from .common import ModelTesterMixin +from .single_file import SingleFileTesterMixin diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py new file mode 100644 index 000000000000..ec4245af1249 --- /dev/null +++ b/tests/models/testing_utils/common.py @@ -0,0 +1,304 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from typing import Dict, List, Tuple + +import pytest +import torch + +from ...testing_utils import torch_device + + +class ModelTesterMixin: + """ + Base mixin class for model testing with common test methods. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + - main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states") + - base_precision: Default tolerance for floating point comparisons (default: 1e-3) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + model_class = None + base_precision = 1e-3 + + def get_init_dict(self): + raise NotImplementedError("get_init_dict must be implemented by subclasses. ") + + def get_dummy_inputs(self): + raise NotImplementedError( + "get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict." + ) + + def check_device_map_is_respected(self, model, device_map): + """Helper method to check if device map is correctly applied to model parameters.""" + for param_name, param in model.named_parameters(): + # Find device in device_map + while len(param_name) > 0 and param_name not in device_map: + param_name = ".".join(param_name.split(".")[:-1]) + if param_name not in device_map: + raise ValueError("device map is incomplete, it does not contain any device for `param_name`.") + + param_device = device_map[param_name] + if param_device in ["cpu", "disk"]: + assert param.device == torch.device( + "meta" + ), f"Expected device 'meta' for {param_name}, got {param.device}" + else: + assert param.device == torch.device( + param_device + ), f"Expected device {param_device} for {param_name}, got {param.device}" + + def test_from_save_pretrained(self, expected_max_diff=5e-5): + """Test that model can be saved and loaded with save_pretrained/from_pretrained.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + with torch.no_grad(): + image = model(**self.get_dummy_inputs()) + + if isinstance(image, dict): + image = image.to_tuple()[0] + + new_image = new_model(**self.get_dummy_inputs()) + + if isinstance(new_image, dict): + new_image = new_image.to_tuple()[0] + + max_diff = (image - new_image).abs().max().item() + assert ( + max_diff <= expected_max_diff + ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + + def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): + """Test save_pretrained/from_pretrained with variant parameter.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, variant="fp16") + new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") + + # non-variant cannot be loaded + with pytest.raises(OSError) as exc_info: + self.model_class.from_pretrained(tmpdirname) + + # make sure that error message states what keys are missing + assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value) + + new_model.to(torch_device) + + with torch.no_grad(): + image = model(**self.get_dummy_inputs()) + if isinstance(image, dict): + image = image.to_tuple()[0] + + new_image = new_model(**self.get_dummy_inputs()) + + if isinstance(new_image, dict): + new_image = new_image.to_tuple()[0] + + max_diff = (image - new_image).abs().max().item() + assert ( + max_diff <= expected_max_diff + ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + + def test_from_save_pretrained_dtype(self): + """Test save_pretrained/from_pretrained preserves dtype correctly.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + if torch_device == "mps" and dtype == torch.bfloat16: + continue + with tempfile.TemporaryDirectory() as tmpdirname: + model.to(dtype) + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) + assert new_model.dtype == dtype + if ( + hasattr(self.model_class, "_keep_in_fp32_modules") + and self.model_class._keep_in_fp32_modules is None + ): + new_model = self.model_class.from_pretrained( + tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype + ) + assert new_model.dtype == dtype + + def test_determinism(self, expected_max_diff=1e-5): + """Test that model outputs are deterministic across multiple forward passes.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + first = model(**self.get_dummy_inputs()) + if isinstance(first, dict): + first = first.to_tuple()[0] + + second = model(**self.get_dummy_inputs()) + if isinstance(second, dict): + second = second.to_tuple()[0] + + # Remove NaN values and compute max difference + first_flat = first.flatten() + second_flat = second.flatten() + + # Filter out NaN values + mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat)) + first_filtered = first_flat[mask] + second_filtered = second_flat[mask] + + max_diff = torch.abs(first_filtered - second_filtered).max().item() + assert ( + max_diff <= expected_max_diff + ), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}" + + def test_output(self, expected_output_shape=None): + """Test that model produces output with expected shape.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + inputs_dict = self.get_dummy_inputs() + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + assert output is not None, "Model output is None" + assert ( + output.shape == expected_output_shape + ), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" + + def test_model_from_pretrained(self): + """Test that model loaded from pretrained matches original model.""" + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + # test if the model can be loaded from the config + # and has all the expected shape + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, safe_serialization=False) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + new_model.eval() + + # check if all parameters shape are the same + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + assert ( + param_1.shape == param_2.shape + ), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" + + with torch.no_grad(): + output_1 = model(**self.get_dummy_inputs()) + + if isinstance(output_1, dict): + output_1 = output_1.to_tuple()[0] + + output_2 = new_model(**self.get_dummy_inputs()) + + if isinstance(output_2, dict): + output_2 = output_2.to_tuple()[0] + + assert ( + output_1.shape == output_2.shape + ), f"Output shape mismatch. Original: {output_1.shape}, loaded: {output_2.shape}" + + def test_outputs_equivalence(self): + """Test that dict and tuple outputs are equivalent.""" + + def set_nan_tensor_to_zero(t): + # Temporary fallback until `aten::_index_put_impl_` is implemented in mps + # Track progress in https://github.com/pytorch/pytorch/issues/77764 + device = t.device + if device.type == "mps": + t = t.to("cpu") + t[t != t] = 0 + return t.to(device) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + assert torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), ( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ) + + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs_dict = model(**self.get_dummy_inputs()) + outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False) + + recursive_check(outputs_tuple, outputs_dict) + + def test_model_config_to_json_string(self): + """Test model config can be serialized to JSON string.""" + model = self.model_class(**self.get_init_dict()) + + json_string = model.config.to_json_string() + assert isinstance(json_string, str), "Config to_json_string should return a string" + assert len(json_string) > 0, "JSON string should not be empty" + + def test_keep_in_fp32_modules(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16 + Also ensures if inference works. + """ + if not hasattr(self.model_class, "_keep_in_fp32_modules"): + pytest.skip("Model does not have _keep_in_fp32_modules") + + fp32_modules = self.model_class._keep_in_fp32_modules + + for torch_dtype in [torch.bfloat16, torch.float16]: + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, torch_dtype=torch_dtype).to( + torch_device + ) + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): + assert param.data == torch.float32 + else: + assert param.data == torch_dtype diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py new file mode 100644 index 000000000000..2c083176c585 --- /dev/null +++ b/tests/models/testing_utils/compile.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import tempfile + +import pytest +import torch + +from ...testing_utils import ( + backend_empty_cache, + is_torch_compile, + require_accelerator, + require_torch_version_greater, + torch_device, +) + + +@is_torch_compile +@require_accelerator +@require_torch_version_greater("2.7.1") +class TorchCompileTesterMixin: + """ + Mixin class for testing torch.compile functionality on models. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + - different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + different_shapes_for_compilation = None + + def setup_method(self): + """Setup before each test method.""" + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + """Cleanup after each test method.""" + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def test_torch_compile_recompilation_and_graph_break(self): + """Test that model compiles without graph breaks and doesn't recompile unnecessarily.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model = torch.compile(model, fullgraph=True) + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + + def test_torch_compile_repeated_blocks(self): + """Test compilation of repeated blocks if model supports it.""" + if self.model_class._repeated_blocks is None: + pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.") + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model.compile_repeated_blocks(fullgraph=True) + + recompile_limit = 1 + if self.model_class.__name__ == "UNet2DConditionModel": + recompile_limit = 2 + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(recompile_limit=recompile_limit), + torch.no_grad(), + ): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + + def test_compile_with_group_offloading(self): + """Test that compilation works with group offloading enabled.""" + if not self.model_class._supports_group_offloading: + pytest.skip("Model does not support group offloading.") + + torch._dynamo.config.cache_size_limit = 10000 + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.eval() + + group_offload_kwargs = { + "onload_device": torch_device, + "offload_device": "cpu", + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + "non_blocking": True, + } + model.enable_group_offload(**group_offload_kwargs) + model.compile() + + with torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + + def test_compile_on_different_shapes(self): + """Test dynamic compilation on different input shapes.""" + if self.different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + torch.fx.experimental._config.use_duck_shape = False + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model = torch.compile(model, fullgraph=True, dynamic=True) + + for height, width in self.different_shapes_for_compilation: + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): + inputs_dict = self.get_dummy_inputs(height=height, width=width) + _ = model(**inputs_dict) + + def test_compile_works_with_aot(self): + """Test that model works with ahead-of-time compilation and packaging.""" + from torch._inductor.package import load_package + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict).to(torch_device) + exported_model = torch.export.export(model, args=(), kwargs=inputs_dict) + + with tempfile.TemporaryDirectory() as tmpdir: + package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2") + _ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path) + assert os.path.exists(package_path), f"Package file not created at {package_path}" + loaded_binary = load_package(package_path, run_single_threaded=True) + + model.forward = loaded_binary + + with torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) diff --git a/tests/models/testing_utils/hub.py b/tests/models/testing_utils/hub.py new file mode 100644 index 000000000000..cbaded9ffffc --- /dev/null +++ b/tests/models/testing_utils/hub.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import uuid + +import pytest +import torch +from huggingface_hub import ModelCard, delete_repo +from huggingface_hub.utils import is_jinja_available + +from ...others.test_utils import TOKEN, USER, is_staging_test + + +@is_staging_test +class ModelPushToHubTesterMixin: + """ + Mixin class for testing push_to_hub functionality on models. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + """ + + identifier = uuid.uuid4() + repo_id = f"test-model-{identifier}" + org_repo_id = f"valid_org/{repo_id}-org" + + def test_push_to_hub(self): + """Test pushing model to hub and loading it back.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.push_to_hub(self.repo_id, token=TOKEN) + + new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") + for p1, p2 in zip(model.parameters(), new_model.parameters()): + assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained" + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.repo_id) + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN) + + new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") + for p1, p2 in zip(model.parameters(), new_model.parameters()): + assert torch.equal( + p1, p2 + ), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained" + + # Reset repo + delete_repo(self.repo_id, token=TOKEN) + + def test_push_to_hub_in_organization(self): + """Test pushing model to hub in organization namespace.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.push_to_hub(self.org_repo_id, token=TOKEN) + + new_model = self.model_class.from_pretrained(self.org_repo_id) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained" + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.org_repo_id) + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id) + + new_model = self.model_class.from_pretrained(self.org_repo_id) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + assert torch.equal( + p1, p2 + ), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained" + + # Reset repo + delete_repo(self.org_repo_id, token=TOKEN) + + def test_push_to_hub_library_name(self): + """Test that library_name in model card is set to 'diffusers'.""" + if not is_jinja_available(): + pytest.skip("Model card tests cannot be performed without Jinja installed.") + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.push_to_hub(self.repo_id, token=TOKEN) + + model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data + assert ( + model_card.library_name == "diffusers" + ), f"Expected library_name 'diffusers', got {model_card.library_name}" + + # Reset repo + delete_repo(self.repo_id, token=TOKEN) diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/testing_utils/offloading.py b/tests/models/testing_utils/offloading.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py new file mode 100644 index 000000000000..561dc3c56703 --- /dev/null +++ b/tests/models/testing_utils/single_file.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile + +import torch +from huggingface_hub import hf_hub_download, snapshot_download + +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name + +from ...testing_utils import ( + backend_empty_cache, + nightly, + require_torch_accelerator, + torch_device, +) + + +def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir): + """Download a single file checkpoint from the Hub to a temporary directory.""" + path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir) + return path + + +def download_diffusers_config(pretrained_model_name_or_path, tmpdir): + """Download diffusers config files (excluding weights) from a repository.""" + path = snapshot_download( + pretrained_model_name_or_path, + ignore_patterns=[ + "**/*.ckpt", + "*.ckpt", + "**/*.bin", + "*.bin", + "**/*.pt", + "*.pt", + "**/*.safetensors", + "*.safetensors", + ], + allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"], + local_dir=tmpdir, + ) + return path + + +@nightly +@require_torch_accelerator +@is_single_file +class SingleFileTesterMixin: + """ + Mixin class for testing single file loading for models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - ckpt_path: Path or Hub path to the single file checkpoint + - subfolder: (Optional) Subfolder within the repo + - torch_dtype: (Optional) torch dtype to use for testing + """ + + pretrained_model_name_or_path = None + ckpt_path = None + + def setup_method(self): + """Setup before each test method.""" + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + """Cleanup after each test method.""" + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_model_config(self): + """Test that config matches between pretrained and single file loading.""" + pretrained_kwargs = {} + single_file_kwargs = {} + + pretrained_kwargs["device"] = torch_device + single_file_kwargs["device"] = torch_device + + if hasattr(self, "subfolder") and self.subfolder: + pretrained_kwargs["subfolder"] = self.subfolder + + if hasattr(self, "torch_dtype") and self.torch_dtype: + pretrained_kwargs["torch_dtype"] = self.torch_dtype + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) + model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading: " + f"pretrained={model.config[param_name]}, single_file={param_value}" + ) + + def test_single_file_model_parameters(self): + """Test that parameters match between pretrained and single file loading.""" + pretrained_kwargs = {} + single_file_kwargs = {} + + pretrained_kwargs["device"] = torch_device + single_file_kwargs["device"] = torch_device + + if hasattr(self, "subfolder") and self.subfolder: + pretrained_kwargs["subfolder"] = self.subfolder + + if hasattr(self, "torch_dtype") and self.torch_dtype: + pretrained_kwargs["torch_dtype"] = self.torch_dtype + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) + model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) + + state_dict = model.state_dict() + state_dict_single_file = model_single_file.state_dict() + + assert set(state_dict.keys()) == set(state_dict_single_file.keys()), ( + "Model parameters keys differ between pretrained and single file loading. " + f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. " + f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}" + ) + + for key in state_dict.keys(): + param = state_dict[key] + param_single_file = state_dict_single_file[key] + + assert param.shape == param_single_file.shape, ( + f"Parameter shape mismatch for {key}: " + f"pretrained {param.shape} vs single file {param_single_file.shape}" + ) + + assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), ( + f"Parameter values differ for {key}: " + f"max difference {torch.max(torch.abs(param - param_single_file)).item()}" + ) + + def test_single_file_loading_local_files_only(self): + """Test single file loading with local_files_only=True.""" + single_file_kwargs = {} + + if hasattr(self, "torch_dtype") and self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + with tempfile.TemporaryDirectory() as tmpdir: + pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir) + + model_single_file = self.model_class.from_single_file( + local_ckpt_path, local_files_only=True, **single_file_kwargs + ) + + assert model_single_file is not None, "Failed to load model with local_files_only=True" + + def test_single_file_loading_with_diffusers_config(self): + """Test single file loading with diffusers config.""" + single_file_kwargs = {} + + if hasattr(self, "torch_dtype") and self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + # Load with config parameter + model_single_file = self.model_class.from_single_file( + self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs + ) + + # Load pretrained for comparison + pretrained_kwargs = {} + if hasattr(self, "subfolder") and self.subfolder: + pretrained_kwargs["subfolder"] = self.subfolder + if hasattr(self, "torch_dtype") and self.torch_dtype: + pretrained_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) + + # Compare configs + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}" + + def test_single_file_loading_with_diffusers_config_local_files_only(self): + """Test single file loading with diffusers config and local_files_only=True.""" + single_file_kwargs = {} + + if hasattr(self, "torch_dtype") and self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + with tempfile.TemporaryDirectory() as tmpdir: + pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir) + local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir) + + model_single_file = self.model_class.from_single_file( + local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs + ) + + assert model_single_file is not None, "Failed to load model with config and local_files_only=True" + + def test_single_file_loading_dtype(self): + """Test single file loading with different dtypes.""" + for dtype in [torch.float32, torch.float16]: + if torch_device == "mps" and dtype == torch.bfloat16: + continue + + model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype) + + assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}" + + # Cleanup + del model_single_file + gc.collect() + backend_empty_cache(torch_device) + + def test_checkpoint_variant_loading(self): + """Test loading checkpoints with alternate keys/variants if provided.""" + if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths: + return + + for ckpt_path in self.alternate_ckpt_paths: + backend_empty_cache(torch_device) + + single_file_kwargs = {} + if hasattr(self, "torch_dtype") and self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs) + + assert model is not None, f"Failed to load checkpoint from {ckpt_path}" + + del model + gc.collect() + backend_empty_cache(torch_device) diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/transformers/test_models_transformer_flux_.py b/tests/models/transformers/test_models_transformer_flux_.py new file mode 100644 index 000000000000..a67218548cbc --- /dev/null +++ b/tests/models/transformers/test_models_transformer_flux_.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from diffusers import FluxTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin +from ..testing_utils.common import ModelTesterMixin +from ..testing_utils.compile import TorchCompileTesterMixin +from ..testing_utils.single_file import SingleFileTesterMixin + + +enable_full_determinism() + + +class FluxTransformerTesterConfig: + model_class = FluxTransformer2DModel + + def get_init_dict(self): + """Return Flux model initialization arguments.""" + return { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], + } + + def get_dummy_inputs(self): + batch_size = 1 + height = width = 4 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 8 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + + @property + def input_shape(self): + return (16, 4) + + @property + def output_shape(self): + return (16, 4) + + +class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): + def test_deprecated_inputs_img_txt_ids_3d(self): + """Test that deprecated 3D img_ids and txt_ids still work.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output_1 = model(**inputs_dict).to_tuple()[0] + + # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) + text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) + image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) + + assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" + assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" + + inputs_dict["txt_ids"] = text_ids_3d + inputs_dict["img_ids"] = image_ids_3d + + with torch.no_grad(): + output_2 = model(**inputs_dict).to_tuple()[0] + + assert output_1.shape == output_2.shape + assert torch.allclose(output_1, output_2, atol=1e-5), ( + "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) " + "are not equal as them as 2d inputs" + ) + + +class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): + ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] + pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" + subfolder = "transformer" + pass + + +class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height=4, width=4): + """Override to support dynamic height/width for compilation tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 8 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + + +class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height=4, width=4): + """Override to support dynamic height/width for LoRA hotswap tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 48 + embedding_dim = 32 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } From bffa3a9754a11cfed0a41184a8dee0933cdbccf8 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 14 Nov 2025 15:48:19 +0530 Subject: [PATCH 02/12] update --- tests/conftest.py | 14 + tests/models/testing_utils/__init__.py | 35 + tests/models/testing_utils/attention.py | 180 ++++ tests/models/testing_utils/common.py | 376 ++++++-- tests/models/testing_utils/compile.py | 10 +- tests/models/testing_utils/hub.py | 1 - tests/models/testing_utils/ip_adapter.py | 205 +++++ tests/models/testing_utils/lora.py | 220 +++++ tests/models/testing_utils/memory.py | 443 ++++++++++ tests/models/testing_utils/offloading.py | 0 tests/models/testing_utils/quantization.py | 833 ++++++++++++++++++ tests/models/testing_utils/single_file.py | 13 +- tests/models/testing_utils/training.py | 224 +++++ .../test_models_transformer_flux_.py | 190 +++- tests/quantization/gguf/test_gguf.py | 6 +- tests/testing_utils.py | 123 ++- 16 files changed, 2752 insertions(+), 121 deletions(-) create mode 100644 tests/models/testing_utils/memory.py delete mode 100644 tests/models/testing_utils/offloading.py create mode 100644 tests/models/testing_utils/quantization.py diff --git a/tests/conftest.py b/tests/conftest.py index fd76d1c84ee7..3744de27f3b2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,6 +32,20 @@ def pytest_configure(config): config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources") + config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality") + config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality") + config.addinivalue_line("markers", "training: marks tests for training functionality") + config.addinivalue_line("markers", "attention: marks tests for attention processor functionality") + config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality") + config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality") + config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality") + config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality") + config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading") + config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality") + config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality") + config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality") + config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality") + config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality") def pytest_addoption(parser): diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index 7c64e0a04bac..7955982ca91c 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -1,2 +1,37 @@ +from .attention import AttentionTesterMixin from .common import ModelTesterMixin +from .compile import TorchCompileTesterMixin +from .ip_adapter import IPAdapterTesterMixin +from .lora import LoraTesterMixin +from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin +from .quantization import ( + BitsAndBytesTesterMixin, + GGUFTesterMixin, + ModelOptTesterMixin, + QuantizationTesterMixin, + QuantoTesterMixin, + TorchAoTesterMixin, +) from .single_file import SingleFileTesterMixin +from .training import TrainingTesterMixin + + +__all__ = [ + "AttentionTesterMixin", + "BitsAndBytesTesterMixin", + "CPUOffloadTesterMixin", + "GGUFTesterMixin", + "GroupOffloadTesterMixin", + "IPAdapterTesterMixin", + "LayerwiseCastingTesterMixin", + "LoraTesterMixin", + "MemoryTesterMixin", + "ModelOptTesterMixin", + "ModelTesterMixin", + "QuantizationTesterMixin", + "QuantoTesterMixin", + "SingleFileTesterMixin", + "TorchAoTesterMixin", + "TorchCompileTesterMixin", + "TrainingTesterMixin", +] diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index e69de29bb2d1..be88fd309b1f 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -0,0 +1,180 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers.models.attention import AttentionModuleMixin +from diffusers.models.attention_processor import ( + AttnProcessor, +) + +from ...testing_utils import is_attention, require_accelerator, torch_device + + +@is_attention +@require_accelerator +class AttentionTesterMixin: + """ + Mixin class for testing attention processor and module functionality on models. + + Tests functionality from AttentionModuleMixin including: + - Attention processor management (set/get) + - QKV projection fusion/unfusion + - Attention backends (XFormers, NPU, etc.) + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + - base_precision: Tolerance for floating point comparisons (default: 1e-3) + - uses_custom_attn_processor: Whether model uses custom attention processors (default: False) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: attention + Use `pytest -m "not attention"` to skip these tests + """ + + base_precision = 1e-3 + + def test_fuse_unfuse_qkv_projections(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + if not hasattr(model, "fuse_qkv_projections"): + pytest.skip("Model does not support QKV projection fusion.") + + # Get output before fusion + with torch.no_grad(): + output_before_fusion = model(**inputs_dict) + if isinstance(output_before_fusion, dict): + output_before_fusion = output_before_fusion.to_tuple()[0] + + # Fuse projections + model.fuse_qkv_projections() + + # Verify fusion occurred by checking for fused attributes + has_fused_projections = False + for module in model.modules(): + if isinstance(module, AttentionModuleMixin): + if hasattr(module, "to_qkv") or hasattr(module, "to_kv"): + has_fused_projections = True + assert module.fused_projections, "fused_projections flag should be True" + break + + if has_fused_projections: + # Get output after fusion + with torch.no_grad(): + output_after_fusion = model(**inputs_dict) + if isinstance(output_after_fusion, dict): + output_after_fusion = output_after_fusion.to_tuple()[0] + + # Verify outputs match + assert torch.allclose( + output_before_fusion, output_after_fusion, atol=self.base_precision + ), "Output should not change after fusing projections" + + # Unfuse projections + model.unfuse_qkv_projections() + + # Verify unfusion occurred + for module in model.modules(): + if isinstance(module, AttentionModuleMixin): + assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing" + assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing" + assert not module.fused_projections, "fused_projections flag should be False" + + # Get output after unfusion + with torch.no_grad(): + output_after_unfusion = model(**inputs_dict) + if isinstance(output_after_unfusion, dict): + output_after_unfusion = output_after_unfusion.to_tuple()[0] + + # Verify outputs still match + assert torch.allclose( + output_before_fusion, output_after_unfusion, atol=self.base_precision + ), "Output should match original after unfusing projections" + + def test_get_set_processor(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.to(torch_device) + + # Check if model has attention processors + if not hasattr(model, "attn_processors"): + pytest.skip("Model does not have attention processors.") + + # Test getting processors + processors = model.attn_processors + assert isinstance(processors, dict), "attn_processors should return a dict" + assert len(processors) > 0, "Model should have at least one attention processor" + + # Test that all processors can be retrieved via get_processor + for module in model.modules(): + if isinstance(module, AttentionModuleMixin): + processor = module.get_processor() + assert processor is not None, "get_processor should return a processor" + + # Test setting a new processor + new_processor = AttnProcessor() + module.set_processor(new_processor) + retrieved_processor = module.get_processor() + assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set" + + def test_attention_processor_dict(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.to(torch_device) + + if not hasattr(model, "set_attn_processor"): + pytest.skip("Model does not support setting attention processors.") + + # Get current processors + current_processors = model.attn_processors + + # Create a dict of new processors + new_processors = {key: AttnProcessor() for key in current_processors.keys()} + + # Set processors using dict + model.set_attn_processor(new_processors) + + # Verify all processors were set + updated_processors = model.attn_processors + for key in current_processors.keys(): + assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor" + + def test_attention_processor_count_mismatch_raises_error(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.to(torch_device) + + if not hasattr(model, "set_attn_processor"): + pytest.skip("Model does not support setting attention processors.") + + # Get current processors + current_processors = model.attn_processors + + # Create a dict with wrong number of processors + wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()} + + # Verify error is raised + with pytest.raises(ValueError) as exc_info: + model.set_attn_processor(wrong_processors) + + assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index ec4245af1249..e4697f6200f6 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -13,15 +13,94 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import os import tempfile from typing import Dict, List, Tuple import pytest import torch +from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size + +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant +from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator from ...testing_utils import torch_device +def compute_module_persistent_sizes( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, +): + """ + Compute the size of each submodule of a given model (parameters + persistent buffers). + """ + if dtype is not None: + dtype = _get_proper_dtype(dtype) + dtype_size = dtype_byte_size(dtype) + if special_dtypes is not None: + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} + module_sizes = defaultdict(int) + + module_list = [] + + module_list = named_persistent_module_tensors(model, recurse=True) + + for name, tensor in module_list: + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes_size[name] + elif dtype is None: + size = tensor.numel() * dtype_byte_size(tensor.dtype) + elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + # According to the code in set_module_tensor_to_device, these types won't be converted + # so use their original size here + size = tensor.numel() * dtype_byte_size(tensor.dtype) + else: + size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) + name_parts = name.split(".") + for idx in range(len(name_parts) + 1): + module_sizes[".".join(name_parts[:idx])] += size + + return module_sizes + + +def calculate_expected_num_shards(index_map_path): + """ + Calculate expected number of shards from index file. + + Args: + index_map_path: Path to the sharded checkpoint index file + + Returns: + int: Expected number of shards + """ + with open(index_map_path) as f: + weight_map_dict = json.load(f)["weight_map"] + first_key = list(weight_map_dict.keys())[0] + weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors + expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0]) + return expected_num_shards + + +def check_device_map_is_respected(model, device_map): + for param_name, param in model.named_parameters(): + # Find device in device_map + while len(param_name) > 0 and param_name not in device_map: + param_name = ".".join(param_name.split(".")[:-1]) + if param_name not in device_map: + raise ValueError("device map is incomplete, it does not contain any device for `param_name`.") + + param_device = device_map[param_name] + if param_device in ["cpu", "disk"]: + assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}" + else: + assert param.device == torch.device( + param_device + ), f"Expected device {param_device} for {param_name}, got {param.device}" + + class ModelTesterMixin: """ Base mixin class for model testing with common test methods. @@ -38,6 +117,7 @@ class ModelTesterMixin: model_class = None base_precision = 1e-3 + model_split_percents = [0.5, 0.7] def get_init_dict(self): raise NotImplementedError("get_init_dict must be implemented by subclasses. ") @@ -47,27 +127,7 @@ def get_dummy_inputs(self): "get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict." ) - def check_device_map_is_respected(self, model, device_map): - """Helper method to check if device map is correctly applied to model parameters.""" - for param_name, param in model.named_parameters(): - # Find device in device_map - while len(param_name) > 0 and param_name not in device_map: - param_name = ".".join(param_name.split(".")[:-1]) - if param_name not in device_map: - raise ValueError("device map is incomplete, it does not contain any device for `param_name`.") - - param_device = device_map[param_name] - if param_device in ["cpu", "disk"]: - assert param.device == torch.device( - "meta" - ), f"Expected device 'meta' for {param_name}, got {param.device}" - else: - assert param.device == torch.device( - param_device - ), f"Expected device {param_device} for {param_name}, got {param.device}" - def test_from_save_pretrained(self, expected_max_diff=5e-5): - """Test that model can be saved and loaded with save_pretrained/from_pretrained.""" model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -77,6 +137,14 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5): new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) + # check if all parameters shape are the same + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + assert ( + param_1.shape == param_2.shape + ), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" + with torch.no_grad(): image = model(**self.get_dummy_inputs()) @@ -94,7 +162,6 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5): ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): - """Test save_pretrained/from_pretrained with variant parameter.""" model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -128,7 +195,6 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" def test_from_save_pretrained_dtype(self): - """Test save_pretrained/from_pretrained preserves dtype correctly.""" model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -145,13 +211,13 @@ def test_from_save_pretrained_dtype(self): hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None ): + # When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None new_model = self.model_class.from_pretrained( tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype ) assert new_model.dtype == dtype def test_determinism(self, expected_max_diff=1e-5): - """Test that model outputs are deterministic across multiple forward passes.""" model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -180,7 +246,6 @@ def test_determinism(self, expected_max_diff=1e-5): ), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}" def test_output(self, expected_output_shape=None): - """Test that model produces output with expected shape.""" model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -197,46 +262,7 @@ def test_output(self, expected_output_shape=None): output.shape == expected_output_shape ), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" - def test_model_from_pretrained(self): - """Test that model loaded from pretrained matches original model.""" - model = self.model_class(**self.get_init_dict()) - model.to(torch_device) - model.eval() - - # test if the model can be loaded from the config - # and has all the expected shape - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname, safe_serialization=False) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) - new_model.eval() - - # check if all parameters shape are the same - for param_name in model.state_dict().keys(): - param_1 = model.state_dict()[param_name] - param_2 = new_model.state_dict()[param_name] - assert ( - param_1.shape == param_2.shape - ), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" - - with torch.no_grad(): - output_1 = model(**self.get_dummy_inputs()) - - if isinstance(output_1, dict): - output_1 = output_1.to_tuple()[0] - - output_2 = new_model(**self.get_dummy_inputs()) - - if isinstance(output_2, dict): - output_2 = output_2.to_tuple()[0] - - assert ( - output_1.shape == output_2.shape - ), f"Output shape mismatch. Original: {output_1.shape}, loaded: {output_2.shape}" - def test_outputs_equivalence(self): - """Test that dict and tuple outputs are equivalent.""" - def set_nan_tensor_to_zero(t): # Temporary fallback until `aten::_index_put_impl_` is implemented in mps # Track progress in https://github.com/pytorch/pytorch/issues/77764 @@ -276,29 +302,213 @@ def recursive_check(tuple_object, dict_object): recursive_check(outputs_tuple, outputs_dict) def test_model_config_to_json_string(self): - """Test model config can be serialized to JSON string.""" model = self.model_class(**self.get_init_dict()) json_string = model.config.to_json_string() assert isinstance(json_string, str), "Config to_json_string should return a string" assert len(json_string) > 0, "JSON string should not be empty" - def test_keep_in_fp32_modules(self): - r""" - A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16 - Also ensures if inference works. - """ - if not hasattr(self.model_class, "_keep_in_fp32_modules"): - pytest.skip("Model does not have _keep_in_fp32_modules") - - fp32_modules = self.model_class._keep_in_fp32_modules - - for torch_dtype in [torch.bfloat16, torch.float16]: - model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, torch_dtype=torch_dtype).to( - torch_device - ) - for name, param in model.named_parameters(): - if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): - assert param.data == torch.float32 - else: - assert param.data == torch_dtype + @require_accelerator + @pytest.mark.skipif(torch_device not in ["cuda", "xpu"]) + def test_from_save_pretrained_float16_bfloat16(self): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + fp32_modules = model._keep_in_fp32_modules + + with tempfile.TemporaryDirectory() as tmp_dir: + for torch_dtype in [torch.bfloat16, torch.float16]: + model.to(torch_dtype).save_pretrained(tmp_dir) + model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device) + + for name, param in model_loaded.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): + assert param.data.dtype == torch.float32 + else: + assert param.data.dtype == torch_dtype + + with torch.no_grad(): + output = model(**get_dummy_inputs()) + output_loaded = model_loaded(**get_dummy_inputs()) + + assert torch.allclose( + output, output_loaded, atol=1e-4 + ), f"Loaded model output differs for {torch_dtype}" + + @require_accelerator + def test_sharded_checkpoints(self): + torch.manual_seed(0) + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + base_output = model(**inputs_dict) + + model_size = compute_module_persistent_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" + + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) + actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) + assert ( + actual_num_shards == expected_num_shards + ), f"Expected {expected_num_shards} shards, got {actual_num_shards}" + + new_model = self.model_class.from_pretrained(tmp_dir).eval() + new_model = new_model.to(torch_device) + + torch.manual_seed(0) + inputs_dict_new = self.get_dummy_inputs() + new_output = new_model(**inputs_dict_new) + + assert torch.allclose( + base_output[0], new_output[0], atol=1e-5 + ), "Output should match after sharded save/load" + + @require_accelerator + def test_sharded_checkpoints_with_variant(self): + torch.manual_seed(0) + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + base_output = model(**inputs_dict) + + model_size = compute_module_persistent_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small + variant = "fp16" + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) + + index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + assert os.path.exists( + os.path.join(tmp_dir, index_filename) + ), f"Variant index file {index_filename} should exist" + + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename)) + actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) + assert ( + actual_num_shards == expected_num_shards + ), f"Expected {expected_num_shards} shards, got {actual_num_shards}" + + new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval() + new_model = new_model.to(torch_device) + + torch.manual_seed(0) + inputs_dict_new = self.get_dummy_inputs() + new_output = new_model(**inputs_dict_new) + + assert torch.allclose( + base_output[0], new_output[0], atol=1e-5 + ), "Output should match after variant sharded save/load" + + @require_accelerator + def test_sharded_checkpoints_with_parallel_loading(self): + import time + + from diffusers.utils import constants + + torch.manual_seed(0) + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + base_output = model(**inputs_dict) + + model_size = compute_module_persistent_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small + + # Save original values to restore after test + original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING + original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None) + + try: + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" + + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) + actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) + assert ( + actual_num_shards == expected_num_shards + ), f"Expected {expected_num_shards} shards, got {actual_num_shards}" + + # Load without parallel loading + constants.HF_ENABLE_PARALLEL_LOADING = False + start_time = time.time() + model_sequential = self.model_class.from_pretrained(tmp_dir).eval() + sequential_load_time = time.time() - start_time + model_sequential = model_sequential.to(torch_device) + + torch.manual_seed(0) + + # Load with parallel loading + constants.HF_ENABLE_PARALLEL_LOADING = True + constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2 + + start_time = time.time() + model_parallel = self.model_class.from_pretrained(tmp_dir).eval() + parallel_load_time = time.time() - start_time + model_parallel = model_parallel.to(torch_device) + + torch.manual_seed(0) + inputs_dict_parallel = self.get_dummy_inputs() + output_parallel = model_parallel(**inputs_dict_parallel) + + assert torch.allclose( + base_output[0], output_parallel[0], atol=1e-5 + ), "Output should match with parallel loading" + + # Verify parallel loading is faster or at least not significantly slower + # For small test models, the difference might be negligible or even slightly slower due to overhead + # so we just check that parallel loading completed successfully and outputs match + assert ( + parallel_load_time < sequential_load_time + ), f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s" + finally: + # Restore original values + constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading + if original_parallel_workers is not None: + constants.HF_PARALLEL_WORKERS = original_parallel_workers + + @require_torch_multi_accelerator + def test_model_parallelism(self): + if self.model_class._no_split_modules is None: + pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") + + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will be on GPU 0 and GPU 1 + assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs" + + check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + assert torch.allclose( + base_output[0], new_output[0], atol=1e-5 + ), "Output should match with model parallelism" diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py index 2c083176c585..8ff4c097b404 100644 --- a/tests/models/testing_utils/compile.py +++ b/tests/models/testing_utils/compile.py @@ -43,24 +43,24 @@ class TorchCompileTesterMixin: Expected methods to be implemented by subclasses: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: compile + Use `pytest -m "not compile"` to skip these tests """ different_shapes_for_compilation = None def setup_method(self): - """Setup before each test method.""" torch.compiler.reset() gc.collect() backend_empty_cache(torch_device) def teardown_method(self): - """Cleanup after each test method.""" torch.compiler.reset() gc.collect() backend_empty_cache(torch_device) def test_torch_compile_recompilation_and_graph_break(self): - """Test that model compiles without graph breaks and doesn't recompile unnecessarily.""" init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() @@ -77,7 +77,6 @@ def test_torch_compile_recompilation_and_graph_break(self): _ = model(**inputs_dict) def test_torch_compile_repeated_blocks(self): - """Test compilation of repeated blocks if model supports it.""" if self.model_class._repeated_blocks is None: pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.") @@ -101,7 +100,6 @@ def test_torch_compile_repeated_blocks(self): _ = model(**inputs_dict) def test_compile_with_group_offloading(self): - """Test that compilation works with group offloading enabled.""" if not self.model_class._supports_group_offloading: pytest.skip("Model does not support group offloading.") @@ -128,7 +126,6 @@ def test_compile_with_group_offloading(self): _ = model(**inputs_dict) def test_compile_on_different_shapes(self): - """Test dynamic compilation on different input shapes.""" if self.different_shapes_for_compilation is None: pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") torch.fx.experimental._config.use_duck_shape = False @@ -144,7 +141,6 @@ def test_compile_on_different_shapes(self): _ = model(**inputs_dict) def test_compile_works_with_aot(self): - """Test that model works with ahead-of-time compilation and packaging.""" from torch._inductor.package import load_package init_dict = self.get_init_dict() diff --git a/tests/models/testing_utils/hub.py b/tests/models/testing_utils/hub.py index cbaded9ffffc..e20c3ab1630e 100644 --- a/tests/models/testing_utils/hub.py +++ b/tests/models/testing_utils/hub.py @@ -18,7 +18,6 @@ import pytest import torch -from huggingface_hub import ModelCard, delete_repo from huggingface_hub.utils import is_jinja_available from ...others.test_utils import TOKEN, USER, is_staging_test diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py index e69de29bb2d1..be079df0614d 100644 --- a/tests/models/testing_utils/ip_adapter.py +++ b/tests/models/testing_utils/ip_adapter.py @@ -0,0 +1,205 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import torch + +from diffusers.models.attention_processor import IPAdapterAttnProcessor + +from ...testing_utils import is_ip_adapter, torch_device + + +def create_ip_adapter_state_dict(model): + """ + Create a dummy IP Adapter state dict for testing. + + Args: + model: The model to create IP adapter weights for + + Returns: + dict: IP adapter state dict with to_k_ip and to_v_ip weights + """ + ip_state_dict = {} + key_id = 1 + + for name in model.attn_processors.keys(): + # Skip self-attention processors + cross_attention_dim = getattr(model.config, "cross_attention_dim", None) + if cross_attention_dim is None: + continue + + # Get hidden size based on model architecture + hidden_size = getattr(model.config, "hidden_size", cross_attention_dim) + + # Create IP adapter processor to get state dict structure + sd = IPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).state_dict() + + ip_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + } + ) + key_id += 2 + + return {"ip_adapter": ip_state_dict} + + +def check_if_ip_adapter_correctly_set(model) -> bool: + """ + Check if IP Adapter processors are correctly set in the model. + + Args: + model: The model to check + + Returns: + bool: True if IP Adapter is correctly set, False otherwise + """ + for module in model.attn_processors.values(): + if isinstance(module, IPAdapterAttnProcessor): + return True + return False + + +@is_ip_adapter +class IPAdapterTesterMixin: + """ + Mixin class for testing IP Adapter functionality on models. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: ip_adapter + Use `pytest -m "not ip_adapter"` to skip these tests + """ + + def create_ip_adapter_state_dict(self, model): + raise NotImplementedError("child class must implement method to create IPAdapter State Dict") + + def test_load_ip_adapter(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + torch.manual_seed(0) + output_no_adapter = model(**inputs_dict, return_dict=False)[0] + + # Create dummy IP adapter state dict + ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) + + # Load IP adapter + model._load_ip_adapter_weights([ip_adapter_state_dict]) + assert check_if_ip_adapter_correctly_set(model), "IP Adapter processors not set correctly" + + torch.manual_seed(0) + # Create dummy image embeds for IP adapter + cross_attention_dim = getattr(model.config, "cross_attention_dim", 32) + image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device) + inputs_dict_with_adapter = inputs_dict.copy() + inputs_dict_with_adapter["image_embeds"] = image_embeds + + outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0] + + assert not torch.allclose( + output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4 + ), "Output should differ with IP Adapter enabled" + + def test_ip_adapter_scale(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + # Create and load dummy IP adapter state dict + ip_adapter_state_dict = create_ip_adapter_state_dict(model) + model._load_ip_adapter_weights([ip_adapter_state_dict]) + + # Test scale = 0.0 (no effect) + model.set_ip_adapter_scale(0.0) + torch.manual_seed(0) + output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0] + + # Test scale = 1.0 (full effect) + model.set_ip_adapter_scale(1.0) + torch.manual_seed(0) + output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0] + + # Outputs should differ with different scales + assert not torch.allclose( + output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4 + ), "Output should differ with different IP Adapter scales" + + def test_unload_ip_adapter(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + # Save original processors + original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()} + + # Create and load IP adapter + ip_adapter_state_dict = create_ip_adapter_state_dict(model) + model._load_ip_adapter_weights([ip_adapter_state_dict]) + assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set" + + # Unload IP adapter + model.unload_ip_adapter() + assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded" + + # Verify processors are restored + current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()} + assert original_processors == current_processors, "Processors should be restored after unload" + + def test_ip_adapter_save_load(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + # Create and load IP adapter + ip_adapter_state_dict = self.create_ip_adapter_state_dict() + model._load_ip_adapter_weights([ip_adapter_state_dict]) + + torch.manual_seed(0) + output_before_save = model(**inputs_dict, return_dict=False)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + # Save the IP adapter weights + save_path = os.path.join(tmpdir, "ip_adapter.safetensors") + import safetensors.torch + + safetensors.torch.save_file(ip_adapter_state_dict["ip_adapter"], save_path) + + # Unload and reload + model.unload_ip_adapter() + assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded" + + # Reload from saved file + loaded_state_dict = {"ip_adapter": safetensors.torch.load_file(save_path)} + model._load_ip_adapter_weights([loaded_state_dict]) + assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be loaded" + + torch.manual_seed(0) + output_after_load = model(**inputs_dict_with_adapter, return_dict=False)[0] + + # Outputs should match before and after save/load + assert torch.allclose( + output_before_save, output_after_load, atol=1e-4, rtol=1e-4 + ), "Output should match before and after save/load" diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index e69de29bb2d1..dfc3bd2955e5 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -0,0 +1,220 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile + +import pytest +import safetensors.torch +import torch + +from diffusers.utils.testing_utils import check_if_dicts_are_equal + +from ...testing_utils import is_lora, require_peft_backend, torch_device + + +def check_if_lora_correctly_set(model) -> bool: + """ + Check if LoRA layers are correctly set in the model. + + Args: + model: The model to check + + Returns: + bool: True if LoRA is correctly set, False otherwise + """ + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return True + return False + + +@is_lora +@require_peft_backend +class LoraTesterMixin: + """ + Mixin class for testing LoRA/PEFT functionality on models. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: lora + Use `pytest -m "not lora"` to skip these tests + """ + + def setup_method(self): + from diffusers.loaders.peft import PeftAdapterMixin + + if not issubclass(self.model_class, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).") + + def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False): + from peft import LoraConfig + from peft.utils import get_peft_model_state_dict + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + torch.manual_seed(0) + output_no_lora = model(**inputs_dict, return_dict=False)[0] + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + model.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + torch.manual_seed(0) + outputs_with_lora = model(**inputs_dict, return_dict=False)[0] + + assert not torch.allclose( + output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4 + ), "Output should differ with LoRA enabled" + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + assert os.path.isfile( + os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + ), "LoRA weights file not created" + + state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") + + for k in state_dict_loaded: + loaded_v = state_dict_loaded[k] + retrieved_v = state_dict_retrieved[k].to(loaded_v.device) + assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}" + + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload" + + torch.manual_seed(0) + outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] + + assert not torch.allclose( + output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4 + ), "Output should differ with LoRA enabled" + assert torch.allclose( + outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4 + ), "Outputs should match before and after save/load" + + def test_lora_wrong_adapter_name_raises_error(self): + from peft import LoraConfig + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + with tempfile.TemporaryDirectory() as tmpdir: + wrong_name = "foo" + with pytest.raises(ValueError) as exc_info: + model.save_lora_adapter(tmpdir, adapter_name=wrong_name) + + assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value) + + def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False): + from peft import LoraConfig + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + model.add_adapter(denoiser_lora_config) + metadata = model.peft_config["default"].to_dict() + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file), "LoRA weights file not created" + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + parsed_metadata = model.peft_config["default_0"].to_dict() + check_if_dicts_are_equal(metadata, parsed_metadata) + + def test_lora_adapter_wrong_metadata_raises_error(self): + from peft import LoraConfig + + from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file), "LoRA weights file not created" + + # Perturb the metadata in the state dict + loaded_state_dict = safetensors.torch.load_file(model_file) + metadata = {"format": "pt"} + lora_adapter_metadata = denoiser_lora_config.to_dict() + lora_adapter_metadata.update({"foo": 1, "bar": 2}) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + + with pytest.raises(TypeError) as exc_info: + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + assert "`LoraConfig` class could not be instantiated" in str(exc_info.value) diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py new file mode 100644 index 000000000000..d06a125dc600 --- /dev/null +++ b/tests/models/testing_utils/memory.py @@ -0,0 +1,443 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import glob +import inspect +import tempfile +from functools import wraps + +import pytest +import torch +from accelerate.utils.modeling import compute_module_sizes + +from diffusers.utils.testing_utils import _check_safetensors_serialization +from diffusers.utils.torch_utils import get_torch_cuda_device_capability + +from ...testing_utils import ( + backend_empty_cache, + backend_max_memory_allocated, + backend_reset_peak_memory_stats, + backend_synchronize, + is_cpu_offload, + is_group_offload, + is_memory, + require_accelerator, + torch_device, +) +from .common import check_device_map_is_respected + + +def cast_maybe_tensor_dtype(inputs_dict, from_dtype, to_dtype): + """Helper to cast tensor inputs from one dtype to another.""" + for key, value in inputs_dict.items(): + if isinstance(value, torch.Tensor) and value.dtype == from_dtype: + inputs_dict[key] = value.to(to_dtype) + return inputs_dict + + +def require_offload_support(func): + """ + Decorator to skip tests if model doesn't support offloading (requires _no_split_modules). + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.model_class._no_split_modules is None: + pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") + return func(self, *args, **kwargs) + + return wrapper + + +def require_group_offload_support(func): + """ + Decorator to skip tests if model doesn't support group offloading. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.model_class._supports_group_offloading: + pytest.skip("Model does not support group offloading.") + return func(self, *args, **kwargs) + + return wrapper + + +@is_cpu_offload +class CPUOffloadTesterMixin: + """ + Mixin class for testing CPU offloading functionality. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + - model_split_percents: List of percentages for splitting model across devices + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cpu_offload + Use `pytest -m "not cpu_offload"` to skip these tests + """ + + model_split_percents = [0.5, 0.7] + + @require_offload_support + def test_cpu_offload(self): + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU" + + check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + assert torch.allclose( + base_output[0], new_output[0], atol=1e-5 + ), "Output should match with CPU offloading" + + @require_offload_support + def test_disk_offload_without_safetensors(self): + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + max_size = int(self.model_split_percents[0] * model_size) + # Force disk offload by setting very small CPU memory + max_memory = {0: max_size, "cpu": int(0.1 * max_size)} + + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir, safe_serialization=False) + # This errors out because it's missing an offload folder + with pytest.raises(ValueError): + new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + new_model = self.model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + ) + + check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading" + + @require_offload_support + def test_disk_offload_with_safetensors(self): + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + max_size = int(self.model_split_percents[0] * model_size) + max_memory = {0: max_size, "cpu": max_size} + new_model = self.model_class.from_pretrained( + tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory + ) + + check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + assert torch.allclose( + base_output[0], new_output[0], atol=1e-5 + ), "Output should match with disk offloading (safetensors)" + + +@is_group_offload +class GroupOffloadTesterMixin: + """ + Mixin class for testing group offloading functionality. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: group_offload + Use `pytest -m "not group_offload"` to skip these tests + """ + + @require_group_offload_support + def test_group_offloading(self, record_stream=False): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + torch.manual_seed(0) + + @torch.no_grad() + def run_forward(model): + assert all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in model.modules() + if hasattr(module, "_diffusers_hook") + ), "Group offloading hook should be set" + model.eval() + return model(**inputs_dict)[0] + + model = self.model_class(**init_dict) + + model.to(torch_device) + output_without_group_offloading = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) + output_with_group_offloading1 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) + output_with_group_offloading2 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="leaf_level") + output_with_group_offloading3 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload( + torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream + ) + output_with_group_offloading4 = run_forward(model) + + assert torch.allclose( + output_without_group_offloading, output_with_group_offloading1, atol=1e-5 + ), "Output should match with block-level offloading" + assert torch.allclose( + output_without_group_offloading, output_with_group_offloading2, atol=1e-5 + ), "Output should match with non-blocking block-level offloading" + assert torch.allclose( + output_without_group_offloading, output_with_group_offloading3, atol=1e-5 + ), "Output should match with leaf-level offloading" + assert torch.allclose( + output_without_group_offloading, output_with_group_offloading4, atol=1e-5 + ), "Output should match with leaf-level offloading with stream" + + @require_group_offload_support + @torch.no_grad() + def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"): + torch.manual_seed(0) + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + + model.to(torch_device) + model.eval() + _ = model(**inputs_dict)[0] + + torch.manual_seed(0) + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + storage_dtype, compute_dtype = torch.float16, torch.float32 + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) + model = self.model_class(**init_dict) + model.eval() + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1} + model.enable_group_offload( + torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs + ) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + _ = model(**inputs_dict)[0] + + @require_group_offload_support + @torch.no_grad() + @torch.inference_mode() + def test_group_offloading_with_disk(self, offload_type="block_level", record_stream=False, atol=1e-5): + def _has_generator_arg(model): + sig = inspect.signature(model.forward) + params = sig.parameters + return "generator" in params + + def _run_forward(model, inputs_dict): + accepts_generator = _has_generator_arg(model) + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + torch.manual_seed(0) + return model(**inputs_dict)[0] + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + torch.manual_seed(0) + model = self.model_class(**init_dict) + + model.eval() + model.to(torch_device) + output_without_group_offloading = _run_forward(model, inputs_dict) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.eval() + + num_blocks_per_group = None if offload_type == "leaf_level" else 1 + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group} + with tempfile.TemporaryDirectory() as tmpdir: + model.enable_group_offload( + torch_device, + offload_type=offload_type, + offload_to_disk_path=tmpdir, + use_stream=True, + record_stream=record_stream, + **additional_kwargs, + ) + has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") + assert has_safetensors, "No safetensors found in the directory." + + # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic + # in nature. So, skip it. + if offload_type != "leaf_level": + is_correct, extra_files, missing_files = _check_safetensors_serialization( + module=model, + offload_to_disk_path=tmpdir, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + ) + if not is_correct: + if extra_files: + raise ValueError(f"Found extra files: {', '.join(extra_files)}") + elif missing_files: + raise ValueError(f"Following files are missing: {', '.join(missing_files)}") + + output_with_group_offloading = _run_forward(model, inputs_dict) + assert torch.allclose( + output_without_group_offloading, output_with_group_offloading, atol=atol + ), "Output should match with disk-based group offloading" + + +class LayerwiseCastingTesterMixin: + """ + Mixin class for testing layerwise dtype casting for memory optimization. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + @torch.no_grad() + def test_layerwise_casting_memory(self): + MB_TOLERANCE = 0.2 + LEAST_COMPUTE_CAPABILITY = 8.0 + + def reset_memory_stats(): + gc.collect() + backend_synchronize(torch_device) + backend_empty_cache(torch_device) + backend_reset_peak_memory_stats(torch_device) + + def get_memory_usage(storage_dtype, compute_dtype): + torch.manual_seed(0) + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) + model = self.model_class(**config).eval() + model = model.to(torch_device, dtype=compute_dtype) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + + reset_memory_stats() + model(**inputs_dict) + model_memory_footprint = model.get_memory_footprint() + peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2 + + return model_memory_footprint, peak_inference_memory_allocated_mb + + fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32) + fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32) + fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage( + torch.float8_e4m3fn, torch.bfloat16 + ) + + compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None + assert ( + fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint + ), "Memory footprint should decrease with lower precision storage" + + # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. + # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. + if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY: + assert ( + fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory + ), "Peak memory should be lower with bf16 compute on newer GPUs" + + # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few + # bytes. This only happens for some models, so we allow a small tolerance. + # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32. + assert ( + fp8_e4m3_fp32_max_memory < fp32_max_memory + or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE + ), "Peak memory should be lower or within tolerance with fp8 storage" + + +@is_memory +@require_accelerator +class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin): + """ + Combined mixin class for all memory optimization tests including CPU/disk offloading, + group offloading, and layerwise dtype casting. + + This mixin inherits from: + - CPUOffloadTesterMixin: CPU and disk offloading tests + - GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level) + - LayerwiseCastingTesterMixin: Layerwise dtype casting tests + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + - model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7]) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: memory + Use `pytest -m "not memory"` to skip these tests + """ + + pass diff --git a/tests/models/testing_utils/offloading.py b/tests/models/testing_utils/offloading.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py new file mode 100644 index 000000000000..6fd2c6152fb2 --- /dev/null +++ b/tests/models/testing_utils/quantization.py @@ -0,0 +1,833 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile + +import pytest +import torch + +from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig +from diffusers.utils.import_utils import ( + is_bitsandbytes_available, + is_gguf_available, + is_nvidia_modelopt_available, + is_optimum_quanto_available, +) + +from ...testing_utils import ( + backend_empty_cache, + is_bitsandbytes, + is_gguf, + is_modelopt, + is_quanto, + is_torchao, + nightly, + require_accelerate, + require_accelerator, + require_bitsandbytes_version_greater, + require_gguf_version_greater_or_equal, + require_quanto, + require_torchao_version_greater_or_equal, + torch_device, +) + + +if is_nvidia_modelopt_available(): + import modelopt.torch.quantization as mtq + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + +if is_optimum_quanto_available(): + from optimum.quanto import QLinear + +if is_gguf_available(): + pass + +if is_torchao_available(): + + if is_torchao_version(">=", "0.9.0"): + pass + + +@require_accelerator +class QuantizationTesterMixin: + """ + Base mixin class providing common test implementations for quantization testing. + + Backend-specific mixins should: + 1. Implement _create_quantized_model(config_kwargs) + 2. Implement _verify_if_layer_quantized(name, module, config_kwargs) + 3. Define their config dict (e.g., BNB_CONFIGS, QUANTO_WEIGHT_TYPES, etc.) + 4. Use @pytest.mark.parametrize to create tests that call the common test methods below + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods in test classes: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + """ + Create a quantized model with the given config kwargs. + + Args: + config_kwargs: Quantization config parameters + **extra_kwargs: Additional kwargs to pass to from_pretrained (e.g., device_map, offload_folder) + """ + raise NotImplementedError("Subclass must implement _create_quantized_model") + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + raise NotImplementedError("Subclass must implement _verify_if_layer_quantized") + + def _is_module_quantized(self, module): + """ + Check if a module is quantized. Returns True if quantized, False otherwise. + Default implementation tries _verify_if_layer_quantized and catches exceptions. + Subclasses can override for more efficient checking. + """ + try: + self._verify_if_layer_quantized("", module, {}) + return True + except (AssertionError, AttributeError): + return False + + def _load_unquantized_model(self): + kwargs = getattr(self, "pretrained_model_kwargs", {}) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _test_quantization_num_parameters(self, config_kwargs): + model = self._load_unquantized_model() + num_params = model.num_parameters() + + model_quantized = self._create_quantized_model(config_kwargs) + num_params_quantized = model_quantized.num_parameters() + + assert ( + num_params == num_params_quantized + ), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + + def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): + model = self._load_unquantized_model() + mem = model.get_memory_footprint() + + model_quantized = self._create_quantized_model(config_kwargs) + mem_quantized = model_quantized.get_memory_footprint() + + ratio = mem / mem_quantized + assert ( + ratio >= expected_memory_reduction + ), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + + def _test_quantization_inference(self, config_kwargs): + model_quantized = self._create_quantized_model(config_kwargs) + + with torch.no_grad(): + inputs = self.get_dummy_inputs() + output = model_quantized(**inputs) + + if isinstance(output, tuple): + output = output[0] + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" + + def _test_quantization_dtype_assignment(self, config_kwargs): + model = self._create_quantized_model(config_kwargs) + + with pytest.raises(ValueError): + model.to(torch.float16) + + with pytest.raises(ValueError): + device_0 = f"{torch_device}:0" + model.to(device=device_0, dtype=torch.float16) + + with pytest.raises(ValueError): + model.float() + + with pytest.raises(ValueError): + model.half() + + model.to(torch_device) + + def _test_quantization_lora_inference(self, config_kwargs): + try: + from peft import LoraConfig + except ImportError: + pytest.skip("peft is not available") + + from diffusers.loaders.peft import PeftAdapterMixin + + if not issubclass(self.model_class, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__})") + + model = self._create_quantized_model(config_kwargs) + + lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + ) + model.add_adapter(lora_config) + + with torch.no_grad(): + inputs = self.get_dummy_inputs() + output = model(**inputs) + + if isinstance(output, tuple): + output = output[0] + assert output is not None, "Model output is None with LoRA" + assert not torch.isnan(output).any(), "Model output contains NaN with LoRA" + + def _test_quantization_serialization(self, config_kwargs): + model = self._create_quantized_model(config_kwargs) + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_pretrained(tmpdir, safe_serialization=True) + + model_loaded = self.model_class.from_pretrained(tmpdir) + + with torch.no_grad(): + inputs = self.get_dummy_inputs() + output = model_loaded(**inputs) + if isinstance(output, tuple): + output = output[0] + assert not torch.isnan(output).any(), "Loaded model output contains NaN" + + def _test_quantized_layers(self, config_kwargs): + model_fp = self._load_unquantized_model() + num_linear_layers = sum(1 for module in model_fp.modules() if isinstance(module, torch.nn.Linear)) + + model_quantized = self._create_quantized_model(config_kwargs) + + num_fp32_modules = 0 + if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules: + for name, module in model_quantized.named_modules(): + if isinstance(module, torch.nn.Linear): + if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules): + num_fp32_modules += 1 + + expected_quantized_layers = num_linear_layers - num_fp32_modules + + num_quantized_layers = 0 + for name, module in model_quantized.named_modules(): + if isinstance(module, torch.nn.Linear): + if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules: + if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules): + continue + self._verify_if_layer_quantized(name, module, config_kwargs) + num_quantized_layers += 1 + + assert ( + num_quantized_layers > 0 + ), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" + assert ( + num_quantized_layers == expected_quantized_layers + ), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + + def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): + """ + Test that modules specified in modules_to_not_convert are not quantized. + + Args: + config_kwargs: Base quantization config kwargs + modules_to_not_convert: List of module names to exclude from quantization + """ + # Create config with modules_to_not_convert + config_kwargs_with_exclusion = config_kwargs.copy() + config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert + + model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion) + + # Find a module that should NOT be quantized + found_excluded = False + for name, module in model_with_exclusion.named_modules(): + if isinstance(module, torch.nn.Linear): + # Check if this module is in the exclusion list + if any(excluded in name for excluded in modules_to_not_convert): + found_excluded = True + # This module should NOT be quantized + assert not self._is_module_quantized( + module + ), f"Module {name} should not be quantized but was found to be quantized" + + assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}" + + # Find a module that SHOULD be quantized (not in exclusion list) + found_quantized = False + for name, module in model_with_exclusion.named_modules(): + if isinstance(module, torch.nn.Linear): + # Check if this module is NOT in the exclusion list + if not any(excluded in name for excluded in modules_to_not_convert): + if self._is_module_quantized(module): + found_quantized = True + break + + assert found_quantized, "No quantized layers found outside of excluded modules" + + # Compare memory footprint with fully quantized model + model_fully_quantized = self._create_quantized_model(config_kwargs) + + mem_with_exclusion = model_with_exclusion.get_memory_footprint() + mem_fully_quantized = model_fully_quantized.get_memory_footprint() + + assert ( + mem_with_exclusion > mem_fully_quantized + ), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + + def _test_quantization_device_map(self, config_kwargs): + """ + Test that quantized models work correctly with device_map="auto". + + Args: + config_kwargs: Base quantization config kwargs + """ + model = self._create_quantized_model(config_kwargs, device_map="auto") + + # Verify device map is set + assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute" + assert model.hf_device_map is not None, "hf_device_map should not be None" + + # Verify inference works + with torch.no_grad(): + inputs = self.get_dummy_inputs() + output = model(**inputs) + if isinstance(output, tuple): + output = output[0] + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" + + +@is_bitsandbytes +@nightly +@require_accelerator +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +class BitsAndBytesTesterMixin(QuantizationTesterMixin): + """ + Mixin class for testing BitsAndBytes quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test + + Pytest mark: bitsandbytes + Use `pytest -m "not bitsandbytes"` to skip these tests + """ + + # Standard BnB configs tested for all models + # Subclasses can override to add or modify configs + BNB_CONFIGS = { + "4bit_nf4": { + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.float16, + }, + "4bit_fp4": { + "load_in_4bit": True, + "bnb_4bit_quant_type": "fp4", + "bnb_4bit_compute_dtype": torch.float16, + }, + "8bit": { + "load_in_8bit": True, + }, + } + + BNB_EXPECTED_MEMORY_REDUCTIONS = { + "4bit_nf4": 3.0, + "4bit_fp4": 3.0, + "8bit": 1.5, + } + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config = BitsAndBytesConfig(**config_kwargs) + kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() + kwargs["quantization_config"] = config + kwargs.update(extra_kwargs) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params + assert ( + module.weight.__class__ == expected_weight_class + ), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + def test_bnb_quantization_num_parameters(self, config_name): + self._test_quantization_num_parameters(self.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + def test_bnb_quantization_memory_footprint(self, config_name): + expected = self.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) + self._test_quantization_memory_footprint(self.BNB_CONFIGS[config_name], expected_memory_reduction=expected) + + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + def test_bnb_quantization_inference(self, config_name): + self._test_quantization_inference(self.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"]) + def test_bnb_quantization_dtype_assignment(self, config_name): + self._test_quantization_dtype_assignment(self.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"]) + def test_bnb_quantization_lora_inference(self, config_name): + self._test_quantization_lora_inference(self.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"]) + def test_bnb_quantization_serialization(self, config_name): + self._test_quantization_serialization(self.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + def test_bnb_quantized_layers(self, config_name): + self._test_quantized_layers(self.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + def test_bnb_quantization_config_serialization(self, config_name): + model = self._create_quantized_model(self.BNB_CONFIGS[config_name]) + + assert "quantization_config" in model.config, "Missing quantization_config" + _ = model.config["quantization_config"].to_dict() + _ = model.config["quantization_config"].to_diff_dict() + _ = model.config["quantization_config"].to_json_string() + + def test_bnb_original_dtype(self): + config_name = list(self.BNB_CONFIGS.keys())[0] + config_kwargs = self.BNB_CONFIGS[config_name] + + model = self._create_quantized_model(config_kwargs) + + assert "_pre_quantization_dtype" in model.config, "Missing _pre_quantization_dtype" + assert model.config["_pre_quantization_dtype"] in [ + torch.float16, + torch.float32, + torch.bfloat16, + ], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}" + + def test_bnb_keep_modules_in_fp32(self): + if not hasattr(self.model_class, "_keep_in_fp32_modules"): + pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules") + + config_kwargs = self.BNB_CONFIGS["4bit_nf4"] + + original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None) + self.model_class._keep_in_fp32_modules = ["proj_out"] + + try: + model = self._create_quantized_model(config_kwargs) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): + assert ( + module.weight.dtype == torch.float32 + ), f"Module {name} should be FP32 but is {module.weight.dtype}" + else: + assert ( + module.weight.dtype == torch.uint8 + ), f"Module {name} should be uint8 but is {module.weight.dtype}" + + with torch.no_grad(): + inputs = self.get_dummy_inputs() + _ = model(**inputs) + finally: + if original_fp32_modules is not None: + self.model_class._keep_in_fp32_modules = original_fp32_modules + + def test_bnb_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert(self.BNB_CONFIGS["4bit_nf4"], modules_to_exclude) + + def test_bnb_device_map(self): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"]) + + +@is_quanto +@nightly +@require_quanto +@require_accelerate +@require_accelerator +class QuantoTesterMixin(QuantizationTesterMixin): + """ + Mixin class for testing Quanto quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype + + Pytest mark: quanto + Use `pytest -m "not quanto"` to skip these tests + """ + + QUANTO_WEIGHT_TYPES = { + "float8": {"weights_dtype": "float8"}, + "int8": {"weights_dtype": "int8"}, + "int4": {"weights_dtype": "int4"}, + "int2": {"weights_dtype": "int2"}, + } + + QUANTO_EXPECTED_MEMORY_REDUCTIONS = { + "float8": 1.5, + "int8": 1.5, + "int4": 3.0, + "int2": 7.0, + } + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config = QuantoConfig(**config_kwargs) + kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() + kwargs["quantization_config"] = config + kwargs.update(extra_kwargs) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}" + + @pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys())) + def test_quanto_quantization_num_parameters(self, weight_type_name): + self._test_quantization_num_parameters(self.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys())) + def test_quanto_quantization_memory_footprint(self, weight_type_name): + expected = self.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2) + self._test_quantization_memory_footprint( + self.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected + ) + + @pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys())) + def test_quanto_quantization_inference(self, weight_type_name): + self._test_quantization_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"]) + def test_quanto_quantized_layers(self, weight_type_name): + self._test_quantized_layers(self.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"]) + def test_quanto_quantization_lora_inference(self, weight_type_name): + self._test_quantization_lora_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"]) + def test_quanto_quantization_serialization(self, weight_type_name): + self._test_quantization_serialization(self.QUANTO_WEIGHT_TYPES[weight_type_name]) + + def test_quanto_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert(self.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude) + + def test_quanto_device_map(self): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"]) + + +@is_torchao +@require_accelerator +@require_torchao_version_greater_or_equal("0.7.0") +class TorchAoTesterMixin(QuantizationTesterMixin): + """ + Mixin class for testing TorchAO quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - TORCHAO_QUANT_TYPES: Dict of quantization type strings to test + + Pytest mark: torchao + Use `pytest -m "not torchao"` to skip these tests + """ + + TORCHAO_QUANT_TYPES = { + "int4wo": {"quant_type": "int4_weight_only"}, + "int8wo": {"quant_type": "int8_weight_only"}, + "int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"}, + } + + TORCHAO_EXPECTED_MEMORY_REDUCTIONS = { + "int4wo": 3.0, + "int8wo": 1.5, + "int8dq": 1.5, + } + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config = TorchAoConfig(**config_kwargs) + kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() + kwargs["quantization_config"] = config + kwargs.update(extra_kwargs) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" + + @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys())) + def test_torchao_quantization_num_parameters(self, quant_type): + self._test_quantization_num_parameters(self.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys())) + def test_torchao_quantization_memory_footprint(self, quant_type): + expected = self.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2) + self._test_quantization_memory_footprint( + self.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected + ) + + @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys())) + def test_torchao_quantization_inference(self, quant_type): + self._test_quantization_inference(self.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"]) + def test_torchao_quantized_layers(self, quant_type): + self._test_quantized_layers(self.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"]) + def test_torchao_quantization_lora_inference(self, quant_type): + self._test_quantization_lora_inference(self.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"]) + def test_torchao_quantization_serialization(self, quant_type): + self._test_quantization_serialization(self.TORCHAO_QUANT_TYPES[quant_type]) + + def test_torchao_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + # Get a module name that exists in the model - this needs to be set by test classes + # For now, use a generic pattern that should work with transformer models + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert( + self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude + ) + + def test_torchao_device_map(self): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"]) + + +@is_gguf +@nightly +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFTesterMixin(QuantizationTesterMixin): + """ + Mixin class for testing GGUF quantization on models. + + Expected class attributes: + - model_class: The model class to test + - gguf_filename: URL or path to the GGUF file + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: gguf + Use `pytest -m "not gguf"` to skip these tests + """ + + gguf_filename = None + + def _create_quantized_model(self, config_kwargs=None, **extra_kwargs): + if config_kwargs is None: + config_kwargs = {"compute_dtype": torch.bfloat16} + + config = GGUFQuantizationConfig(**config_kwargs) + kwargs = { + "quantization_config": config, + "torch_dtype": config_kwargs.get("compute_dtype", torch.bfloat16), + } + kwargs.update(extra_kwargs) + return self.model_class.from_single_file(self.gguf_filename, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs=None): + from diffusers.quantizers.gguf.utils import GGUFParameter + + assert isinstance(module.weight, GGUFParameter), f"{name} weight is not GGUFParameter" + assert hasattr(module.weight, "quant_type"), f"{name} weight missing quant_type" + assert module.weight.dtype == torch.uint8, f"{name} weight dtype should be uint8" + + def test_gguf_quantization_inference(self): + self._test_quantization_inference({"compute_dtype": torch.bfloat16}) + + def test_gguf_keep_modules_in_fp32(self): + if not hasattr(self.model_class, "_keep_in_fp32_modules"): + pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules") + + _keep_in_fp32_modules = self.model_class._keep_in_fp32_modules + self.model_class._keep_in_fp32_modules = ["proj_out"] + + try: + model = self._create_quantized_model() + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): + assert module.weight.dtype == torch.float32, f"Module {name} should be FP32" + finally: + self.model_class._keep_in_fp32_modules = _keep_in_fp32_modules + + def test_gguf_quantization_dtype_assignment(self): + self._test_quantization_dtype_assignment({"compute_dtype": torch.bfloat16}) + + def test_gguf_quantization_lora_inference(self): + self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16}) + + def test_gguf_dequantize_model(self): + from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter + + model = self._create_quantized_model() + model.dequantize() + + def _check_for_gguf_linear(model): + has_children = list(model.children()) + if not has_children: + return + + for name, module in model.named_children(): + if isinstance(module, torch.nn.Linear): + assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear" + assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter" + + for name, module in model.named_children(): + _check_for_gguf_linear(module) + + def test_gguf_quantized_layers(self): + self._test_quantized_layers({"compute_dtype": torch.bfloat16}) + + +@is_modelopt +@nightly +@require_accelerator +@require_accelerate +@require_modelopt_version_greater_or_equal("0.33.1") +class ModelOptTesterMixin(QuantizationTesterMixin): + """ + Mixin class for testing NVIDIA ModelOpt quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test + + Pytest mark: modelopt + Use `pytest -m "not modelopt"` to skip these tests + """ + + MODELOPT_CONFIGS = { + "fp8": {"quant_type": "FP8"}, + "int8": {"quant_type": "INT8"}, + "int4": {"quant_type": "INT4"}, + } + + MODELOPT_EXPECTED_MEMORY_REDUCTIONS = { + "fp8": 1.5, + "int8": 1.5, + "int4": 3.0, + } + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config = NVIDIAModelOptConfig(**config_kwargs) + kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() + kwargs["quantization_config"] = config + kwargs.update(extra_kwargs) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)" + + @pytest.mark.parametrize("config_name", ["fp8"]) + def test_modelopt_quantization_num_parameters(self, config_name): + self._test_quantization_num_parameters(self.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys())) + def test_modelopt_quantization_memory_footprint(self, config_name): + expected = self.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) + self._test_quantization_memory_footprint( + self.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected + ) + + @pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys())) + def test_modelopt_quantization_inference(self, config_name): + self._test_quantization_inference(self.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["fp8"]) + def test_modelopt_quantization_dtype_assignment(self, config_name): + self._test_quantization_dtype_assignment(self.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["fp8"]) + def test_modelopt_quantization_lora_inference(self, config_name): + self._test_quantization_lora_inference(self.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["fp8"]) + def test_modelopt_quantization_serialization(self, config_name): + self._test_quantization_serialization(self.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["fp8"]) + def test_modelopt_quantized_layers(self, config_name): + self._test_quantized_layers(self.MODELOPT_CONFIGS[config_name]) + + def test_modelopt_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert(self.MODELOPT_CONFIGS["fp8"], modules_to_exclude) + + def test_modelopt_device_map(self): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"]) diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py index 561dc3c56703..cb0ae7a4e7e4 100644 --- a/tests/models/testing_utils/single_file.py +++ b/tests/models/testing_utils/single_file.py @@ -23,6 +23,7 @@ from ...testing_utils import ( backend_empty_cache, + is_single_file, nightly, require_torch_accelerator, torch_device, @@ -68,23 +69,23 @@ class SingleFileTesterMixin: - ckpt_path: Path or Hub path to the single file checkpoint - subfolder: (Optional) Subfolder within the repo - torch_dtype: (Optional) torch dtype to use for testing + + Pytest mark: single_file + Use `pytest -m "not single_file"` to skip these tests """ pretrained_model_name_or_path = None ckpt_path = None def setup_method(self): - """Setup before each test method.""" gc.collect() backend_empty_cache(torch_device) def teardown_method(self): - """Cleanup after each test method.""" gc.collect() backend_empty_cache(torch_device) def test_single_file_model_config(self): - """Test that config matches between pretrained and single file loading.""" pretrained_kwargs = {} single_file_kwargs = {} @@ -111,7 +112,6 @@ def test_single_file_model_config(self): ) def test_single_file_model_parameters(self): - """Test that parameters match between pretrained and single file loading.""" pretrained_kwargs = {} single_file_kwargs = {} @@ -152,7 +152,6 @@ def test_single_file_model_parameters(self): ) def test_single_file_loading_local_files_only(self): - """Test single file loading with local_files_only=True.""" single_file_kwargs = {} if hasattr(self, "torch_dtype") and self.torch_dtype: @@ -169,7 +168,6 @@ def test_single_file_loading_local_files_only(self): assert model_single_file is not None, "Failed to load model with local_files_only=True" def test_single_file_loading_with_diffusers_config(self): - """Test single file loading with diffusers config.""" single_file_kwargs = {} if hasattr(self, "torch_dtype") and self.torch_dtype: @@ -199,7 +197,6 @@ def test_single_file_loading_with_diffusers_config(self): ), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}" def test_single_file_loading_with_diffusers_config_local_files_only(self): - """Test single file loading with diffusers config and local_files_only=True.""" single_file_kwargs = {} if hasattr(self, "torch_dtype") and self.torch_dtype: @@ -217,7 +214,6 @@ def test_single_file_loading_with_diffusers_config_local_files_only(self): assert model_single_file is not None, "Failed to load model with config and local_files_only=True" def test_single_file_loading_dtype(self): - """Test single file loading with different dtypes.""" for dtype in [torch.float32, torch.float16]: if torch_device == "mps" and dtype == torch.bfloat16: continue @@ -232,7 +228,6 @@ def test_single_file_loading_dtype(self): backend_empty_cache(torch_device) def test_checkpoint_variant_loading(self): - """Test loading checkpoints with alternate keys/variants if provided.""" if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths: return diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py index e69de29bb2d1..f301b5a6d0c2 100644 --- a/tests/models/testing_utils/training.py +++ b/tests/models/testing_utils/training.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import pytest +import torch + +from diffusers.training_utils import EMAModel + +from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device + + +@is_training +@require_torch_accelerator_with_training +class TrainingTesterMixin: + """ + Mixin class for testing training functionality on models. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Expected properties to be implemented by subclasses: + - output_shape: Tuple defining the expected output shape + + Pytest mark: training + Use `pytest -m "not training"` to skip these tests + """ + + def test_training(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + + def test_training_with_ema(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + ema_model = EMAModel(model.parameters()) + + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + ema_model.step(model.parameters()) + + def test_gradient_checkpointing(self): + if not self.model_class._supports_gradient_checkpointing: + pytest.skip("Gradient checkpointing is not supported.") + + init_dict = self.get_init_dict() + + # at init model should have gradient checkpointing disabled + model = self.model_class(**init_dict) + assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init" + + # check enable works + model.enable_gradient_checkpointing() + assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled" + + # check disable works + model.disable_gradient_checkpointing() + assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled" + + def test_gradient_checkpointing_is_applied(self, expected_set=None): + if not self.model_class._supports_gradient_checkpointing: + pytest.skip("Gradient checkpointing is not supported.") + + if expected_set is None: + pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.") + + init_dict = self.get_init_dict() + + model_class_copy = copy.copy(self.model_class) + model = model_class_copy(**init_dict) + model.enable_gradient_checkpointing() + + modules_with_gc_enabled = {} + for submodule in model.modules(): + if hasattr(submodule, "gradient_checkpointing"): + assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled" + modules_with_gc_enabled[submodule.__class__.__name__] = True + + assert set(modules_with_gc_enabled.keys()) == expected_set, ( + f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} " + f"do not match expected set {expected_set}" + ) + assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled" + + def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None): + if not self.model_class._supports_gradient_checkpointing: + pytest.skip("Gradient checkpointing is not supported.") + + if skip is None: + skip = set() + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + inputs_dict_copy = copy.deepcopy(inputs_dict) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + assert not model.is_gradient_checkpointing and model.training + + out = model(**inputs_dict) + if isinstance(out, dict): + out = out.sample if hasattr(out, "sample") else out.to_tuple()[0] + + # run the backwards pass on the model + model.zero_grad() + + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() + + # re-instantiate the model now enabling gradient checkpointing + torch.manual_seed(0) + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict_copy) + if isinstance(out_2, dict): + out_2 = out_2.sample if hasattr(out_2, "sample") else out_2.to_tuple()[0] + + # run the backwards pass on the model + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() + + # compare the output and parameters gradients + assert ( + loss - loss_2 + ).abs() < loss_tolerance, f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}" + + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + + for name, param in named_params.items(): + if "post_quant_conv" in name: + continue + if name in skip: + continue + if param.grad is None: + continue + + assert torch_all_close( + param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol + ), f"Gradient mismatch for {name}" + + def test_mixed_precision_training(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + + # Test with float16 + if torch.device(torch_device).type != "cpu": + with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + + loss.backward() + + # Test with bfloat16 + if torch.device(torch_device).type != "cpu": + model.zero_grad() + with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + + loss.backward() diff --git a/tests/models/transformers/test_models_transformer_flux_.py b/tests/models/transformers/test_models_transformer_flux_.py index a67218548cbc..e250bf33425b 100644 --- a/tests/models/transformers/test_models_transformer_flux_.py +++ b/tests/models/transformers/test_models_transformer_flux_.py @@ -16,13 +16,26 @@ import torch from diffusers import FluxTransformer2DModel +from diffusers.models.embeddings import ImageProjection from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import LoraHotSwappingForModelTesterMixin -from ..testing_utils.common import ModelTesterMixin -from ..testing_utils.compile import TorchCompileTesterMixin -from ..testing_utils.single_file import SingleFileTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BitsAndBytesTesterMixin, + GGUFTesterMixin, + IPAdapterTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelOptTesterMixin, + ModelTesterMixin, + QuantoTesterMixin, + SingleFileTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() @@ -30,6 +43,8 @@ class FluxTransformerTesterConfig: model_class = FluxTransformer2DModel + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" + pretrained_model_kwargs = {"subfolder": "transformer"} def get_init_dict(self): """Return Flux model initialization arguments.""" @@ -104,19 +119,91 @@ def test_deprecated_inputs_img_txt_ids_3d(self): ) -class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): - ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" - alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] - pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" - subfolder = "transformer" +class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for Flux Transformer.""" + pass -class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): +class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin): + """Training tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): + """IP Adapter tests for Flux Transformer.""" + + def create_ip_adapter_state_dict(self, model): + from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor + + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + + key_id += 1 + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=( + model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 + ), + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + +class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for Flux Transformer.""" + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] def get_dummy_inputs(self, height=4, width=4): - """Override to support dynamic height/width for compilation tests.""" + """Override to support dynamic height/width for LoRA hotswap tests.""" batch_size = 1 num_latent_channels = 4 num_image_channels = 3 @@ -133,16 +220,16 @@ def get_dummy_inputs(self, height=4, width=4): } -class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): +class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] def get_dummy_inputs(self, height=4, width=4): - """Override to support dynamic height/width for LoRA hotswap tests.""" + """Override to support dynamic height/width for compilation tests.""" batch_size = 1 num_latent_channels = 4 num_image_channels = 3 - sequence_length = 48 - embedding_dim = 32 + sequence_length = 24 + embedding_dim = 8 return { "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), @@ -152,3 +239,78 @@ def get_dummy_inputs(self, height=4, width=4): "txt_ids": randn_tensor((sequence_length, num_image_channels)), "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), } + + +class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): + ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] + pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" + subfolder = "transformer" + pass + + +class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): + def get_dummy_inputs(self): + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): + def get_dummy_inputs(self): + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): + def get_dummy_inputs(self): + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): + gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf" + + def get_dummy_inputs(self): + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin): + def get_dummy_inputs(self): + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 0f4fd408a7c1..c3bb71e79415 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -98,9 +98,9 @@ def test_cuda_kernels_vs_native(self): output_native = linear.forward_native(x) output_cuda = linear.forward_cuda(x) - assert torch.allclose(output_native, output_cuda, 1e-2), ( - f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}" - ) + assert torch.allclose( + output_native, output_cuda, 1e-2 + ), f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}" @nightly diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 6ed7e3467d7f..bd82a19259f9 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -241,7 +241,6 @@ def parse_flag_from_env(key, default=False): _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False) -_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False) def floats_tensor(shape, scale=1.0, rng=None, name=None): @@ -282,12 +281,128 @@ def nightly(test_case): def is_torch_compile(test_case): """ - Decorator marking a test that runs compile tests in the diffusers CI. + Decorator marking a test as a torch.compile test. These tests can be filtered using: + pytest -m "not compile" to skip + pytest -m compile to run only these tests + """ + return pytest.mark.compile(test_case) + + +def is_single_file(test_case): + """ + Decorator marking a test as a single file loading test. These tests can be filtered using: + pytest -m "not single_file" to skip + pytest -m single_file to run only these tests + """ + return pytest.mark.single_file(test_case) + + +def is_lora(test_case): + """ + Decorator marking a test as a LoRA test. These tests can be filtered using: + pytest -m "not lora" to skip + pytest -m lora to run only these tests + """ + return pytest.mark.lora(test_case) + + +def is_ip_adapter(test_case): + """ + Decorator marking a test as an IP Adapter test. These tests can be filtered using: + pytest -m "not ip_adapter" to skip + pytest -m ip_adapter to run only these tests + """ + return pytest.mark.ip_adapter(test_case) + + +def is_training(test_case): + """ + Decorator marking a test as a training test. These tests can be filtered using: + pytest -m "not training" to skip + pytest -m training to run only these tests + """ + return pytest.mark.training(test_case) + + +def is_attention(test_case): + """ + Decorator marking a test as an attention test. These tests can be filtered using: + pytest -m "not attention" to skip + pytest -m attention to run only these tests + """ + return pytest.mark.attention(test_case) + + +def is_memory(test_case): + """ + Decorator marking a test as a memory optimization test. These tests can be filtered using: + pytest -m "not memory" to skip + pytest -m memory to run only these tests + """ + return pytest.mark.memory(test_case) - Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them. +def is_cpu_offload(test_case): + """ + Decorator marking a test as a CPU offload test. These tests can be filtered using: + pytest -m "not cpu_offload" to skip + pytest -m cpu_offload to run only these tests + """ + return pytest.mark.cpu_offload(test_case) + + +def is_group_offload(test_case): + """ + Decorator marking a test as a group offload test. These tests can be filtered using: + pytest -m "not group_offload" to skip + pytest -m group_offload to run only these tests + """ + return pytest.mark.group_offload(test_case) + + +def is_bitsandbytes(test_case): + """ + Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using: + pytest -m "not bitsandbytes" to skip + pytest -m bitsandbytes to run only these tests + """ + return pytest.mark.bitsandbytes(test_case) + + +def is_quanto(test_case): + """ + Decorator marking a test as a Quanto quantization test. These tests can be filtered using: + pytest -m "not quanto" to skip + pytest -m quanto to run only these tests + """ + return pytest.mark.quanto(test_case) + + +def is_torchao(test_case): + """ + Decorator marking a test as a TorchAO quantization test. These tests can be filtered using: + pytest -m "not torchao" to skip + pytest -m torchao to run only these tests + """ + return pytest.mark.torchao(test_case) + + +def is_gguf(test_case): + """ + Decorator marking a test as a GGUF quantization test. These tests can be filtered using: + pytest -m "not gguf" to skip + pytest -m gguf to run only these tests + """ + return pytest.mark.gguf(test_case) + + +def is_modelopt(test_case): + """ + Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using: + pytest -m "not modelopt" to skip + pytest -m modelopt to run only these tests """ - return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case) + return pytest.mark.modelopt(test_case) def require_torch(test_case): From aa29af8f0e02fc99fc0dd0c3ad36e497e5e278eb Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 19 Nov 2025 08:51:38 +0530 Subject: [PATCH 03/12] update --- tests/models/testing_utils/attention.py | 3 +- tests/models/testing_utils/common.py | 11 ++-- tests/models/testing_utils/quantization.py | 7 ++- .../test_models_transformer_flux_.py | 50 +++++++++++-------- tests/testing_utils.py | 14 ++++++ 5 files changed, 53 insertions(+), 32 deletions(-) diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index be88fd309b1f..22512c945882 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -21,11 +21,10 @@ AttnProcessor, ) -from ...testing_utils import is_attention, require_accelerator, torch_device +from ...testing_utils import is_attention, torch_device @is_attention -@require_accelerator class AttentionTesterMixin: """ Mixin class for testing attention processor and module functionality on models. diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index e4697f6200f6..aa3954fbbd41 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -16,10 +16,10 @@ import json import os import tempfile -from typing import Dict, List, Tuple import pytest import torch +import torch.nn as nn from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant @@ -30,8 +30,8 @@ def compute_module_persistent_sizes( model: nn.Module, - dtype: Optional[Union[str, torch.device]] = None, - special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, + dtype: str | torch.device | None = None, + special_dtypes: dict[str, str | torch.device] | None = None, ): """ Compute the size of each submodule of a given model (parameters + persistent buffers). @@ -128,6 +128,7 @@ def get_dummy_inputs(self): ) def test_from_save_pretrained(self, expected_max_diff=5e-5): + torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -273,10 +274,10 @@ def set_nan_tensor_to_zero(t): return t.to(device) def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): + if isinstance(tuple_object, (list, tuple)): for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): + elif isinstance(tuple_object, dict): for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): recursive_check(tuple_iterable_value, dict_iterable_value) elif tuple_object is None: diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 6fd2c6152fb2..0a7d1018035f 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -25,6 +25,7 @@ is_gguf_available, is_nvidia_modelopt_available, is_optimum_quanto_available, + is_torchao_available, ) from ...testing_utils import ( @@ -41,6 +42,7 @@ require_gguf_version_greater_or_equal, require_quanto, require_torchao_version_greater_or_equal, + require_modelopt_version_greater_or_equal, torch_device, ) @@ -58,7 +60,6 @@ pass if is_torchao_available(): - if is_torchao_version(">=", "0.9.0"): pass @@ -644,9 +645,7 @@ def test_torchao_modules_to_not_convert(self): if modules_to_exclude is None: pytest.skip("modules_to_not_convert_for_test not defined for this model") - self._test_quantization_modules_to_not_convert( - self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude - ) + self._test_quantization_modules_to_not_convert(self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude) def test_torchao_device_map(self): """Test that device_map='auto' works correctly with quantization.""" diff --git a/tests/models/transformers/test_models_transformer_flux_.py b/tests/models/transformers/test_models_transformer_flux_.py index e250bf33425b..6526c78b021f 100644 --- a/tests/models/transformers/test_models_transformer_flux_.py +++ b/tests/models/transformers/test_models_transformer_flux_.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch from diffusers import FluxTransformer2DModel @@ -46,7 +48,11 @@ class FluxTransformerTesterConfig: pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" pretrained_model_kwargs = {"subfolder": "transformer"} - def get_init_dict(self): + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: """Return Flux model initialization arguments.""" return { "patch_size": 1, @@ -60,30 +66,32 @@ def get_init_dict(self): "axes_dims_rope": [4, 4, 8], } - def get_dummy_inputs(self): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: batch_size = 1 height = width = 4 num_latent_channels = 4 num_image_channels = 3 - sequence_length = 24 - embedding_dim = 8 + sequence_length = 48 + embedding_dim = 32 return { - "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), - "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), - "pooled_projections": randn_tensor((batch_size, embedding_dim)), - "img_ids": randn_tensor((height * width, num_image_channels)), - "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator + ), + "pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator), + "img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator), + "txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator), "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), } @property - def input_shape(self): - return (16, 4) + def input_shape(self) -> tuple[int, int]: + return (1, 16, 4) @property - def output_shape(self): - return (16, 4) + def output_shape(self) -> tuple[int, int]: + return (1, 16, 4) class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): @@ -140,7 +148,7 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): """IP Adapter tests for Flux Transformer.""" - def create_ip_adapter_state_dict(self, model): + def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor ip_cross_attn_state_dict = {} @@ -202,7 +210,7 @@ class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappin different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - def get_dummy_inputs(self, height=4, width=4): + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: """Override to support dynamic height/width for LoRA hotswap tests.""" batch_size = 1 num_latent_channels = 4 @@ -223,7 +231,7 @@ def get_dummy_inputs(self, height=4, width=4): class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - def get_dummy_inputs(self, height=4, width=4): + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: """Override to support dynamic height/width for compilation tests.""" batch_size = 1 num_latent_channels = 4 @@ -250,7 +258,7 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): - def get_dummy_inputs(self): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: return { "hidden_states": randn_tensor((1, 4096, 64)), "encoder_hidden_states": randn_tensor((1, 512, 4096)), @@ -263,7 +271,7 @@ def get_dummy_inputs(self): class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): - def get_dummy_inputs(self): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: return { "hidden_states": randn_tensor((1, 4096, 64)), "encoder_hidden_states": randn_tensor((1, 512, 4096)), @@ -276,7 +284,7 @@ def get_dummy_inputs(self): class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): - def get_dummy_inputs(self): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: return { "hidden_states": randn_tensor((1, 4096, 64)), "encoder_hidden_states": randn_tensor((1, 512, 4096)), @@ -291,7 +299,7 @@ def get_dummy_inputs(self): class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf" - def get_dummy_inputs(self): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: return { "hidden_states": randn_tensor((1, 4096, 64)), "encoder_hidden_states": randn_tensor((1, 512, 4096)), @@ -304,7 +312,7 @@ def get_dummy_inputs(self): class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin): - def get_dummy_inputs(self): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: return { "hidden_states": randn_tensor((1, 4096, 64)), "encoder_hidden_states": randn_tensor((1, 512, 4096)), diff --git a/tests/testing_utils.py b/tests/testing_utils.py index bd82a19259f9..bd1acd41c0de 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -37,6 +37,7 @@ is_flax_available, is_gguf_available, is_kernels_available, + is_nvidia_modelopt_available, is_note_seq_available, is_onnx_available, is_opencv_available, @@ -765,6 +766,19 @@ def decorator(test_case): return decorator +def require_modelopt_version_greater_or_equal(modelopt_version): + def decorator(test_case): + correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse( + version.parse(importlib.metadata.version("modelopt")).base_version + ) >= version.parse(modelopt_version) + return pytest.mark.skipif( + not correct_nvidia_modelopt_version, + f"Test requires modelopt with version greater than {modelopt_version}.", + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend From 0f1a4e0c141889e4f4ea3e9b674b9634bb7c20a6 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 19 Nov 2025 21:59:20 +0530 Subject: [PATCH 04/12] update --- tests/models/testing_utils/common.py | 14 +++++--------- tests/models/testing_utils/ip_adapter.py | 10 +++++++--- tests/models/testing_utils/quantization.py | 2 +- .../transformers/test_models_transformer_flux_.py | 10 ++++++++-- tests/testing_utils.py | 2 +- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index aa3954fbbd41..7ec8dbbd8b0b 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -260,7 +260,7 @@ def test_output(self, expected_output_shape=None): assert output is not None, "Model output is None" assert ( - output.shape == expected_output_shape + output[0].shape == expected_output_shape or self.output_shape ), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" def test_outputs_equivalence(self): @@ -302,15 +302,11 @@ def recursive_check(tuple_object, dict_object): recursive_check(outputs_tuple, outputs_dict) - def test_model_config_to_json_string(self): - model = self.model_class(**self.get_init_dict()) - - json_string = model.config.to_json_string() - assert isinstance(json_string, str), "Config to_json_string should return a string" - assert len(json_string) > 0, "JSON string should not be empty" - @require_accelerator - @pytest.mark.skipif(torch_device not in ["cuda", "xpu"]) + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be use for inference with an accelerator", + ) def test_from_save_pretrained_float16_bfloat16(self): model = self.model_class(**self.get_init_dict()) model.to(torch_device) diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py index be079df0614d..aff2cf18643b 100644 --- a/tests/models/testing_utils/ip_adapter.py +++ b/tests/models/testing_utils/ip_adapter.py @@ -100,6 +100,7 @@ def test_load_ip_adapter(self): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) + self.prepare_model(model) torch.manual_seed(0) output_no_adapter = model(**inputs_dict, return_dict=False)[0] @@ -128,9 +129,10 @@ def test_ip_adapter_scale(self): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) + # self.prepare_model(model) # Create and load dummy IP adapter state dict - ip_adapter_state_dict = create_ip_adapter_state_dict(model) + ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) model._load_ip_adapter_weights([ip_adapter_state_dict]) # Test scale = 0.0 (no effect) @@ -151,12 +153,13 @@ def test_ip_adapter_scale(self): def test_unload_ip_adapter(self): init_dict = self.get_init_dict() model = self.model_class(**init_dict).to(torch_device) + self.prepare_model(model) # Save original processors original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()} # Create and load IP adapter - ip_adapter_state_dict = create_ip_adapter_state_dict(model) + ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) model._load_ip_adapter_weights([ip_adapter_state_dict]) assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set" @@ -172,9 +175,10 @@ def test_ip_adapter_save_load(self): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) + self.prepare_model(model) # Create and load IP adapter - ip_adapter_state_dict = self.create_ip_adapter_state_dict() + ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) model._load_ip_adapter_weights([ip_adapter_state_dict]) torch.manual_seed(0) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 0a7d1018035f..15d6a3206946 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -40,9 +40,9 @@ require_accelerator, require_bitsandbytes_version_greater, require_gguf_version_greater_or_equal, + require_modelopt_version_greater_or_equal, require_quanto, require_torchao_version_greater_or_equal, - require_modelopt_version_greater_or_equal, torch_device, ) diff --git a/tests/models/transformers/test_models_transformer_flux_.py b/tests/models/transformers/test_models_transformer_flux_.py index 6526c78b021f..e79974459fd4 100644 --- a/tests/models/transformers/test_models_transformer_flux_.py +++ b/tests/models/transformers/test_models_transformer_flux_.py @@ -19,6 +19,7 @@ from diffusers import FluxTransformer2DModel from diffusers.models.embeddings import ImageProjection +from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device @@ -87,11 +88,11 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: @property def input_shape(self) -> tuple[int, int]: - return (1, 16, 4) + return (16, 4) @property def output_shape(self) -> tuple[int, int]: - return (1, 16, 4) + return (16, 4) class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): @@ -148,6 +149,11 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): """IP Adapter tests for Flux Transformer.""" + def prepare_model(self, model): + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + model.set_attn_processor(FluxIPAdapterAttnProcessor(hidden_size, joint_attention_dim, scale=1.0)) + def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor diff --git a/tests/testing_utils.py b/tests/testing_utils.py index bd1acd41c0de..ae69a21cf8a8 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -37,8 +37,8 @@ is_flax_available, is_gguf_available, is_kernels_available, - is_nvidia_modelopt_available, is_note_seq_available, + is_nvidia_modelopt_available, is_onnx_available, is_opencv_available, is_optimum_quanto_available, From fe451c367b7191e790e3cef6571d0cfdfd53d68e Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 11 Dec 2025 11:04:47 +0530 Subject: [PATCH 05/12] update --- tests/conftest.py | 1 + tests/models/test_modeling_common.py | 24 +- tests/models/testing_utils/__init__.py | 3 +- tests/models/testing_utils/attention.py | 98 ++++- tests/models/testing_utils/common.py | 205 ++++++--- tests/models/testing_utils/hub.py | 20 +- tests/models/testing_utils/ip_adapter.py | 95 ++-- tests/models/testing_utils/lora.py | 24 +- tests/models/testing_utils/memory.py | 84 ++-- tests/models/testing_utils/quantization.py | 55 +-- tests/models/testing_utils/single_file.py | 6 +- tests/models/testing_utils/training.py | 15 +- .../test_models_transformer_flux.py | 416 +++++++++++------- .../test_models_transformer_flux_.py | 330 -------------- tests/quantization/gguf/test_gguf.py | 6 +- tests/testing_utils.py | 9 + 16 files changed, 682 insertions(+), 709 deletions(-) delete mode 100644 tests/models/transformers/test_models_transformer_flux_.py diff --git a/tests/conftest.py b/tests/conftest.py index 3744de27f3b2..0f7b9ef984ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,6 +46,7 @@ def pytest_configure(config): config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality") config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality") config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality") + config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality") def pytest_addoption(parser): diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 520bd8f871a4..6f4c3d544b45 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -317,9 +317,9 @@ def test_local_files_only_with_sharded_checkpoint(self): repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True ) - assert all( - torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters()) - ), "Model parameters don't match!" + assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), ( + "Model parameters don't match!" + ) # Remove a shard file cached_shard_file = try_to_load_from_cache( @@ -335,9 +335,9 @@ def test_local_files_only_with_sharded_checkpoint(self): # Verify error mentions the missing shard error_msg = str(context.exception) - assert ( - cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg - ), f"Expected error about missing shard, got: {error_msg}" + assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, ( + f"Expected error about missing shard, got: {error_msg}" + ) @unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners") @unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.") @@ -354,9 +354,9 @@ def test_one_request_upon_cached(self): ) download_requests = [r.method for r in m.request_history] - assert ( - download_requests.count("HEAD") == 3 - ), "3 HEAD requests one for config, one for model, and one for shard index file." + assert download_requests.count("HEAD") == 3, ( + "3 HEAD requests one for config, one for model, and one for shard index file." + ) assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" with requests_mock.mock(real_http=True) as m: @@ -368,9 +368,9 @@ def test_one_request_upon_cached(self): ) cache_requests = [r.method for r in m.request_history] - assert ( - "HEAD" == cache_requests[0] and len(cache_requests) == 2 - ), "We should call only `model_info` to check for commit hash and knowing if shard index is present." + assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, ( + "We should call only `model_info` to check for commit hash and knowing if shard index is present." + ) def test_weight_overwrite(self): with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index 7955982ca91c..e72a3c928b64 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -1,4 +1,4 @@ -from .attention import AttentionTesterMixin +from .attention import AttentionTesterMixin, ContextParallelTesterMixin from .common import ModelTesterMixin from .compile import TorchCompileTesterMixin from .ip_adapter import IPAdapterTesterMixin @@ -17,6 +17,7 @@ __all__ = [ + "ContextParallelTesterMixin", "AttentionTesterMixin", "BitsAndBytesTesterMixin", "CPUOffloadTesterMixin", diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index 22512c945882..f794a7a0aa4a 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -13,15 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch +import torch.multiprocessing as mp +from diffusers.models._modeling_parallel import ContextParallelConfig from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_processor import ( AttnProcessor, ) -from ...testing_utils import is_attention, torch_device +from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device @is_attention @@ -85,9 +89,9 @@ def test_fuse_unfuse_qkv_projections(self): output_after_fusion = output_after_fusion.to_tuple()[0] # Verify outputs match - assert torch.allclose( - output_before_fusion, output_after_fusion, atol=self.base_precision - ), "Output should not change after fusing projections" + assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), ( + "Output should not change after fusing projections" + ) # Unfuse projections model.unfuse_qkv_projections() @@ -106,9 +110,9 @@ def test_fuse_unfuse_qkv_projections(self): output_after_unfusion = output_after_unfusion.to_tuple()[0] # Verify outputs still match - assert torch.allclose( - output_before_fusion, output_after_unfusion, atol=self.base_precision - ), "Output should match original after unfusing projections" + assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), ( + "Output should match original after unfusing projections" + ) def test_get_set_processor(self): init_dict = self.get_init_dict() @@ -177,3 +181,83 @@ def test_attention_processor_count_mismatch_raises_error(self): model.set_attn_processor(wrong_processors) assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" + + +def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue): + try: + # Setup distributed environment + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + torch.distributed.init_process_group( + backend="nccl", + init_method="env://", + world_size=world_size, + rank=rank, + ) + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + + model = model_class(**init_dict) + model.to(device) + model.eval() + + inputs_on_device = {} + for key, value in inputs_dict.items(): + if isinstance(value, torch.Tensor): + inputs_on_device[key] = value.to(device) + else: + inputs_on_device[key] = value + + cp_config = ContextParallelConfig(**cp_dict) + model.enable_parallelism(config=cp_config) + + with torch.no_grad(): + output = model(**inputs_on_device) + if isinstance(output, dict): + output = output.to_tuple()[0] + + if rank == 0: + result_queue.put(("success", output.shape)) + + except Exception as e: + if rank == 0: + result_queue.put(("error", str(e))) + finally: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +@is_context_parallel +@require_torch_multi_accelerator +class ContextParallelTesterMixin: + base_precision = 1e-3 + + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_inference(self, cp_type): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Context parallel requires at least 2 CUDA devices.") + + if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + cp_dict = {cp_type: world_size} + + ctx = mp.get_context("spawn") + result_queue = ctx.Queue() + + mp.spawn( + _context_parallel_worker, + args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue), + nprocs=world_size, + join=True, + ) + + status, result = result_queue.get(timeout=60) + assert status == "success", f"Context parallel inference failed: {result}" diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 7ec8dbbd8b0b..9f4ae271f97f 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -16,16 +16,45 @@ import json import os import tempfile +from collections import defaultdict import pytest import torch import torch.nn as nn from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size -from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator -from ...testing_utils import torch_device +from ...testing_utils import CaptureLogger, torch_device + + +def named_persistent_module_tensors( + module: nn.Module, + recurse: bool = False, +): + """ + A helper function that gathers all the tensors (parameters + persistent buffers) of a given module. + + Args: + module (`torch.nn.Module`): + The module we want the tensors on. + recurse (`bool`, *optional`, defaults to `False`): + Whether or not to go look in every submodule or just return the direct parameters and buffers. + """ + yield from module.named_parameters(recurse=recurse) + + for named_buffer in module.named_buffers(recurse=recurse): + name, _ = named_buffer + # Get parent by splitting on dots and traversing the model + parent = module + if "." in name: + parent_name = name.rsplit(".", 1)[0] + for part in parent_name.split("."): + parent = getattr(parent, part) + name = name.split(".")[-1] + if name not in parent._non_persistent_buffers_set: + yield named_buffer def compute_module_persistent_sizes( @@ -96,9 +125,9 @@ def check_device_map_is_respected(model, device_map): if param_device in ["cpu", "disk"]: assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}" else: - assert param.device == torch.device( - param_device - ), f"Expected device {param_device} for {param_name}, got {param.device}" + assert param.device == torch.device(param_device), ( + f"Expected device {param_device} for {param_name}, got {param.device}" + ) class ModelTesterMixin: @@ -123,9 +152,7 @@ def get_init_dict(self): raise NotImplementedError("get_init_dict must be implemented by subclasses. ") def get_dummy_inputs(self): - raise NotImplementedError( - "get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict." - ) + raise NotImplementedError("get_dummy_inputs must be implemented by subclasses. It should return inputs_dict.") def test_from_save_pretrained(self, expected_max_diff=5e-5): torch.manual_seed(0) @@ -142,9 +169,9 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5): for param_name in model.state_dict().keys(): param_1 = model.state_dict()[param_name] param_2 = new_model.state_dict()[param_name] - assert ( - param_1.shape == param_2.shape - ), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" + assert param_1.shape == param_2.shape, ( + f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" + ) with torch.no_grad(): image = model(**self.get_dummy_inputs()) @@ -158,9 +185,9 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5): new_image = new_image.to_tuple()[0] max_diff = (image - new_image).abs().max().item() - assert ( - max_diff <= expected_max_diff - ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + assert max_diff <= expected_max_diff, ( + f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + ) def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): model = self.model_class(**self.get_init_dict()) @@ -191,9 +218,9 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): new_image = new_image.to_tuple()[0] max_diff = (image - new_image).abs().max().item() - assert ( - max_diff <= expected_max_diff - ), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + assert max_diff <= expected_max_diff, ( + f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" + ) def test_from_save_pretrained_dtype(self): model = self.model_class(**self.get_init_dict()) @@ -242,9 +269,9 @@ def test_determinism(self, expected_max_diff=1e-5): second_filtered = second_flat[mask] max_diff = torch.abs(first_filtered - second_filtered).max().item() - assert ( - max_diff <= expected_max_diff - ), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}" + assert max_diff <= expected_max_diff, ( + f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}" + ) def test_output(self, expected_output_shape=None): model = self.model_class(**self.get_init_dict()) @@ -259,9 +286,9 @@ def test_output(self, expected_output_shape=None): output = output.to_tuple()[0] assert output is not None, "Model output is None" - assert ( - output[0].shape == expected_output_shape or self.output_shape - ), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" + assert output[0].shape == expected_output_shape or self.output_shape, ( + f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" + ) def test_outputs_equivalence(self): def set_nan_tensor_to_zero(t): @@ -302,6 +329,71 @@ def recursive_check(tuple_object, dict_object): recursive_check(outputs_tuple, outputs_dict) + def test_getattr_is_correct(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + # save some things to test + model.dummy_attribute = 5 + model.register_to_config(test_attribute=5) + + logger = logging.get_logger("diffusers.models.modeling_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(model, "dummy_attribute") + assert getattr(model, "dummy_attribute") == 5 + assert model.dummy_attribute == 5 + + # no warning should be thrown + assert cap_logger.out == "" + + logger = logging.get_logger("diffusers.models.modeling_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(model, "save_pretrained") + fn = model.save_pretrained + fn_1 = getattr(model, "save_pretrained") + + assert fn == fn_1 + # no warning should be thrown + assert cap_logger.out == "" + + # warning should be thrown for config attributes accessed directly + with pytest.warns(FutureWarning): + assert model.test_attribute == 5 + + with pytest.warns(FutureWarning): + assert getattr(model, "test_attribute") == 5 + + with pytest.raises(AttributeError) as error: + model.does_not_exist + + assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" + + @require_accelerator + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be used with an accelerator", + ) + def test_keep_in_fp32_modules(self): + model = self.model_class(**self.get_init_dict()) + fp32_modules = model._keep_in_fp32_modules + + if fp32_modules is None or len(fp32_modules) == 0: + pytest.skip("Model does not have _keep_in_fp32_modules defined.") + + # Test with float16 + model.to(torch_device) + model.to(torch.float16) + + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): + assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}" + else: + assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}" + @require_accelerator @pytest.mark.skipif( torch_device not in ["cuda", "xpu"], @@ -324,12 +416,12 @@ def test_from_save_pretrained_float16_bfloat16(self): assert param.data.dtype == torch_dtype with torch.no_grad(): - output = model(**get_dummy_inputs()) - output_loaded = model_loaded(**get_dummy_inputs()) + output = model(**self.get_dummy_inputs()) + output_loaded = model_loaded(**self.get_dummy_inputs()) - assert torch.allclose( - output, output_loaded, atol=1e-4 - ), f"Loaded model output differs for {torch_dtype}" + assert torch.allclose(output, output_loaded, atol=1e-4), ( + f"Loaded model output differs for {torch_dtype}" + ) @require_accelerator def test_sharded_checkpoints(self): @@ -350,9 +442,9 @@ def test_sharded_checkpoints(self): # Check if the right number of shards exists expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert ( - actual_num_shards == expected_num_shards - ), f"Expected {expected_num_shards} shards, got {actual_num_shards}" + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) new_model = self.model_class.from_pretrained(tmp_dir).eval() new_model = new_model.to(torch_device) @@ -361,9 +453,9 @@ def test_sharded_checkpoints(self): inputs_dict_new = self.get_dummy_inputs() new_output = new_model(**inputs_dict_new) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match after sharded save/load" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match after sharded save/load" + ) @require_accelerator def test_sharded_checkpoints_with_variant(self): @@ -382,16 +474,16 @@ def test_sharded_checkpoints_with_variant(self): model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - assert os.path.exists( - os.path.join(tmp_dir, index_filename) - ), f"Variant index file {index_filename} should exist" + assert os.path.exists(os.path.join(tmp_dir, index_filename)), ( + f"Variant index file {index_filename} should exist" + ) # Check if the right number of shards exists expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert ( - actual_num_shards == expected_num_shards - ), f"Expected {expected_num_shards} shards, got {actual_num_shards}" + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval() new_model = new_model.to(torch_device) @@ -400,11 +492,10 @@ def test_sharded_checkpoints_with_variant(self): inputs_dict_new = self.get_dummy_inputs() new_output = new_model(**inputs_dict_new) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match after variant sharded save/load" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match after variant sharded save/load" + ) - @require_accelerator def test_sharded_checkpoints_with_parallel_loading(self): import time @@ -433,9 +524,9 @@ def test_sharded_checkpoints_with_parallel_loading(self): # Check if the right number of shards exists expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert ( - actual_num_shards == expected_num_shards - ), f"Expected {expected_num_shards} shards, got {actual_num_shards}" + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) # Load without parallel loading constants.HF_ENABLE_PARALLEL_LOADING = False @@ -459,16 +550,14 @@ def test_sharded_checkpoints_with_parallel_loading(self): inputs_dict_parallel = self.get_dummy_inputs() output_parallel = model_parallel(**inputs_dict_parallel) - assert torch.allclose( - base_output[0], output_parallel[0], atol=1e-5 - ), "Output should match with parallel loading" + assert torch.allclose(base_output[0], output_parallel[0], atol=1e-5), ( + "Output should match with parallel loading" + ) # Verify parallel loading is faster or at least not significantly slower - # For small test models, the difference might be negligible or even slightly slower due to overhead - # so we just check that parallel loading completed successfully and outputs match - assert ( - parallel_load_time < sequential_load_time - ), f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s" + assert parallel_load_time < sequential_load_time, ( + f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s" + ) finally: # Restore original values constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading @@ -506,6 +595,6 @@ def test_model_parallelism(self): torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match with model parallelism" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match with model parallelism" + ) diff --git a/tests/models/testing_utils/hub.py b/tests/models/testing_utils/hub.py index e20c3ab1630e..40d8777c33b1 100644 --- a/tests/models/testing_utils/hub.py +++ b/tests/models/testing_utils/hub.py @@ -18,7 +18,7 @@ import pytest import torch -from huggingface_hub.utils import is_jinja_available +from huggingface_hub.utils import ModelCard, delete_repo, is_jinja_available from ...others.test_utils import TOKEN, USER, is_staging_test @@ -58,9 +58,9 @@ def test_push_to_hub(self): new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal( - p1, p2 - ), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained" + assert torch.equal(p1, p2), ( + "Parameters don't match after save_pretrained with push_to_hub and from_pretrained" + ) # Reset repo delete_repo(self.repo_id, token=TOKEN) @@ -84,9 +84,9 @@ def test_push_to_hub_in_organization(self): new_model = self.model_class.from_pretrained(self.org_repo_id) for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal( - p1, p2 - ), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained" + assert torch.equal(p1, p2), ( + "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained" + ) # Reset repo delete_repo(self.org_repo_id, token=TOKEN) @@ -101,9 +101,9 @@ def test_push_to_hub_library_name(self): model.push_to_hub(self.repo_id, token=TOKEN) model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data - assert ( - model_card.library_name == "diffusers" - ), f"Expected library_name 'diffusers', got {model_card.library_name}" + assert model_card.library_name == "diffusers", ( + f"Expected library_name 'diffusers', got {model_card.library_name}" + ) # Reset repo delete_repo(self.repo_id, token=TOKEN) diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py index aff2cf18643b..13e141869c3a 100644 --- a/tests/models/testing_utils/ip_adapter.py +++ b/tests/models/testing_utils/ip_adapter.py @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile +import pytest import torch from diffusers.models.attention_processor import IPAdapterAttnProcessor @@ -61,7 +60,7 @@ def create_ip_adapter_state_dict(model): return {"ip_adapter": ip_state_dict} -def check_if_ip_adapter_correctly_set(model) -> bool: +def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool: """ Check if IP Adapter processors are correctly set in the model. @@ -72,7 +71,7 @@ def check_if_ip_adapter_correctly_set(model) -> bool: bool: True if IP Adapter is correctly set, False otherwise """ for module in model.attn_processors.values(): - if isinstance(module, IPAdapterAttnProcessor): + if isinstance(module, processor_cls): return True return False @@ -93,48 +92,49 @@ class IPAdapterTesterMixin: Use `pytest -m "not ip_adapter"` to skip these tests """ + ip_adapter_processor_cls = None + def create_ip_adapter_state_dict(self, model): raise NotImplementedError("child class must implement method to create IPAdapter State Dict") + def modify_inputs_for_ip_adapter(self, model, inputs_dict): + raise NotImplementedError("child class must implement method to create IPAdapter model inputs") + def test_load_ip_adapter(self): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - self.prepare_model(model) torch.manual_seed(0) output_no_adapter = model(**inputs_dict, return_dict=False)[0] - # Create dummy IP adapter state dict ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) - # Load IP adapter model._load_ip_adapter_weights([ip_adapter_state_dict]) - assert check_if_ip_adapter_correctly_set(model), "IP Adapter processors not set correctly" - - torch.manual_seed(0) - # Create dummy image embeds for IP adapter - cross_attention_dim = getattr(model.config, "cross_attention_dim", 32) - image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device) - inputs_dict_with_adapter = inputs_dict.copy() - inputs_dict_with_adapter["image_embeds"] = image_embeds + assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), ( + "IP Adapter processors not set correctly" + ) + inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy()) outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0] - assert not torch.allclose( - output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4 - ), "Output should differ with IP Adapter enabled" + assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), ( + "Output should differ with IP Adapter enabled" + ) + @pytest.mark.skip( + reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring" + ) def test_ip_adapter_scale(self): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - # self.prepare_model(model) - # Create and load dummy IP adapter state dict ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) model._load_ip_adapter_weights([ip_adapter_state_dict]) + inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy()) + # Test scale = 0.0 (no effect) model.set_ip_adapter_scale(0.0) torch.manual_seed(0) @@ -146,14 +146,16 @@ def test_ip_adapter_scale(self): output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0] # Outputs should differ with different scales - assert not torch.allclose( - output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4 - ), "Output should differ with different IP Adapter scales" + assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), ( + "Output should differ with different IP Adapter scales" + ) + @pytest.mark.skip( + reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring" + ) def test_unload_ip_adapter(self): init_dict = self.get_init_dict() model = self.model_class(**init_dict).to(torch_device) - self.prepare_model(model) # Save original processors original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()} @@ -161,49 +163,16 @@ def test_unload_ip_adapter(self): # Create and load IP adapter ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) model._load_ip_adapter_weights([ip_adapter_state_dict]) - assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set" + + assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set" # Unload IP adapter model.unload_ip_adapter() - assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded" + + assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), ( + "IP Adapter should be unloaded" + ) # Verify processors are restored current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()} assert original_processors == current_processors, "Processors should be restored after unload" - - def test_ip_adapter_save_load(self): - init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() - model = self.model_class(**init_dict).to(torch_device) - self.prepare_model(model) - - # Create and load IP adapter - ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) - model._load_ip_adapter_weights([ip_adapter_state_dict]) - - torch.manual_seed(0) - output_before_save = model(**inputs_dict, return_dict=False)[0] - - with tempfile.TemporaryDirectory() as tmpdir: - # Save the IP adapter weights - save_path = os.path.join(tmpdir, "ip_adapter.safetensors") - import safetensors.torch - - safetensors.torch.save_file(ip_adapter_state_dict["ip_adapter"], save_path) - - # Unload and reload - model.unload_ip_adapter() - assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded" - - # Reload from saved file - loaded_state_dict = {"ip_adapter": safetensors.torch.load_file(save_path)} - model._load_ip_adapter_weights([loaded_state_dict]) - assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be loaded" - - torch.manual_seed(0) - output_after_load = model(**inputs_dict_with_adapter, return_dict=False)[0] - - # Outputs should match before and after save/load - assert torch.allclose( - output_before_save, output_after_load, atol=1e-4, rtol=1e-4 - ), "Output should match before and after save/load" diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index dfc3bd2955e5..6777c164f280 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -91,15 +91,15 @@ def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False): torch.manual_seed(0) outputs_with_lora = model(**inputs_dict, return_dict=False)[0] - assert not torch.allclose( - output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4 - ), "Output should differ with LoRA enabled" + assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4), ( + "Output should differ with LoRA enabled" + ) with tempfile.TemporaryDirectory() as tmpdir: model.save_lora_adapter(tmpdir) - assert os.path.isfile( - os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - ), "LoRA weights file not created" + assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")), ( + "LoRA weights file not created" + ) state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) @@ -119,12 +119,12 @@ def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False): torch.manual_seed(0) outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] - assert not torch.allclose( - output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4 - ), "Output should differ with LoRA enabled" - assert torch.allclose( - outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4 - ), "Outputs should match before and after save/load" + assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), ( + "Output should differ with LoRA enabled" + ) + assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), ( + "Outputs should match before and after save/load" + ) def test_lora_wrong_adapter_name_raises_error(self): from peft import LoraConfig diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index d06a125dc600..6cdc72b004c9 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -122,9 +122,9 @@ def test_cpu_offload(self): torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match with CPU offloading" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match with CPU offloading" + ) @require_offload_support def test_disk_offload_without_safetensors(self): @@ -183,9 +183,9 @@ def test_disk_offload_with_safetensors(self): torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose( - base_output[0], new_output[0], atol=1e-5 - ), "Output should match with disk offloading (safetensors)" + assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( + "Output should match with disk offloading (safetensors)" + ) @is_group_offload @@ -247,18 +247,18 @@ def run_forward(model): ) output_with_group_offloading4 = run_forward(model) - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading1, atol=1e-5 - ), "Output should match with block-level offloading" - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading2, atol=1e-5 - ), "Output should match with non-blocking block-level offloading" - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading3, atol=1e-5 - ), "Output should match with leaf-level offloading" - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading4, atol=1e-5 - ), "Output should match with leaf-level offloading with stream" + assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5), ( + "Output should match with block-level offloading" + ) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5), ( + "Output should match with non-blocking block-level offloading" + ) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5), ( + "Output should match with leaf-level offloading" + ) + assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5), ( + "Output should match with leaf-level offloading with stream" + ) @require_group_offload_support @torch.no_grad() @@ -345,9 +345,9 @@ def _run_forward(model, inputs_dict): raise ValueError(f"Following files are missing: {', '.join(missing_files)}") output_with_group_offloading = _run_forward(model, inputs_dict) - assert torch.allclose( - output_without_group_offloading, output_with_group_offloading, atol=atol - ), "Output should match with disk-based group offloading" + assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol), ( + "Output should match with disk-based group offloading" + ) class LayerwiseCastingTesterMixin: @@ -396,16 +396,16 @@ def get_memory_usage(storage_dtype, compute_dtype): ) compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None - assert ( - fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint - ), "Memory footprint should decrease with lower precision storage" + assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, ( + "Memory footprint should decrease with lower precision storage" + ) # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY: - assert ( - fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory - ), "Peak memory should be lower with bf16 compute on newer GPUs" + assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, ( + "Peak memory should be lower with bf16 compute on newer GPUs" + ) # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few # bytes. This only happens for some models, so we allow a small tolerance. @@ -415,6 +415,36 @@ def get_memory_usage(storage_dtype, compute_dtype): or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE ), "Peak memory should be lower or within tolerance with fp8 storage" + def test_layerwise_casting_training(self): + def test_fn(storage_dtype, compute_dtype): + if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16: + pytest.skip("Skipping test because CPU doesn't go well with bfloat16.") + + model = self.model_class(**self.get_init_dict()) + model = model.to(torch_device, dtype=compute_dtype) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + model.train() + + inputs_dict = self.get_inputs_dict() + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) + with torch.amp.autocast(device_type=torch.device(torch_device).type): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + input_tensor = inputs_dict[self.main_input_name] + noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) + noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype) + loss = torch.nn.functional.mse_loss(output, noise) + + loss.backward() + + test_fn(torch.float16, torch.float32) + test_fn(torch.float8_e4m3fn, torch.float32) + test_fn(torch.float8_e5m2, torch.float32) + test_fn(torch.float8_e4m3fn, torch.bfloat16) + @is_memory @require_accelerator diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 15d6a3206946..866d572f9d92 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -26,6 +26,7 @@ is_nvidia_modelopt_available, is_optimum_quanto_available, is_torchao_available, + is_torchao_version, ) from ...testing_utils import ( @@ -128,9 +129,9 @@ def _test_quantization_num_parameters(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) num_params_quantized = model_quantized.num_parameters() - assert ( - num_params == num_params_quantized - ), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + assert num_params == num_params_quantized, ( + f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + ) def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): model = self._load_unquantized_model() @@ -140,9 +141,9 @@ def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_red mem_quantized = model_quantized.get_memory_footprint() ratio = mem / mem_quantized - assert ( - ratio >= expected_memory_reduction - ), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + assert ratio >= expected_memory_reduction, ( + f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + ) def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) @@ -243,12 +244,12 @@ def _test_quantized_layers(self, config_kwargs): self._verify_if_layer_quantized(name, module, config_kwargs) num_quantized_layers += 1 - assert ( - num_quantized_layers > 0 - ), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" - assert ( - num_quantized_layers == expected_quantized_layers - ), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + assert num_quantized_layers > 0, ( + f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" + ) + assert num_quantized_layers == expected_quantized_layers, ( + f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + ) def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): """ @@ -272,9 +273,9 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no if any(excluded in name for excluded in modules_to_not_convert): found_excluded = True # This module should NOT be quantized - assert not self._is_module_quantized( - module - ), f"Module {name} should not be quantized but was found to be quantized" + assert not self._is_module_quantized(module), ( + f"Module {name} should not be quantized but was found to be quantized" + ) assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}" @@ -296,9 +297,9 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no mem_with_exclusion = model_with_exclusion.get_memory_footprint() mem_fully_quantized = model_fully_quantized.get_memory_footprint() - assert ( - mem_with_exclusion > mem_fully_quantized - ), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + assert mem_with_exclusion > mem_fully_quantized, ( + f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + ) def _test_quantization_device_map(self, config_kwargs): """ @@ -380,9 +381,9 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs): def _verify_if_layer_quantized(self, name, module, config_kwargs): expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params - assert ( - module.weight.__class__ == expected_weight_class - ), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + assert module.weight.__class__ == expected_weight_class, ( + f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + ) @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) def test_bnb_quantization_num_parameters(self, config_name): @@ -450,13 +451,13 @@ def test_bnb_keep_modules_in_fp32(self): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): - assert ( - module.weight.dtype == torch.float32 - ), f"Module {name} should be FP32 but is {module.weight.dtype}" + assert module.weight.dtype == torch.float32, ( + f"Module {name} should be FP32 but is {module.weight.dtype}" + ) else: - assert ( - module.weight.dtype == torch.uint8 - ), f"Module {name} should be uint8 but is {module.weight.dtype}" + assert module.weight.dtype == torch.uint8, ( + f"Module {name} should be uint8 but is {module.weight.dtype}" + ) with torch.no_grad(): inputs = self.get_dummy_inputs() diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py index cb0ae7a4e7e4..67d770849f00 100644 --- a/tests/models/testing_utils/single_file.py +++ b/tests/models/testing_utils/single_file.py @@ -192,9 +192,9 @@ def test_single_file_loading_with_diffusers_config(self): for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}" + assert model.config[param_name] == param_value, ( + f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}" + ) def test_single_file_loading_with_diffusers_config_local_files_only(self): single_file_kwargs = {} diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py index f301b5a6d0c2..7e4193d59e84 100644 --- a/tests/models/testing_utils/training.py +++ b/tests/models/testing_utils/training.py @@ -116,8 +116,7 @@ def test_gradient_checkpointing_is_applied(self, expected_set=None): modules_with_gc_enabled[submodule.__class__.__name__] = True assert set(modules_with_gc_enabled.keys()) == expected_set, ( - f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} " - f"do not match expected set {expected_set}" + f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}" ) assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled" @@ -169,9 +168,9 @@ def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_gra loss_2.backward() # compare the output and parameters gradients - assert ( - loss - loss_2 - ).abs() < loss_tolerance, f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}" + assert (loss - loss_2).abs() < loss_tolerance, ( + f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}" + ) named_params = dict(model.named_parameters()) named_params_2 = dict(model_2.named_parameters()) @@ -184,9 +183,9 @@ def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_gra if param.grad is None: continue - assert torch_all_close( - param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol - ), f"Gradient mismatch for {name}" + assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), ( + f"Gradient mismatch for {name}" + ) def test_mixed_precision_training(self): init_dict = self.get_init_dict() diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 3ab02f797b5b..43e02db448eb 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -13,119 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from typing import Any import torch from diffusers import FluxTransformer2DModel -from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 from diffusers.models.embeddings import ImageProjection - -from ...testing_utils import enable_full_determinism, is_peft_available, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import LoraHotSwappingForModelTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BitsAndBytesTesterMixin, + ContextParallelTesterMixin, + GGUFTesterMixin, + IPAdapterTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelOptTesterMixin, + ModelTesterMixin, + QuantoTesterMixin, + SingleFileTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -def create_flux_ip_adapter_state_dict(model): - # "ip_adapter" (cross-attention weights) - ip_cross_attn_state_dict = {} - key_id = 0 - - for name in model.attn_processors.keys(): - if name.startswith("single_transformer_blocks"): - continue - - joint_attention_dim = model.config["joint_attention_dim"] - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - sd = FluxIPAdapterJointAttnProcessor2_0( - hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 - ).state_dict() - ip_cross_attn_state_dict.update( - { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], - f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], - f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], - } - ) - - key_id += 1 - - # "image_proj" (ImageProjection layer weights) - - image_projection = ImageProjection( - cross_attention_dim=model.config["joint_attention_dim"], - image_embed_dim=( - model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 - ), - num_image_text_embeds=4, - ) - - ip_image_projection_state_dict = {} - sd = image_projection.state_dict() - ip_image_projection_state_dict.update( - { - "proj.weight": sd["image_embeds.weight"], - "proj.bias": sd["image_embeds.bias"], - "norm.weight": sd["norm.weight"], - "norm.bias": sd["norm.bias"], - } - ) - - del sd - ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) - return ip_state_dict - - -class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): +class FluxTransformerTesterConfig: model_class = FluxTransformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.7, 0.6, 0.6] - - # Skip setting testing with default: AttnProcessor - uses_custom_attn_processor = True - - @property - def dummy_input(self): - return self.prepare_dummy_input() + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" + pretrained_model_kwargs = {"subfolder": "transformer"} @property - def input_shape(self): - return (16, 4) - - @property - def output_shape(self): - return (16, 4) - - def prepare_dummy_input(self, height=4, width=4): - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 48 - embedding_dim = 32 - - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) - text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) - image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) - timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + def generator(self): + return torch.Generator("cpu").manual_seed(0) + def get_init_dict(self) -> dict[str, int | list[int]]: + """Return Flux model initialization arguments.""" return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "img_ids": image_ids, - "txt_ids": text_ids, - "pooled_projections": pooled_prompt_embeds, - "timestep": timestep, - } - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { "patch_size": 1, "in_channels": 4, "num_layers": 1, @@ -137,11 +68,40 @@ def prepare_init_args_and_inputs_for_common(self): "axes_dims_rope": [4, 4, 8], } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + height = width = 4 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 48 + embedding_dim = 32 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator + ), + "pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator), + "img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator), + "txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + + @property + def input_shape(self) -> tuple[int, int]: + return (16, 4) + + @property + def output_shape(self) -> tuple[int, int]: + return (16, 4) + +class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): def test_deprecated_inputs_img_txt_ids_3d(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + """Test that deprecated 3D img_ids and txt_ids still work.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) model.to(torch_device) model.eval() @@ -162,63 +122,223 @@ def test_deprecated_inputs_img_txt_ids_3d(self): with torch.no_grad(): output_2 = model(**inputs_dict).to_tuple()[0] - self.assertEqual(output_1.shape, output_2.shape) - self.assertTrue( - torch.allclose(output_1, output_2, atol=1e-5), - msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", + assert output_1.shape == output_2.shape + assert torch.allclose(output_1, output_2, atol=1e-5), ( + "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) " + "are not equal as them as 2d inputs" ) - def test_gradient_checkpointing_is_applied(self): - expected_set = {"FluxTransformer2DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - # The test exists for cases like - # https://github.com/huggingface/diffusers/issues/11874 - @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_lora_exclude_modules(self): - from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict - - lora_rank = 4 - target_module = "single_transformer_blocks.0.proj_out" - adapter_name = "foo" - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - state_dict = model.state_dict() - target_mod_shape = state_dict[f"{target_module}.weight"].shape - lora_state_dict = { - f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22, - f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33, - } - # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter). - config = LoraConfig( - r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"] + +class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin): + """Training tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin): + """Context Parallel inference tests for Flux Transformer""" + + pass + + +class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): + """IP Adapter tests for Flux Transformer.""" + + ip_adapter_processor_cls = FluxIPAdapterAttnProcessor + + def modify_inputs_for_ip_adapter(self, model, inputs_dict): + torch.manual_seed(0) + # Create dummy image embeds for IP adapter + cross_attention_dim = getattr(model.config, "joint_attention_dim", 32) + image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device) + + inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}}) + + return inputs_dict + + def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: + from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor + + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + + key_id += 1 + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=( + model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 + ), + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } ) - inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict) - set_peft_model_state_dict(model, lora_state_dict, adapter_name) - retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) - assert len(retrieved_lora_state_dict) == len(lora_state_dict) - assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all() - assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all() + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + +class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for Flux Transformer.""" + + pass + + +class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for Flux Transformer.""" -class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = FluxTransformer2DModel different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - def prepare_init_args_and_inputs_for_common(self): - return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for LoRA hotswap tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 8 - def prepare_dummy_input(self, height, width): - return FluxTransformerTests().prepare_dummy_input(height=height, width=width) + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } -class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = FluxTransformer2DModel +class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - def prepare_init_args_and_inputs_for_common(self): - return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for compilation tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 8 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), + "pooled_projections": randn_tensor((batch_size, embedding_dim)), + "img_ids": randn_tensor((height * width, num_image_channels)), + "txt_ids": randn_tensor((sequence_length, num_image_channels)), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + + +class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): + ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] + pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" + subfolder = "transformer" + pass - def prepare_dummy_input(self, height, width): - return FluxTransformerTests().prepare_dummy_input(height=height, width=width) + +class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): + gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf" + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } + + +class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin): + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + return { + "hidden_states": randn_tensor((1, 4096, 64)), + "encoder_hidden_states": randn_tensor((1, 512, 4096)), + "pooled_projections": randn_tensor((1, 768)), + "timestep": torch.tensor([1.0]).to(torch_device), + "img_ids": randn_tensor((4096, 3)), + "txt_ids": randn_tensor((512, 3)), + "guidance": torch.tensor([3.5]).to(torch_device), + } diff --git a/tests/models/transformers/test_models_transformer_flux_.py b/tests/models/transformers/test_models_transformer_flux_.py deleted file mode 100644 index e79974459fd4..000000000000 --- a/tests/models/transformers/test_models_transformer_flux_.py +++ /dev/null @@ -1,330 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - -from diffusers import FluxTransformer2DModel -from diffusers.models.embeddings import ImageProjection -from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor -from diffusers.utils.torch_utils import randn_tensor - -from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin -from ..testing_utils import ( - AttentionTesterMixin, - BitsAndBytesTesterMixin, - GGUFTesterMixin, - IPAdapterTesterMixin, - LoraTesterMixin, - MemoryTesterMixin, - ModelOptTesterMixin, - ModelTesterMixin, - QuantoTesterMixin, - SingleFileTesterMixin, - TorchAoTesterMixin, - TorchCompileTesterMixin, - TrainingTesterMixin, -) - - -enable_full_determinism() - - -class FluxTransformerTesterConfig: - model_class = FluxTransformer2DModel - pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" - pretrained_model_kwargs = {"subfolder": "transformer"} - - @property - def generator(self): - return torch.Generator("cpu").manual_seed(0) - - def get_init_dict(self) -> dict[str, int | list[int]]: - """Return Flux model initialization arguments.""" - return { - "patch_size": 1, - "in_channels": 4, - "num_layers": 1, - "num_single_layers": 1, - "attention_head_dim": 16, - "num_attention_heads": 2, - "joint_attention_dim": 32, - "pooled_projection_dim": 32, - "axes_dims_rope": [4, 4, 8], - } - - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - batch_size = 1 - height = width = 4 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 48 - embedding_dim = 32 - - return { - "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator), - "encoder_hidden_states": randn_tensor( - (batch_size, sequence_length, embedding_dim), generator=self.generator - ), - "pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator), - "img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator), - "txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator), - "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), - } - - @property - def input_shape(self) -> tuple[int, int]: - return (16, 4) - - @property - def output_shape(self) -> tuple[int, int]: - return (16, 4) - - -class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): - def test_deprecated_inputs_img_txt_ids_3d(self): - """Test that deprecated 3D img_ids and txt_ids still work.""" - init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output_1 = model(**inputs_dict).to_tuple()[0] - - # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) - text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) - image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) - - assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor" - assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor" - - inputs_dict["txt_ids"] = text_ids_3d - inputs_dict["img_ids"] = image_ids_3d - - with torch.no_grad(): - output_2 = model(**inputs_dict).to_tuple()[0] - - assert output_1.shape == output_2.shape - assert torch.allclose(output_1, output_2, atol=1e-5), ( - "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) " - "are not equal as them as 2d inputs" - ) - - -class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin): - """Memory optimization tests for Flux Transformer.""" - - pass - - -class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin): - """Training tests for Flux Transformer.""" - - pass - - -class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin): - """Attention processor tests for Flux Transformer.""" - - pass - - -class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): - """IP Adapter tests for Flux Transformer.""" - - def prepare_model(self, model): - joint_attention_dim = model.config["joint_attention_dim"] - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - model.set_attn_processor(FluxIPAdapterAttnProcessor(hidden_size, joint_attention_dim, scale=1.0)) - - def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: - from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor - - ip_cross_attn_state_dict = {} - key_id = 0 - - for name in model.attn_processors.keys(): - if name.startswith("single_transformer_blocks"): - continue - - joint_attention_dim = model.config["joint_attention_dim"] - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - sd = FluxIPAdapterAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 - ).state_dict() - ip_cross_attn_state_dict.update( - { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], - f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], - f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], - } - ) - - key_id += 1 - - image_projection = ImageProjection( - cross_attention_dim=model.config["joint_attention_dim"], - image_embed_dim=( - model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 - ), - num_image_text_embeds=4, - ) - - ip_image_projection_state_dict = {} - sd = image_projection.state_dict() - ip_image_projection_state_dict.update( - { - "proj.weight": sd["image_embeds.weight"], - "proj.bias": sd["image_embeds.bias"], - "norm.weight": sd["norm.weight"], - "norm.bias": sd["norm.bias"], - } - ) - - del sd - ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) - return ip_state_dict - - -class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): - """LoRA adapter tests for Flux Transformer.""" - - pass - - -class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): - """LoRA hot-swapping tests for Flux Transformer.""" - - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - - def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: - """Override to support dynamic height/width for LoRA hotswap tests.""" - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 24 - embedding_dim = 8 - - return { - "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), - "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), - "pooled_projections": randn_tensor((batch_size, embedding_dim)), - "img_ids": randn_tensor((height * width, num_image_channels)), - "txt_ids": randn_tensor((sequence_length, num_image_channels)), - "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), - } - - -class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - - def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: - """Override to support dynamic height/width for compilation tests.""" - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 24 - embedding_dim = 8 - - return { - "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)), - "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)), - "pooled_projections": randn_tensor((batch_size, embedding_dim)), - "img_ids": randn_tensor((height * width, num_image_channels)), - "txt_ids": randn_tensor((sequence_length, num_image_channels)), - "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), - } - - -class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): - ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" - alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] - pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" - subfolder = "transformer" - pass - - -class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } - - -class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } - - -class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } - - -class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): - gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf" - - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } - - -class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin): - def get_dummy_inputs(self) -> dict[str, torch.Tensor]: - return { - "hidden_states": randn_tensor((1, 4096, 64)), - "encoder_hidden_states": randn_tensor((1, 512, 4096)), - "pooled_projections": randn_tensor((1, 768)), - "timestep": torch.tensor([1.0]).to(torch_device), - "img_ids": randn_tensor((4096, 3)), - "txt_ids": randn_tensor((512, 3)), - "guidance": torch.tensor([3.5]).to(torch_device), - } diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index c3bb71e79415..0f4fd408a7c1 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -98,9 +98,9 @@ def test_cuda_kernels_vs_native(self): output_native = linear.forward_native(x) output_cuda = linear.forward_cuda(x) - assert torch.allclose( - output_native, output_cuda, 1e-2 - ), f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}" + assert torch.allclose(output_native, output_cuda, 1e-2), ( + f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}" + ) @nightly diff --git a/tests/testing_utils.py b/tests/testing_utils.py index ae69a21cf8a8..9860d64dc119 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -406,6 +406,15 @@ def is_modelopt(test_case): return pytest.mark.modelopt(test_case) +def is_context_parallel(test_case): + """ + Decorator marking a test as a context parallel inference test. These tests can be filtered using: + pytest -m "not context_parallel" to skip + pytest -m context_parallel to run only these tests + """ + return pytest.mark.context_parallel(test_case) + + def require_torch(test_case): """ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. From 489480b02a4c4df5f118e44b756943ddd723c02e Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 11 Dec 2025 11:27:59 +0530 Subject: [PATCH 06/12] update --- tests/models/testing_utils/quantization.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 866d572f9d92..140f6db6d2ff 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -36,7 +36,6 @@ is_modelopt, is_quanto, is_torchao, - nightly, require_accelerate, require_accelerator, require_bitsandbytes_version_greater, @@ -325,7 +324,6 @@ def _test_quantization_device_map(self, config_kwargs): @is_bitsandbytes -@nightly @require_accelerator @require_bitsandbytes_version_greater("0.43.2") @require_accelerate @@ -480,7 +478,6 @@ def test_bnb_device_map(self): @is_quanto -@nightly @require_quanto @require_accelerate @require_accelerator @@ -654,7 +651,6 @@ def test_torchao_device_map(self): @is_gguf -@nightly @require_accelerate @require_accelerator @require_gguf_version_greater_or_equal("0.10.0") @@ -744,7 +740,6 @@ def test_gguf_quantized_layers(self): @is_modelopt -@nightly @require_accelerator @require_accelerate @require_modelopt_version_greater_or_equal("0.33.1") From 0fdd9d3a609aa230473109ff7cfb0a5bed1544bc Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 11 Dec 2025 11:41:17 +0530 Subject: [PATCH 07/12] update --- utils/generate_model_tests.py | 489 ++++++++++++++++++++++++++++++++++ 1 file changed, 489 insertions(+) create mode 100644 utils/generate_model_tests.py diff --git a/utils/generate_model_tests.py b/utils/generate_model_tests.py new file mode 100644 index 000000000000..ffd600dfdf29 --- /dev/null +++ b/utils/generate_model_tests.py @@ -0,0 +1,489 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility script to generate test suites for diffusers model classes. + +Usage: + python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py + +This will analyze the model file and generate a test file with appropriate +test classes based on the model's mixins and attributes. +""" + +import argparse +import ast +import sys +from pathlib import Path + + +MIXIN_TO_TESTER = { + "ModelMixin": "ModelTesterMixin", + "PeftAdapterMixin": "LoraTesterMixin", +} + +ATTRIBUTE_TO_TESTER = { + "_cp_plan": "ContextParallelTesterMixin", + "_supports_gradient_checkpointing": "TrainingTesterMixin", +} + +ALWAYS_INCLUDE_TESTERS = [ + "ModelTesterMixin", + "MemoryTesterMixin", + "AttentionTesterMixin", + "TorchCompileTesterMixin", +] + +OPTIONAL_TESTERS = [ + ("BitsAndBytesTesterMixin", "bnb"), + ("QuantoTesterMixin", "quanto"), + ("TorchAoTesterMixin", "torchao"), + ("GGUFTesterMixin", "gguf"), + ("ModelOptTesterMixin", "modelopt"), + ("SingleFileTesterMixin", "single_file"), + ("IPAdapterTesterMixin", "ip_adapter"), +] + + +class ModelAnalyzer(ast.NodeVisitor): + def __init__(self): + self.model_classes = [] + self.current_class = None + + def visit_ClassDef(self, node: ast.ClassDef): + base_names = [] + for base in node.bases: + if isinstance(base, ast.Name): + base_names.append(base.id) + elif isinstance(base, ast.Attribute): + base_names.append(base.attr) + + if "ModelMixin" in base_names: + class_info = { + "name": node.name, + "bases": base_names, + "attributes": {}, + "has_forward": False, + "init_params": [], + } + + for item in node.body: + if isinstance(item, ast.Assign): + for target in item.targets: + if isinstance(target, ast.Name): + attr_name = target.id + if attr_name.startswith("_"): + class_info["attributes"][attr_name] = self._get_value(item.value) + + elif isinstance(item, ast.FunctionDef): + if item.name == "forward": + class_info["has_forward"] = True + class_info["forward_params"] = self._extract_func_params(item) + elif item.name == "__init__": + class_info["init_params"] = self._extract_func_params(item) + + self.model_classes.append(class_info) + + self.generic_visit(node) + + def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]: + params = [] + args = func_node.args + + num_defaults = len(args.defaults) + num_args = len(args.args) + first_default_idx = num_args - num_defaults + + for i, arg in enumerate(args.args): + if arg.arg == "self": + continue + + param_info = {"name": arg.arg, "type": None, "default": None} + + if arg.annotation: + param_info["type"] = self._get_annotation_str(arg.annotation) + + default_idx = i - first_default_idx + if default_idx >= 0 and default_idx < len(args.defaults): + param_info["default"] = self._get_value(args.defaults[default_idx]) + + params.append(param_info) + + return params + + def _get_annotation_str(self, node) -> str: + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Constant): + return repr(node.value) + elif isinstance(node, ast.Subscript): + base = self._get_annotation_str(node.value) + if isinstance(node.slice, ast.Tuple): + args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts) + else: + args = self._get_annotation_str(node.slice) + return f"{base}[{args}]" + elif isinstance(node, ast.Attribute): + return f"{self._get_annotation_str(node.value)}.{node.attr}" + elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + left = self._get_annotation_str(node.left) + right = self._get_annotation_str(node.right) + return f"{left} | {right}" + elif isinstance(node, ast.Tuple): + return ", ".join(self._get_annotation_str(el) for el in node.elts) + return "Any" + + def _get_value(self, node): + if isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.Name): + if node.id == "None": + return None + elif node.id == "True": + return True + elif node.id == "False": + return False + return node.id + elif isinstance(node, ast.List): + return [self._get_value(el) for el in node.elts] + elif isinstance(node, ast.Dict): + return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)} + return "" + + +def analyze_model_file(filepath: str) -> list[dict]: + with open(filepath) as f: + source = f.read() + + tree = ast.parse(source) + analyzer = ModelAnalyzer() + analyzer.visit(tree) + + return analyzer.model_classes + + +def determine_testers(model_info: dict, include_optional: list[str]) -> list[str]: + testers = list(ALWAYS_INCLUDE_TESTERS) + + for base in model_info["bases"]: + if base in MIXIN_TO_TESTER: + tester = MIXIN_TO_TESTER[base] + if tester not in testers: + testers.append(tester) + + for attr, tester in ATTRIBUTE_TO_TESTER.items(): + if attr in model_info["attributes"]: + value = model_info["attributes"][attr] + if value is not None and value is not False: + if tester not in testers: + testers.append(tester) + + if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None: + if "ContextParallelTesterMixin" not in testers: + testers.append("ContextParallelTesterMixin") + + for tester, flag in OPTIONAL_TESTERS: + if flag in include_optional: + if tester not in testers: + testers.append(tester) + + return testers + + +def generate_config_class(model_info: dict, model_name: str) -> str: + class_name = f"{model_name}TesterConfig" + model_class = model_info["name"] + forward_params = model_info.get("forward_params", []) + init_params = model_info.get("init_params", []) + + lines = [ + f"class {class_name}:", + f" model_class = {model_class}", + ' pretrained_model_name_or_path = ""', + ' pretrained_model_kwargs = {"subfolder": "transformer"}', + "", + " @property", + " def generator(self):", + ' return torch.Generator("cpu").manual_seed(0)', + "", + " def get_init_dict(self) -> dict[str, int | list[int]]:", + ] + + if init_params: + lines.append(" # __init__ parameters:") + for param in init_params: + type_str = f": {param['type']}" if param["type"] else "" + default_str = f" = {param['default']}" if param["default"] is not None else "" + lines.append(f" # {param['name']}{type_str}{default_str}") + + lines.extend( + [ + " return {}", + "", + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + ] + ) + + if forward_params: + lines.append(" # forward() parameters:") + for param in forward_params: + type_str = f": {param['type']}" if param["type"] else "" + default_str = f" = {param['default']}" if param["default"] is not None else "" + lines.append(f" # {param['name']}{type_str}{default_str}") + + lines.extend( + [ + " # TODO: Fill in dummy inputs", + " return {}", + "", + " @property", + " def input_shape(self) -> tuple[int, ...]:", + " return (1, 1)", + "", + " @property", + " def output_shape(self) -> tuple[int, ...]:", + " return (1, 1)", + ] + ) + + return "\n".join(lines) + + +def generate_test_class(model_name: str, config_class: str, tester: str) -> str: + tester_short = tester.replace("TesterMixin", "") + class_name = f"Test{model_name}{tester_short}" + + lines = [f"class {class_name}({config_class}, {tester}):"] + + if tester == "TorchCompileTesterMixin": + lines.extend( + [ + " different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]", + "", + " def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:", + " # TODO: Implement dynamic input generation", + " return {}", + ] + ) + elif tester == "IPAdapterTesterMixin": + lines.extend( + [ + " ip_adapter_processor_cls = None # TODO: Set processor class", + "", + " def modify_inputs_for_ip_adapter(self, model, inputs_dict):", + " # TODO: Add IP adapter image embeds to inputs", + " return inputs_dict", + "", + " def create_ip_adapter_state_dict(self, model):", + " # TODO: Create IP adapter state dict", + " return {}", + ] + ) + elif tester == "SingleFileTesterMixin": + lines.extend( + [ + ' ckpt_path = "" # TODO: Set checkpoint path', + " alternate_keys_ckpt_paths = []", + ' pretrained_model_name_or_path = ""', + ' subfolder = "transformer"', + ] + ) + elif tester == "GGUFTesterMixin": + lines.extend( + [ + ' gguf_filename = "" # TODO: Set GGUF filename', + "", + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization tests", + " return {}", + ] + ) + elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]: + lines.extend( + [ + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization tests", + " return {}", + ] + ) + elif tester == "LoraHotSwappingForModelTesterMixin": + lines.extend( + [ + " different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]", + "", + " def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:", + " # TODO: Implement dynamic input generation", + " return {}", + ] + ) + else: + lines.append(" pass") + + return "\n".join(lines) + + +def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str]) -> str: + model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "") + testers = determine_testers(model_info, include_optional) + tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"}) + + lines = [ + "# coding=utf-8", + "# Copyright 2025 HuggingFace Inc.", + "#", + '# Licensed under the Apache License, Version 2.0 (the "License");', + "# you may not use this file except in compliance with the License.", + "# You may obtain a copy of the License at", + "#", + "# http://www.apache.org/licenses/LICENSE-2.0", + "#", + "# Unless required by applicable law or agreed to in writing, software", + '# distributed under the License is distributed on an "AS IS" BASIS,', + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "# See the License for the specific language governing permissions and", + "# limitations under the License.", + "", + "import torch", + "", + f"from diffusers import {model_info['name']}", + "from diffusers.utils.torch_utils import randn_tensor", + "", + "from ...testing_utils import enable_full_determinism, torch_device", + ] + + if "LoraTesterMixin" in testers: + lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin") + + lines.extend( + [ + "from ..testing_utils import (", + *[f" {tester}," for tester in sorted(tester_imports)], + ")", + "", + "", + "enable_full_determinism()", + "", + "", + ] + ) + + config_class = f"{model_name}TesterConfig" + lines.append(generate_config_class(model_info, model_name)) + lines.append("") + lines.append("") + + for tester in testers: + lines.append(generate_test_class(model_name, config_class, tester)) + lines.append("") + lines.append("") + + if "LoraTesterMixin" in testers: + lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin")) + lines.append("") + lines.append("") + + return "\n".join(lines).rstrip() + "\n" + + +def get_test_output_path(model_filepath: str) -> str: + path = Path(model_filepath) + model_filename = path.stem + + if "transformers" in path.parts: + return f"tests/models/transformers/test_models_{model_filename}.py" + elif "unets" in path.parts: + return f"tests/models/unets/test_models_{model_filename}.py" + elif "autoencoders" in path.parts: + return f"tests/models/autoencoders/test_models_{model_filename}.py" + else: + return f"tests/models/test_models_{model_filename}.py" + + +def main(): + parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class") + parser.add_argument( + "model_filepath", + type=str, + help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)", + ) + parser.add_argument( + "--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)" + ) + parser.add_argument( + "--include", + "-i", + type=str, + nargs="*", + default=[], + choices=["compile", "bnb", "quanto", "torchao", "gguf", "modelopt", "single_file", "ip_adapter", "all"], + help="Optional testers to include", + ) + parser.add_argument( + "--class-name", + "-c", + type=str, + default=None, + help="Specific model class to generate tests for (default: first model class found)", + ) + parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file") + + args = parser.parse_args() + + if not Path(args.model_filepath).exists(): + print(f"Error: File not found: {args.model_filepath}", file=sys.stderr) + sys.exit(1) + + model_classes = analyze_model_file(args.model_filepath) + + if not model_classes: + print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr) + sys.exit(1) + + if args.class_name: + model_info = next((m for m in model_classes if m["name"] == args.class_name), None) + if not model_info: + available = [m["name"] for m in model_classes] + print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr) + sys.exit(1) + else: + model_info = model_classes[0] + if len(model_classes) > 1: + print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr) + print("Use --class-name to specify a different class", file=sys.stderr) + + include_optional = args.include + if "all" in include_optional: + include_optional = [flag for _, flag in OPTIONAL_TESTERS] + + generated_code = generate_test_file(model_info, args.model_filepath, include_optional) + + if args.dry_run: + print(generated_code) + else: + output_path = args.output or get_test_output_path(args.model_filepath) + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w") as f: + f.write(generated_code) + + print(f"Generated test file: {output_path}") + print(f"Model class: {model_info['name']}") + print(f"Detected attributes: {list(model_info['attributes'].keys())}") + + +if __name__ == "__main__": + main() From c366b5a817c0671ef9967a58cf6fcec10952f3ae Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 11 Dec 2025 13:37:06 +0530 Subject: [PATCH 08/12] update --- .../test_models_transformer_flux.py | 98 ++++++++++--------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 43e02db448eb..30193088316c 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -45,6 +45,55 @@ enable_full_determinism() +# TODO: This standalone function maintains backward compatibility with pipeline tests +# (tests/pipelines/test_pipelines_common.py) and will be refactored. +def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]: + """Create a dummy IP Adapter state dict for Flux transformer testing.""" + ip_cross_attn_state_dict = {} + key_id = 0 + + for name in model.attn_processors.keys(): + if name.startswith("single_transformer_blocks"): + continue + + joint_attention_dim = model.config["joint_attention_dim"] + hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] + sd = FluxIPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], + f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], + } + ) + key_id += 1 + + image_projection = ImageProjection( + cross_attention_dim=model.config["joint_attention_dim"], + image_embed_dim=( + model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 + ), + num_image_text_embeds=4, + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict} + + class FluxTransformerTesterConfig: model_class = FluxTransformer2DModel pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" @@ -169,54 +218,7 @@ def modify_inputs_for_ip_adapter(self, model, inputs_dict): return inputs_dict def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: - from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor - - ip_cross_attn_state_dict = {} - key_id = 0 - - for name in model.attn_processors.keys(): - if name.startswith("single_transformer_blocks"): - continue - - joint_attention_dim = model.config["joint_attention_dim"] - hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - sd = FluxIPAdapterAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 - ).state_dict() - ip_cross_attn_state_dict.update( - { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], - f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"], - f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], - } - ) - - key_id += 1 - - image_projection = ImageProjection( - cross_attention_dim=model.config["joint_attention_dim"], - image_embed_dim=( - model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 - ), - num_image_text_embeds=4, - ) - - ip_image_projection_state_dict = {} - sd = image_projection.state_dict() - ip_image_projection_state_dict.update( - { - "proj.weight": sd["image_embeds.weight"], - "proj.bias": sd["image_embeds.bias"], - "norm.weight": sd["norm.weight"], - "norm.bias": sd["norm.bias"], - } - ) - - del sd - ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) - return ip_state_dict + return create_flux_ip_adapter_state_dict(model) class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): From d08e0bb545741c118f7f3eb5864164c733ea788e Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 15 Dec 2025 14:19:27 +0530 Subject: [PATCH 09/12] update --- tests/models/testing_utils/__init__.py | 8 +- tests/models/testing_utils/attention.py | 27 +- tests/models/testing_utils/common.py | 447 +++++++++-------- tests/models/testing_utils/lora.py | 448 +++++++++++++++--- tests/models/testing_utils/memory.py | 57 ++- tests/models/testing_utils/single_file.py | 6 +- .../test_models_transformer_flux.py | 35 +- tests/testing_utils.py | 53 +++ 8 files changed, 794 insertions(+), 287 deletions(-) diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index e72a3c928b64..229179737a9a 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -1,8 +1,8 @@ from .attention import AttentionTesterMixin, ContextParallelTesterMixin -from .common import ModelTesterMixin +from .common import BaseModelTesterConfig, ModelTesterMixin from .compile import TorchCompileTesterMixin from .ip_adapter import IPAdapterTesterMixin -from .lora import LoraTesterMixin +from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin from .quantization import ( BitsAndBytesTesterMixin, @@ -17,14 +17,16 @@ __all__ = [ - "ContextParallelTesterMixin", "AttentionTesterMixin", + "BaseModelTesterConfig", "BitsAndBytesTesterMixin", + "ContextParallelTesterMixin", "CPUOffloadTesterMixin", "GGUFTesterMixin", "GroupOffloadTesterMixin", "IPAdapterTesterMixin", "LayerwiseCastingTesterMixin", + "LoraHotSwappingForModelTesterMixin", "LoraTesterMixin", "MemoryTesterMixin", "ModelOptTesterMixin", diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index f794a7a0aa4a..45443046fb50 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -25,7 +25,13 @@ AttnProcessor, ) -from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device +from ...testing_utils import ( + assert_tensors_close, + is_attention, + is_context_parallel, + require_torch_multi_accelerator, + torch_device, +) @is_attention @@ -89,8 +95,12 @@ def test_fuse_unfuse_qkv_projections(self): output_after_fusion = output_after_fusion.to_tuple()[0] # Verify outputs match - assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), ( - "Output should not change after fusing projections" + assert_tensors_close( + output_before_fusion, + output_after_fusion, + atol=self.base_precision, + rtol=0, + msg="Output should not change after fusing projections", ) # Unfuse projections @@ -110,8 +120,12 @@ def test_fuse_unfuse_qkv_projections(self): output_after_unfusion = output_after_unfusion.to_tuple()[0] # Verify outputs still match - assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), ( - "Output should match original after unfusing projections" + assert_tensors_close( + output_before_fusion, + output_after_unfusion, + atol=self.base_precision, + rtol=0, + msg="Output should match original after unfusing projections", ) def test_get_set_processor(self): @@ -238,9 +252,6 @@ def test_context_parallel_inference(self, cp_type): if not torch.distributed.is_available(): pytest.skip("torch.distributed is not available.") - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - pytest.skip("Context parallel requires at least 2 CUDA devices.") - if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 9f4ae271f97f..11c10c4557af 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -15,8 +15,8 @@ import json import os -import tempfile from collections import defaultdict +from typing import Any, Dict, Optional, Type import pytest import torch @@ -26,7 +26,7 @@ from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator -from ...testing_utils import CaptureLogger, torch_device +from ...testing_utils import assert_tensors_close, torch_device def named_persistent_module_tensors( @@ -130,40 +130,144 @@ def check_device_map_is_respected(model, device_map): ) -class ModelTesterMixin: +class BaseModelTesterConfig: """ - Base mixin class for model testing with common test methods. + Base class defining the configuration interface for model testing. + + This class defines the contract that all model test classes must implement. + It provides a consistent interface for accessing model configuration, initialization + parameters, and test inputs across all testing mixins. - Expected class attributes to be set by subclasses: + Required properties (must be implemented by subclasses): - model_class: The model class to test - - main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states") + + Optional properties (can be overridden, have sensible defaults): + - pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None) + - pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {}) + - output_shape: Expected output shape for output validation tests (default: None) - base_precision: Default tolerance for floating point comparisons (default: 1e-3) + - model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7]) - Expected methods to be implemented by subclasses: + Required methods (must be implemented by subclasses): - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Example usage: + class MyModelTestConfig(BaseModelTesterConfig): + @property + def model_class(self): + return MyModel + + @property + def pretrained_model_name_or_path(self): + return "org/my-model" + + @property + def output_shape(self): + return (1, 3, 32, 32) + + def get_init_dict(self): + return {"in_channels": 3, "out_channels": 3} + + def get_dummy_inputs(self): + return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)} + + class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin): + pass + """ + + # ==================== Required Properties ==================== + + @property + def model_class(self) -> Type[nn.Module]: + """The model class to test. Must be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement the `model_class` property.") + + # ==================== Optional Properties ==================== + + @property + def pretrained_model_name_or_path(self) -> Optional[str]: + """Hub repository ID for the pretrained model (used for quantization and hub tests).""" + return None + + @property + def pretrained_model_kwargs(self) -> Dict[str, Any]: + """Additional kwargs to pass to from_pretrained (e.g., subfolder, variant).""" + return {} + + @property + def output_shape(self) -> Optional[tuple]: + """Expected output shape for output validation tests.""" + return None + + @property + def model_split_percents(self) -> list: + """Percentages for model parallelism tests.""" + return [0.5, 0.7] + + # ==================== Required Methods ==================== + + def get_init_dict(self) -> Dict[str, Any]: + """ + Returns dict of arguments to initialize the model. + + Returns: + Dict[str, Any]: Initialization arguments for the model constructor. + + Example: + return { + "in_channels": 3, + "out_channels": 3, + "sample_size": 32, + } + """ + raise NotImplementedError("Subclasses must implement `get_init_dict()`.") + + def get_dummy_inputs(self) -> Dict[str, Any]: + """ + Returns dict of inputs to pass to the model forward pass. + + Returns: + Dict[str, Any]: Input tensors/values for model.forward(). + + Example: + return { + "sample": torch.randn(1, 3, 32, 32, device=torch_device), + "timestep": torch.tensor([1], device=torch_device), + } + """ + raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.") + + +class ModelTesterMixin: """ + Base mixin class for model testing with common test methods. - model_class = None - base_precision = 1e-3 - model_split_percents = [0.5, 0.7] + This mixin expects the test class to also inherit from BaseModelTesterConfig + (or implement its interface) which provides: + - model_class: The model class to test + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass - def get_init_dict(self): - raise NotImplementedError("get_init_dict must be implemented by subclasses. ") + Example: + class MyModelTestConfig(BaseModelTesterConfig): + model_class = MyModel + def get_init_dict(self): ... + def get_dummy_inputs(self): ... - def get_dummy_inputs(self): - raise NotImplementedError("get_dummy_inputs must be implemented by subclasses. It should return inputs_dict.") + class TestMyModel(MyModelTestConfig, ModelTesterMixin): + pass + """ - def test_from_save_pretrained(self, expected_max_diff=5e-5): + def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0): torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) + model.save_pretrained(tmp_path) + new_model = self.model_class.from_pretrained(tmp_path) + new_model.to(torch_device) # check if all parameters shape are the same for param_name in model.state_dict().keys(): @@ -184,28 +288,24 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5): if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] - max_diff = (image - new_image).abs().max().item() - assert max_diff <= expected_max_diff, ( - f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" - ) + assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") - def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): + def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname, variant="fp16") - new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") + model.save_pretrained(tmp_path, variant="fp16") + new_model = self.model_class.from_pretrained(tmp_path, variant="fp16") - # non-variant cannot be loaded - with pytest.raises(OSError) as exc_info: - self.model_class.from_pretrained(tmpdirname) + # non-variant cannot be loaded + with pytest.raises(OSError) as exc_info: + self.model_class.from_pretrained(tmp_path) - # make sure that error message states what keys are missing - assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value) + # make sure that error message states what keys are missing + assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value) - new_model.to(torch_device) + new_model.to(torch_device) with torch.no_grad(): image = model(**self.get_dummy_inputs()) @@ -217,35 +317,27 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] - max_diff = (image - new_image).abs().max().item() - assert max_diff <= expected_max_diff, ( - f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}" - ) + assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") - def test_from_save_pretrained_dtype(self): + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) + def test_from_save_pretrained_dtype(self, tmp_path, dtype): model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() - for dtype in [torch.float32, torch.float16, torch.bfloat16]: - if torch_device == "mps" and dtype == torch.bfloat16: - continue - with tempfile.TemporaryDirectory() as tmpdirname: - model.to(dtype) - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) - assert new_model.dtype == dtype - if ( - hasattr(self.model_class, "_keep_in_fp32_modules") - and self.model_class._keep_in_fp32_modules is None - ): - # When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None - new_model = self.model_class.from_pretrained( - tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype - ) - assert new_model.dtype == dtype - - def test_determinism(self, expected_max_diff=1e-5): + if torch_device == "mps" and dtype == torch.bfloat16: + pytest.skip(reason=f"{dtype} is not supported on {torch_device}") + + model.to(dtype) + model.save_pretrained(tmp_path) + new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype) + assert new_model.dtype == dtype + if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None: + # When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None + new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype) + assert new_model.dtype == dtype + + def test_determinism(self, atol=1e-5, rtol=0): model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -259,18 +351,15 @@ def test_determinism(self, expected_max_diff=1e-5): if isinstance(second, dict): second = second.to_tuple()[0] - # Remove NaN values and compute max difference + # Filter out NaN values before comparison first_flat = first.flatten() second_flat = second.flatten() - - # Filter out NaN values mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat)) first_filtered = first_flat[mask] second_filtered = second_flat[mask] - max_diff = torch.abs(first_filtered - second_filtered).max().item() - assert max_diff <= expected_max_diff, ( - f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}" + assert_tensors_close( + first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic" ) def test_output(self, expected_output_shape=None): @@ -310,13 +399,12 @@ def recursive_check(tuple_object, dict_object): elif tuple_object is None: return else: - assert torch.allclose( - set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 - ), ( - "Tuple and dict output are not equal. Difference:" - f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" - f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + assert_tensors_close( + set_nan_tensor_to_zero(tuple_object), + set_nan_tensor_to_zero(dict_object), + atol=1e-5, + rtol=0, + msg="Tuple and dict output are not equal", ) model = self.model_class(**self.get_init_dict()) @@ -329,7 +417,7 @@ def recursive_check(tuple_object, dict_object): recursive_check(outputs_tuple, outputs_dict) - def test_getattr_is_correct(self): + def test_getattr_is_correct(self, caplog): init_dict = self.get_init_dict() model = self.model_class(**init_dict) @@ -337,28 +425,26 @@ def test_getattr_is_correct(self): model.dummy_attribute = 5 model.register_to_config(test_attribute=5) - logger = logging.get_logger("diffusers.models.modeling_utils") - # 30 for warning - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: + logger_name = "diffusers.models.modeling_utils" + with caplog.at_level(logging.WARNING, logger=logger_name): + caplog.clear() assert hasattr(model, "dummy_attribute") assert getattr(model, "dummy_attribute") == 5 assert model.dummy_attribute == 5 # no warning should be thrown - assert cap_logger.out == "" + assert caplog.text == "" - logger = logging.get_logger("diffusers.models.modeling_utils") - # 30 for warning - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: + with caplog.at_level(logging.WARNING, logger=logger_name): + caplog.clear() assert hasattr(model, "save_pretrained") fn = model.save_pretrained fn_1 = getattr(model, "save_pretrained") assert fn == fn_1 + # no warning should be thrown - assert cap_logger.out == "" + assert caplog.text == "" # warning should be thrown for config attributes accessed directly with pytest.warns(FutureWarning): @@ -399,32 +485,34 @@ def test_keep_in_fp32_modules(self): torch_device not in ["cuda", "xpu"], reason="float16 and bfloat16 can only be use for inference with an accelerator", ) - def test_from_save_pretrained_float16_bfloat16(self): + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): model = self.model_class(**self.get_init_dict()) model.to(torch_device) fp32_modules = model._keep_in_fp32_modules - with tempfile.TemporaryDirectory() as tmp_dir: - for torch_dtype in [torch.bfloat16, torch.float16]: - model.to(torch_dtype).save_pretrained(tmp_dir) - model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device) + model.to(dtype).save_pretrained(tmp_path) + model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device) + + for name, param in model_loaded.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): + assert param.data.dtype == torch.float32 + else: + assert param.data.dtype == dtype - for name, param in model_loaded.named_parameters(): - if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): - assert param.data.dtype == torch.float32 - else: - assert param.data.dtype == torch_dtype + with torch.no_grad(): + output = model(**self.get_dummy_inputs()) + if isinstance(output, dict): + output = output.to_tuple()[0] - with torch.no_grad(): - output = model(**self.get_dummy_inputs()) - output_loaded = model_loaded(**self.get_dummy_inputs()) + output_loaded = model_loaded(**self.get_dummy_inputs()) + if isinstance(output_loaded, dict): + output_loaded = output_loaded.to_tuple()[0] - assert torch.allclose(output, output_loaded, atol=1e-4), ( - f"Loaded model output differs for {torch_dtype}" - ) + assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}") @require_accelerator - def test_sharded_checkpoints(self): + def test_sharded_checkpoints(self, tmp_path): torch.manual_seed(0) config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() @@ -435,30 +523,30 @@ def test_sharded_checkpoints(self): model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" - # Check if the right number of shards exists - expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) - actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert actual_num_shards == expected_num_shards, ( - f"Expected {expected_num_shards} shards, got {actual_num_shards}" - ) + model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB") + assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" - new_model = self.model_class.from_pretrained(tmp_dir).eval() - new_model = new_model.to(torch_device) + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)) + actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) - torch.manual_seed(0) - inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new) + new_model = self.model_class.from_pretrained(tmp_path).eval() + new_model = new_model.to(torch_device) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match after sharded save/load" - ) + torch.manual_seed(0) + inputs_dict_new = self.get_dummy_inputs() + new_output = new_model(**inputs_dict_new) + + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after sharded save/load" + ) @require_accelerator - def test_sharded_checkpoints_with_variant(self): + def test_sharded_checkpoints_with_variant(self, tmp_path): torch.manual_seed(0) config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() @@ -470,35 +558,33 @@ def test_sharded_checkpoints_with_variant(self): model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small variant = "fp16" - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) - index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - assert os.path.exists(os.path.join(tmp_dir, index_filename)), ( - f"Variant index file {index_filename} should exist" - ) + model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant) - # Check if the right number of shards exists - expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename)) - actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert actual_num_shards == expected_num_shards, ( - f"Expected {expected_num_shards} shards, got {actual_num_shards}" - ) + index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + assert os.path.exists(os.path.join(tmp_path, index_filename)), ( + f"Variant index file {index_filename} should exist" + ) - new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval() - new_model = new_model.to(torch_device) + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename)) + actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) - torch.manual_seed(0) - inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new) + new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval() + new_model = new_model.to(torch_device) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match after variant sharded save/load" - ) + torch.manual_seed(0) + inputs_dict_new = self.get_dummy_inputs() + new_output = new_model(**inputs_dict_new) - def test_sharded_checkpoints_with_parallel_loading(self): - import time + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load" + ) + def test_sharded_checkpoints_with_parallel_loading(self, tmp_path): from diffusers.utils import constants torch.manual_seed(0) @@ -517,47 +603,37 @@ def test_sharded_checkpoints_with_parallel_loading(self): original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None) try: - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" - - # Check if the right number of shards exists - expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) - actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) - assert actual_num_shards == expected_num_shards, ( - f"Expected {expected_num_shards} shards, got {actual_num_shards}" - ) + model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB") + assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" - # Load without parallel loading - constants.HF_ENABLE_PARALLEL_LOADING = False - start_time = time.time() - model_sequential = self.model_class.from_pretrained(tmp_dir).eval() - sequential_load_time = time.time() - start_time - model_sequential = model_sequential.to(torch_device) + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)) + actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) - torch.manual_seed(0) + # Load without parallel loading + constants.HF_ENABLE_PARALLEL_LOADING = False + model_sequential = self.model_class.from_pretrained(tmp_path).eval() + model_sequential = model_sequential.to(torch_device) - # Load with parallel loading - constants.HF_ENABLE_PARALLEL_LOADING = True - constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2 + # Load with parallel loading + constants.HF_ENABLE_PARALLEL_LOADING = True + constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2 - start_time = time.time() - model_parallel = self.model_class.from_pretrained(tmp_dir).eval() - parallel_load_time = time.time() - start_time - model_parallel = model_parallel.to(torch_device) + torch.manual_seed(0) + model_parallel = self.model_class.from_pretrained(tmp_path).eval() + model_parallel = model_parallel.to(torch_device) - torch.manual_seed(0) - inputs_dict_parallel = self.get_dummy_inputs() - output_parallel = model_parallel(**inputs_dict_parallel) + torch.manual_seed(0) + inputs_dict_parallel = self.get_dummy_inputs() + output_parallel = model_parallel(**inputs_dict_parallel) - assert torch.allclose(base_output[0], output_parallel[0], atol=1e-5), ( - "Output should match with parallel loading" - ) + assert_tensors_close( + base_output[0], output_parallel[0], atol=1e-5, rtol=0, msg="Output should match with parallel loading" + ) - # Verify parallel loading is faster or at least not significantly slower - assert parallel_load_time < sequential_load_time, ( - f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s" - ) finally: # Restore original values constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading @@ -565,7 +641,7 @@ def test_sharded_checkpoints_with_parallel_loading(self): constants.HF_PARALLEL_WORKERS = original_parallel_workers @require_torch_multi_accelerator - def test_model_parallelism(self): + def test_model_parallelism(self, tmp_path): if self.model_class._no_split_modules is None: pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") @@ -581,20 +657,19 @@ def test_model_parallelism(self): model_size = compute_module_sizes(model)[""] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir) + model.cpu().save_pretrained(tmp_path) - for max_size in max_gpu_sizes: - max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} - new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) - # Making sure part of the model will be on GPU 0 and GPU 1 - assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs" + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory) + # Making sure part of the model will be on GPU 0 and GPU 1 + assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs" - check_device_map_is_respected(new_model, new_model.hf_device_map) + check_device_map_is_respected(new_model, new_model.hf_device_map) - torch.manual_seed(0) - new_output = new_model(**inputs_dict) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match with model parallelism" - ) + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with model parallelism" + ) diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index 6777c164f280..b790e3ea263b 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -13,17 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import json import os -import tempfile +import re import pytest import safetensors.torch import torch +import torch.nn as nn +from diffusers.utils.import_utils import is_peft_available from diffusers.utils.testing_utils import check_if_dicts_are_equal -from ...testing_utils import is_lora, require_peft_backend, torch_device +from ...testing_utils import ( + assert_tensors_close, + backend_empty_cache, + is_lora, + is_torch_compile, + require_peft_backend, + require_peft_version_greater, + require_torch_accelerator, + require_torch_version_greater, + torch_device, +) + + +if is_peft_available(): + from diffusers.loaders.peft import PeftAdapterMixin def check_if_lora_correctly_set(model) -> bool: @@ -67,7 +84,7 @@ def setup_method(self): if not issubclass(self.model_class, PeftAdapterMixin): pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).") - def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False): + def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False): from peft import LoraConfig from peft.utils import get_peft_model_state_dict @@ -95,26 +112,25 @@ def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False): "Output should differ with LoRA enabled" ) - with tempfile.TemporaryDirectory() as tmpdir: - model.save_lora_adapter(tmpdir) - assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")), ( - "LoRA weights file not created" - ) + model.save_lora_adapter(tmp_path) + assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), ( + "LoRA weights file not created" + ) - state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")) - model.unload_lora() - assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" - model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") + model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True) + state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") - for k in state_dict_loaded: - loaded_v = state_dict_loaded[k] - retrieved_v = state_dict_retrieved[k].to(loaded_v.device) - assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}" + for k in state_dict_loaded: + loaded_v = state_dict_loaded[k] + retrieved_v = state_dict_retrieved[k].to(loaded_v.device) + assert_tensors_close(loaded_v, retrieved_v, atol=1e-5, rtol=0, msg=f"Mismatch in LoRA weight {k}") - assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload" + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload" torch.manual_seed(0) outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] @@ -122,11 +138,15 @@ def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False): assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), ( "Output should differ with LoRA enabled" ) - assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), ( - "Outputs should match before and after save/load" + assert_tensors_close( + outputs_with_lora, + outputs_with_lora_2, + atol=1e-4, + rtol=1e-4, + msg="Outputs should match before and after save/load", ) - def test_lora_wrong_adapter_name_raises_error(self): + def test_lora_wrong_adapter_name_raises_error(self, tmp_path): from peft import LoraConfig init_dict = self.get_init_dict() @@ -142,14 +162,13 @@ def test_lora_wrong_adapter_name_raises_error(self): model.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" - with tempfile.TemporaryDirectory() as tmpdir: - wrong_name = "foo" - with pytest.raises(ValueError) as exc_info: - model.save_lora_adapter(tmpdir, adapter_name=wrong_name) + wrong_name = "foo" + with pytest.raises(ValueError) as exc_info: + model.save_lora_adapter(tmp_path, adapter_name=wrong_name) - assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value) + assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value) - def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False): + def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False): from peft import LoraConfig init_dict = self.get_init_dict() @@ -166,19 +185,18 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, u metadata = model.peft_config["default"].to_dict() assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" - with tempfile.TemporaryDirectory() as tmpdir: - model.save_lora_adapter(tmpdir) - model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - assert os.path.isfile(model_file), "LoRA weights file not created" + model.save_lora_adapter(tmp_path) + model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file), "LoRA weights file not created" - model.unload_lora() - assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" - model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - parsed_metadata = model.peft_config["default_0"].to_dict() - check_if_dicts_are_equal(metadata, parsed_metadata) + model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True) + parsed_metadata = model.peft_config["default_0"].to_dict() + check_if_dicts_are_equal(metadata, parsed_metadata) - def test_lora_adapter_wrong_metadata_raises_error(self): + def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path): from peft import LoraConfig from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY @@ -196,25 +214,337 @@ def test_lora_adapter_wrong_metadata_raises_error(self): model.add_adapter(denoiser_lora_config) assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" - with tempfile.TemporaryDirectory() as tmpdir: - model.save_lora_adapter(tmpdir) - model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - assert os.path.isfile(model_file), "LoRA weights file not created" - - # Perturb the metadata in the state dict - loaded_state_dict = safetensors.torch.load_file(model_file) - metadata = {"format": "pt"} - lora_adapter_metadata = denoiser_lora_config.to_dict() - lora_adapter_metadata.update({"foo": 1, "bar": 2}) - for key, value in lora_adapter_metadata.items(): - if isinstance(value, set): - lora_adapter_metadata[key] = list(value) - metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) - safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) - - model.unload_lora() - assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" - - with pytest.raises(TypeError) as exc_info: - model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - assert "`LoraConfig` class could not be instantiated" in str(exc_info.value) + model.save_lora_adapter(tmp_path) + model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file), "LoRA weights file not created" + + # Perturb the metadata in the state dict + loaded_state_dict = safetensors.torch.load_file(model_file) + metadata = {"format": "pt"} + lora_adapter_metadata = denoiser_lora_config.to_dict() + lora_adapter_metadata.update({"foo": 1, "bar": 2}) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + + with pytest.raises(TypeError) as exc_info: + model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True) + assert "`LoraConfig` class could not be instantiated" in str(exc_info.value) + + +@is_lora +@is_torch_compile +@require_peft_backend +@require_peft_version_greater("0.14.0") +@require_torch_version_greater("2.7.1") +@require_torch_accelerator +class LoraHotSwappingForModelTesterMixin: + """ + Mixin class for testing LoRA hot swapping functionality on models. + + Test that hotswapping does not result in recompilation on the model directly. + We're not extensively testing the hotswapping functionality since it is implemented in PEFT + and is extensively tested there. The goal of this test is specifically to ensure that + hotswapping with diffusers does not require recompilation. + + See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252 + for the analogous PEFT test. + + Expected class attributes to be set by subclasses: + - model_class: The model class to test + - different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest marks: lora, torch_compile + Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests + """ + + different_shapes_for_compilation = None + + def setup_method(self): + if not issubclass(self.model_class, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).") + + def teardown_method(self): + # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, + # there will be recompilation errors, as torch caches the model when run in the same process. + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def _get_lora_config(self, lora_rank, lora_alpha, target_modules): + from peft import LoraConfig + + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=target_modules, + init_lora_weights=False, + use_dora=False, + ) + return lora_config + + def _get_linear_module_name_other_than_attn(self, model): + linear_names = [ + name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name + ] + return linear_names[0] + + def _check_model_hotswap(self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None): + """ + Check that hotswapping works on a model. + + Steps: + - create 2 LoRA adapters and save them + - load the first adapter + - hotswap the second adapter + - check that the outputs are correct + - optionally compile the model + - optionally check if recompilations happen on different shapes + + Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would + fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is + fine. + """ + different_shapes = self.different_shapes_for_compilation + # create 2 adapters with different ranks and alphas + torch.manual_seed(0) + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + alpha0, alpha1 = rank0, rank1 + max_rank = max([rank0, rank1]) + if target_modules1 is None: + target_modules1 = target_modules0[:] + lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0) + lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1) + + model.add_adapter(lora_config0, adapter_name="adapter0") + with torch.inference_mode(): + torch.manual_seed(0) + output0_before = model(**inputs_dict)["sample"] + + model.add_adapter(lora_config1, adapter_name="adapter1") + model.set_adapter("adapter1") + with torch.inference_mode(): + torch.manual_seed(0) + output1_before = model(**inputs_dict)["sample"] + + # sanity checks: + tol = 5e-3 + assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol) + assert not (output0_before == 0).all() + assert not (output1_before == 0).all() + + # save the adapter checkpoints + model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0") + model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1") + del model + + # load the first adapter + torch.manual_seed(0) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + if do_compile or (rank0 != rank1): + # no need to prepare if the model is not compiled or if the ranks are identical + model.enable_lora_hotswap(target_rank=max_rank) + + file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors") + file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors") + model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) + + if do_compile: + model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None) + + with torch.inference_mode(): + # additionally check if dynamic compilation works. + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output0_after = model(**inputs_dict)["sample"] + assert_tensors_close( + output0_before, output0_after, atol=tol, rtol=tol, msg="Output mismatch after loading adapter0" + ) + + # hotswap the 2nd adapter + model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) + + # we need to call forward to potentially trigger recompilation + with torch.inference_mode(): + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output1_after = model(**inputs_dict)["sample"] + assert_tensors_close( + output1_before, + output1_after, + atol=tol, + rtol=tol, + msg="Output mismatch after hotswapping to adapter1", + ) + + # check error when not passing valid adapter name + name = "does-not-exist" + msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" + with pytest.raises(ValueError, match=re.escape(msg)): + model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_model(self, tmp_path, rank0, rank1): + self._check_model_hotswap( + tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"] + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1): + # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache(): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + pytest.skip("Test only applies to UNet.") + + # It's important to add this context to raise an error on recompilation + target_modules = ["conv", "conv1", "conv2"] + with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache(): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + pytest.skip("Test only applies to UNet.") + + # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "conv"] + with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache(): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1): + # In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping + # with `torch.compile()` for models that have both linear and conv layers. In this test, we check + # if we can target a linear layer from the transformer blocks and another linear layer from non-attention + # block. + target_modules = ["to_q"] + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + target_modules.append(self._get_linear_module_name_other_than_attn(model)) + del model + + # It's important to add this context to raise an error on recompilation + with torch._dynamo.config.patch(error_on_recompile=True): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + def test_enable_lora_hotswap_called_after_adapter_added_raises(self): + # ensure that enable_lora_hotswap is called before loading the first adapter + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + + msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") + with pytest.raises(RuntimeError, match=msg): + model.enable_lora_hotswap(target_rank=32) + + def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): + # ensure that enable_lora_hotswap is called before loading the first adapter + import logging + + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + msg = ( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + with caplog.at_level(logging.WARNING): + model.enable_lora_hotswap(target_rank=32, check_compiled="warn") + assert any(msg in record.message for record in caplog.records) + + def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog): + # check possibility to ignore the error/warning + import logging + + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + with caplog.at_level(logging.WARNING): + model.enable_lora_hotswap(target_rank=32, check_compiled="ignore") + assert len(caplog.records) == 0 + + def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): + # check that wrong argument value raises an error + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") + with pytest.raises(ValueError, match=msg): + model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") + + def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog): + # check the error and log + import logging + + # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers + target_modules0 = ["to_q"] + target_modules1 = ["to_q", "to_k"] + with pytest.raises(RuntimeError): # peft raises RuntimeError + with caplog.at_level(logging.ERROR): + self._check_model_hotswap( + tmp_path, + do_compile=True, + rank0=8, + rank1=8, + target_modules0=target_modules0, + target_modules1=target_modules1, + ) + assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + @require_torch_version_greater("2.7.1") + def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1): + different_shapes_for_compilation = self.different_shapes_for_compilation + if different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic + # variable to represent input sizes that are the same. For more details, + # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). + torch.fx.experimental._config.use_duck_shape = False + + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True): + self._check_model_hotswap( + tmp_path, + do_compile=True, + rank0=rank0, + rank1=rank1, + target_modules0=target_modules, + ) diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index 6cdc72b004c9..ebd76656f023 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -27,6 +27,7 @@ from diffusers.utils.torch_utils import get_torch_cuda_device_capability from ...testing_utils import ( + assert_tensors_close, backend_empty_cache, backend_max_memory_allocated, backend_reset_peak_memory_stats, @@ -122,8 +123,8 @@ def test_cpu_offload(self): torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match with CPU offloading" + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading" ) @require_offload_support @@ -156,7 +157,9 @@ def test_disk_offload_without_safetensors(self): torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading" + assert_tensors_close( + base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading" + ) @require_offload_support def test_disk_offload_with_safetensors(self): @@ -183,8 +186,12 @@ def test_disk_offload_with_safetensors(self): torch.manual_seed(0) new_output = new_model(**inputs_dict) - assert torch.allclose(base_output[0], new_output[0], atol=1e-5), ( - "Output should match with disk offloading (safetensors)" + assert_tensors_close( + base_output[0], + new_output[0], + atol=1e-5, + rtol=0, + msg="Output should match with disk offloading (safetensors)", ) @@ -247,17 +254,33 @@ def run_forward(model): ) output_with_group_offloading4 = run_forward(model) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5), ( - "Output should match with block-level offloading" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading1, + atol=1e-5, + rtol=0, + msg="Output should match with block-level offloading", ) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5), ( - "Output should match with non-blocking block-level offloading" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading2, + atol=1e-5, + rtol=0, + msg="Output should match with non-blocking block-level offloading", ) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5), ( - "Output should match with leaf-level offloading" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading3, + atol=1e-5, + rtol=0, + msg="Output should match with leaf-level offloading", ) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5), ( - "Output should match with leaf-level offloading with stream" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading4, + atol=1e-5, + rtol=0, + msg="Output should match with leaf-level offloading with stream", ) @require_group_offload_support @@ -345,8 +368,12 @@ def _run_forward(model, inputs_dict): raise ValueError(f"Following files are missing: {', '.join(missing_files)}") output_with_group_offloading = _run_forward(model, inputs_dict) - assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol), ( - "Output should match with disk-based group offloading" + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading, + atol=atol, + rtol=0, + msg="Output should match with disk-based group offloading", ) diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py index 67d770849f00..992e6dd8d982 100644 --- a/tests/models/testing_utils/single_file.py +++ b/tests/models/testing_utils/single_file.py @@ -22,6 +22,7 @@ from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from ...testing_utils import ( + assert_tensors_close, backend_empty_cache, is_single_file, nightly, @@ -146,9 +147,8 @@ def test_single_file_model_parameters(self): f"pretrained {param.shape} vs single file {param_single_file.shape}" ) - assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), ( - f"Parameter values differ for {key}: " - f"max difference {torch.max(torch.abs(param - param_single_file)).item()}" + assert_tensors_close( + param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}" ) def test_single_file_loading_local_files_only(self): diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 30193088316c..e0b38eda7f43 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -23,13 +23,14 @@ from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin from ..testing_utils import ( AttentionTesterMixin, + BaseModelTesterConfig, BitsAndBytesTesterMixin, ContextParallelTesterMixin, GGUFTesterMixin, IPAdapterTesterMixin, + LoraHotSwappingForModelTesterMixin, LoraTesterMixin, MemoryTesterMixin, ModelOptTesterMixin, @@ -94,10 +95,26 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]: return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict} -class FluxTransformerTesterConfig: - model_class = FluxTransformer2DModel - pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" - pretrained_model_kwargs = {"subfolder": "transformer"} +class FluxTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return FluxTransformer2DModel + + @property + def pretrained_model_name_or_path(self): + return "hf-internal-testing/tiny-flux-pipe" + + @property + def pretrained_model_kwargs(self): + return {"subfolder": "transformer"} + + @property + def output_shape(self) -> tuple[int, int]: + return (16, 4) + + @property + def input_shape(self) -> tuple[int, int]: + return (16, 4) @property def generator(self): @@ -136,14 +153,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), } - @property - def input_shape(self) -> tuple[int, int]: - return (16, 4) - - @property - def output_shape(self) -> tuple[int, int]: - return (16, 4) - class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): def test_deprecated_inputs_img_txt_ids_3d(self): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 9860d64dc119..4c97bbc14c6b 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -131,6 +131,59 @@ def torch_all_close(a, b, *args, **kwargs): return True +def assert_tensors_close( + actual: "torch.Tensor", + expected: "torch.Tensor", + atol: float = 1e-5, + rtol: float = 1e-5, + msg: str = "", +) -> None: + """ + Assert that two tensors are close within tolerance. + + Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| + Provides concise, actionable error messages without dumping full tensors. + + Args: + actual: The actual tensor from the computation. + expected: The expected tensor to compare against. + atol: Absolute tolerance. + rtol: Relative tolerance. + msg: Optional message prefix for the assertion error. + + Raises: + AssertionError: If tensors have different shapes or values exceed tolerance. + + Example: + >>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass") + """ + if not is_torch_available(): + raise ValueError("PyTorch needs to be installed to use this function.") + + if actual.shape != expected.shape: + raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}") + + if not torch.allclose(actual, expected, atol=atol, rtol=rtol): + abs_diff = (actual - expected).abs() + max_diff = abs_diff.max().item() + + flat_idx = abs_diff.argmax().item() + max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist()) + + threshold = atol + rtol * expected.abs() + mismatched = (abs_diff > threshold).sum().item() + total = actual.numel() + + raise AssertionError( + f"{msg}\n" + f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n" + f" Max diff: {max_diff:.6e} at index {max_idx}\n" + f" Actual: {actual.flatten()[flat_idx].item():.6e}\n" + f" Expected: {expected.flatten()[flat_idx].item():.6e}\n" + f" atol: {atol:.6e}, rtol: {rtol:.6e}" + ) + + def numpy_cosine_similarity_distance(a, b): similarity = np.dot(a, b) / (norm(a) * norm(b)) distance = 1.0 - similarity.mean() From eae75437128e407cf1593223256f3f999553bedb Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 15 Dec 2025 16:02:38 +0530 Subject: [PATCH 10/12] update --- tests/models/testing_utils/__init__.py | 3 +- tests/models/testing_utils/attention.py | 95 +-------------- tests/models/testing_utils/common.py | 70 ++++------- tests/models/testing_utils/hub.py | 109 ----------------- tests/models/testing_utils/ip_adapter.py | 40 ------ tests/models/testing_utils/lora.py | 2 - tests/models/testing_utils/memory.py | 5 +- tests/models/testing_utils/quantization.py | 135 ++++++++++++--------- tests/models/testing_utils/training.py | 28 +---- utils/generate_model_tests.py | 36 ++++-- 10 files changed, 140 insertions(+), 383 deletions(-) delete mode 100644 tests/models/testing_utils/hub.py diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index 229179737a9a..6dfb77c71378 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -1,9 +1,10 @@ -from .attention import AttentionTesterMixin, ContextParallelTesterMixin +from .attention import AttentionTesterMixin from .common import BaseModelTesterConfig, ModelTesterMixin from .compile import TorchCompileTesterMixin from .ip_adapter import IPAdapterTesterMixin from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin +from .parallelism import ContextParallelTesterMixin from .quantization import ( BitsAndBytesTesterMixin, GGUFTesterMixin, diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index 45443046fb50..d732195c7e2f 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -13,13 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest import torch -import torch.multiprocessing as mp -from diffusers.models._modeling_parallel import ContextParallelConfig from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_processor import ( AttnProcessor, @@ -28,8 +24,6 @@ from ...testing_utils import ( assert_tensors_close, is_attention, - is_context_parallel, - require_torch_multi_accelerator, torch_device, ) @@ -71,9 +65,7 @@ def test_fuse_unfuse_qkv_projections(self): # Get output before fusion with torch.no_grad(): - output_before_fusion = model(**inputs_dict) - if isinstance(output_before_fusion, dict): - output_before_fusion = output_before_fusion.to_tuple()[0] + output_before_fusion = model(**inputs_dict, return_dict=False)[0] # Fuse projections model.fuse_qkv_projections() @@ -90,9 +82,7 @@ def test_fuse_unfuse_qkv_projections(self): if has_fused_projections: # Get output after fusion with torch.no_grad(): - output_after_fusion = model(**inputs_dict) - if isinstance(output_after_fusion, dict): - output_after_fusion = output_after_fusion.to_tuple()[0] + output_after_fusion = model(**inputs_dict, return_dict=False)[0] # Verify outputs match assert_tensors_close( @@ -115,9 +105,7 @@ def test_fuse_unfuse_qkv_projections(self): # Get output after unfusion with torch.no_grad(): - output_after_unfusion = model(**inputs_dict) - if isinstance(output_after_unfusion, dict): - output_after_unfusion = output_after_unfusion.to_tuple()[0] + output_after_unfusion = model(**inputs_dict, return_dict=False)[0] # Verify outputs still match assert_tensors_close( @@ -195,80 +183,3 @@ def test_attention_processor_count_mismatch_raises_error(self): model.set_attn_processor(wrong_processors) assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" - - -def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue): - try: - # Setup distributed environment - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - torch.distributed.init_process_group( - backend="nccl", - init_method="env://", - world_size=world_size, - rank=rank, - ) - torch.cuda.set_device(rank) - device = torch.device(f"cuda:{rank}") - - model = model_class(**init_dict) - model.to(device) - model.eval() - - inputs_on_device = {} - for key, value in inputs_dict.items(): - if isinstance(value, torch.Tensor): - inputs_on_device[key] = value.to(device) - else: - inputs_on_device[key] = value - - cp_config = ContextParallelConfig(**cp_dict) - model.enable_parallelism(config=cp_config) - - with torch.no_grad(): - output = model(**inputs_on_device) - if isinstance(output, dict): - output = output.to_tuple()[0] - - if rank == 0: - result_queue.put(("success", output.shape)) - - except Exception as e: - if rank == 0: - result_queue.put(("error", str(e))) - finally: - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - - -@is_context_parallel -@require_torch_multi_accelerator -class ContextParallelTesterMixin: - base_precision = 1e-3 - - @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) - def test_context_parallel_inference(self, cp_type): - if not torch.distributed.is_available(): - pytest.skip("torch.distributed is not available.") - - if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: - pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") - - world_size = 2 - init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs() - cp_dict = {cp_type: world_size} - - ctx = mp.get_context("spawn") - result_queue = ctx.Queue() - - mp.spawn( - _context_parallel_worker, - args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue), - nprocs=world_size, - join=True, - ) - - status, result = result_queue.get(timeout=60) - assert status == "success", f"Context parallel inference failed: {result}" diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 11c10c4557af..3611eff7ef81 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -259,7 +259,7 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin): pass """ - def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0): + def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) model.to(torch_device) @@ -278,15 +278,8 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0): ) with torch.no_grad(): - image = model(**self.get_dummy_inputs()) - - if isinstance(image, dict): - image = image.to_tuple()[0] - - new_image = new_model(**self.get_dummy_inputs()) - - if isinstance(new_image, dict): - new_image = new_image.to_tuple()[0] + image = model(**self.get_dummy_inputs(), return_dict=False)[0] + new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") @@ -308,14 +301,8 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): new_model.to(torch_device) with torch.no_grad(): - image = model(**self.get_dummy_inputs()) - if isinstance(image, dict): - image = image.to_tuple()[0] - - new_image = new_model(**self.get_dummy_inputs()) - - if isinstance(new_image, dict): - new_image = new_image.to_tuple()[0] + image = model(**self.get_dummy_inputs(), return_dict=False)[0] + new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") @@ -343,13 +330,8 @@ def test_determinism(self, atol=1e-5, rtol=0): model.eval() with torch.no_grad(): - first = model(**self.get_dummy_inputs()) - if isinstance(first, dict): - first = first.to_tuple()[0] - - second = model(**self.get_dummy_inputs()) - if isinstance(second, dict): - second = second.to_tuple()[0] + first = model(**self.get_dummy_inputs(), return_dict=False)[0] + second = model(**self.get_dummy_inputs(), return_dict=False)[0] # Filter out NaN values before comparison first_flat = first.flatten() @@ -369,10 +351,7 @@ def test_output(self, expected_output_shape=None): inputs_dict = self.get_dummy_inputs() with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] assert output is not None, "Model output is None" assert output[0].shape == expected_output_shape or self.output_shape, ( @@ -501,13 +480,8 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): assert param.data.dtype == dtype with torch.no_grad(): - output = model(**self.get_dummy_inputs()) - if isinstance(output, dict): - output = output.to_tuple()[0] - - output_loaded = model_loaded(**self.get_dummy_inputs()) - if isinstance(output_loaded, dict): - output_loaded = output_loaded.to_tuple()[0] + output = model(**self.get_dummy_inputs(), return_dict=False)[0] + output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}") @@ -519,7 +493,7 @@ def test_sharded_checkpoints(self, tmp_path): model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict, return_dict=False)[0] model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -539,10 +513,10 @@ def test_sharded_checkpoints(self, tmp_path): torch.manual_seed(0) inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new) + new_output = new_model(**inputs_dict_new, return_dict=False)[0] assert_tensors_close( - base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after sharded save/load" + base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after sharded save/load" ) @require_accelerator @@ -553,7 +527,7 @@ def test_sharded_checkpoints_with_variant(self, tmp_path): model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict, return_dict=False)[0] model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -578,10 +552,10 @@ def test_sharded_checkpoints_with_variant(self, tmp_path): torch.manual_seed(0) inputs_dict_new = self.get_dummy_inputs() - new_output = new_model(**inputs_dict_new) + new_output = new_model(**inputs_dict_new, return_dict=False)[0] assert_tensors_close( - base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load" + base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load" ) def test_sharded_checkpoints_with_parallel_loading(self, tmp_path): @@ -593,7 +567,7 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path): model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict, return_dict=False)[0] model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -628,10 +602,10 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path): torch.manual_seed(0) inputs_dict_parallel = self.get_dummy_inputs() - output_parallel = model_parallel(**inputs_dict_parallel) + output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0] assert_tensors_close( - base_output[0], output_parallel[0], atol=1e-5, rtol=0, msg="Output should match with parallel loading" + base_output, output_parallel, atol=1e-5, rtol=0, msg="Output should match with parallel loading" ) finally: @@ -652,7 +626,7 @@ def test_model_parallelism(self, tmp_path): model = model.to(torch_device) torch.manual_seed(0) - base_output = model(**inputs_dict) + base_output = model(**inputs_dict, return_dict=False)[0] model_size = compute_module_sizes(model)[""] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] @@ -668,8 +642,8 @@ def test_model_parallelism(self, tmp_path): check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) - new_output = new_model(**inputs_dict) + new_output = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close( - base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with model parallelism" + base_output, new_output, atol=1e-5, rtol=0, msg="Output should match with model parallelism" ) diff --git a/tests/models/testing_utils/hub.py b/tests/models/testing_utils/hub.py deleted file mode 100644 index 40d8777c33b1..000000000000 --- a/tests/models/testing_utils/hub.py +++ /dev/null @@ -1,109 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tempfile -import uuid - -import pytest -import torch -from huggingface_hub.utils import ModelCard, delete_repo, is_jinja_available - -from ...others.test_utils import TOKEN, USER, is_staging_test - - -@is_staging_test -class ModelPushToHubTesterMixin: - """ - Mixin class for testing push_to_hub functionality on models. - - Expected class attributes to be set by subclasses: - - model_class: The model class to test - - Expected methods to be implemented by subclasses: - - get_init_dict(): Returns dict of arguments to initialize the model - """ - - identifier = uuid.uuid4() - repo_id = f"test-model-{identifier}" - org_repo_id = f"valid_org/{repo_id}-org" - - def test_push_to_hub(self): - """Test pushing model to hub and loading it back.""" - init_dict = self.get_init_dict() - model = self.model_class(**init_dict) - model.push_to_hub(self.repo_id, token=TOKEN) - - new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") - for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained" - - # Reset repo - delete_repo(token=TOKEN, repo_id=self.repo_id) - - # Push to hub via save_pretrained - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN) - - new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}") - for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal(p1, p2), ( - "Parameters don't match after save_pretrained with push_to_hub and from_pretrained" - ) - - # Reset repo - delete_repo(self.repo_id, token=TOKEN) - - def test_push_to_hub_in_organization(self): - """Test pushing model to hub in organization namespace.""" - init_dict = self.get_init_dict() - model = self.model_class(**init_dict) - model.push_to_hub(self.org_repo_id, token=TOKEN) - - new_model = self.model_class.from_pretrained(self.org_repo_id) - for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained" - - # Reset repo - delete_repo(token=TOKEN, repo_id=self.org_repo_id) - - # Push to hub via save_pretrained - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id) - - new_model = self.model_class.from_pretrained(self.org_repo_id) - for p1, p2 in zip(model.parameters(), new_model.parameters()): - assert torch.equal(p1, p2), ( - "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained" - ) - - # Reset repo - delete_repo(self.org_repo_id, token=TOKEN) - - def test_push_to_hub_library_name(self): - """Test that library_name in model card is set to 'diffusers'.""" - if not is_jinja_available(): - pytest.skip("Model card tests cannot be performed without Jinja installed.") - - init_dict = self.get_init_dict() - model = self.model_class(**init_dict) - model.push_to_hub(self.repo_id, token=TOKEN) - - model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data - assert model_card.library_name == "diffusers", ( - f"Expected library_name 'diffusers', got {model_card.library_name}" - ) - - # Reset repo - delete_repo(self.repo_id, token=TOKEN) diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py index 13e141869c3a..891a23d330cc 100644 --- a/tests/models/testing_utils/ip_adapter.py +++ b/tests/models/testing_utils/ip_adapter.py @@ -17,49 +17,9 @@ import pytest import torch -from diffusers.models.attention_processor import IPAdapterAttnProcessor - from ...testing_utils import is_ip_adapter, torch_device -def create_ip_adapter_state_dict(model): - """ - Create a dummy IP Adapter state dict for testing. - - Args: - model: The model to create IP adapter weights for - - Returns: - dict: IP adapter state dict with to_k_ip and to_v_ip weights - """ - ip_state_dict = {} - key_id = 1 - - for name in model.attn_processors.keys(): - # Skip self-attention processors - cross_attention_dim = getattr(model.config, "cross_attention_dim", None) - if cross_attention_dim is None: - continue - - # Get hidden size based on model architecture - hidden_size = getattr(model.config, "hidden_size", cross_attention_dim) - - # Create IP adapter processor to get state dict structure - sd = IPAdapterAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 - ).state_dict() - - ip_state_dict.update( - { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], - } - ) - key_id += 2 - - return {"ip_adapter": ip_state_dict} - - def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool: """ Check if IP Adapter processors are correctly set in the model. diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index b790e3ea263b..5777f992789a 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -79,8 +79,6 @@ class LoraTesterMixin: """ def setup_method(self): - from diffusers.loaders.peft import PeftAdapterMixin - if not issubclass(self.model_class, PeftAdapterMixin): pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).") diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index ebd76656f023..5486bbb0cd23 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -455,10 +455,7 @@ def test_fn(storage_dtype, compute_dtype): inputs_dict = self.get_inputs_dict() inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) with torch.amp.autocast(device_type=torch.device(torch_device).type): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] input_tensor = inputs_dict[self.main_input_name] noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 140f6db6d2ff..b7f960a1351c 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -128,9 +128,9 @@ def _test_quantization_num_parameters(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) num_params_quantized = model_quantized.num_parameters() - assert num_params == num_params_quantized, ( - f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" - ) + assert ( + num_params == num_params_quantized + ), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): model = self._load_unquantized_model() @@ -140,19 +140,17 @@ def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_red mem_quantized = model_quantized.get_memory_footprint() ratio = mem / mem_quantized - assert ratio >= expected_memory_reduction, ( - f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" - ) + assert ( + ratio >= expected_memory_reduction + ), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) with torch.no_grad(): inputs = self.get_dummy_inputs() - output = model_quantized(**inputs) + output = model_quantized(**inputs, return_dict=False)[0] - if isinstance(output, tuple): - output = output[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" @@ -197,10 +195,8 @@ def _test_quantization_lora_inference(self, config_kwargs): with torch.no_grad(): inputs = self.get_dummy_inputs() - output = model(**inputs) + output = model(**inputs, return_dict=False)[0] - if isinstance(output, tuple): - output = output[0] assert output is not None, "Model output is None with LoRA" assert not torch.isnan(output).any(), "Model output contains NaN with LoRA" @@ -214,9 +210,7 @@ def _test_quantization_serialization(self, config_kwargs): with torch.no_grad(): inputs = self.get_dummy_inputs() - output = model_loaded(**inputs) - if isinstance(output, tuple): - output = output[0] + output = model_loaded(**inputs, return_dict=False)[0] assert not torch.isnan(output).any(), "Loaded model output contains NaN" def _test_quantized_layers(self, config_kwargs): @@ -243,12 +237,12 @@ def _test_quantized_layers(self, config_kwargs): self._verify_if_layer_quantized(name, module, config_kwargs) num_quantized_layers += 1 - assert num_quantized_layers > 0, ( - f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" - ) - assert num_quantized_layers == expected_quantized_layers, ( - f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" - ) + assert ( + num_quantized_layers > 0 + ), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" + assert ( + num_quantized_layers == expected_quantized_layers + ), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): """ @@ -272,9 +266,9 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no if any(excluded in name for excluded in modules_to_not_convert): found_excluded = True # This module should NOT be quantized - assert not self._is_module_quantized(module), ( - f"Module {name} should not be quantized but was found to be quantized" - ) + assert not self._is_module_quantized( + module + ), f"Module {name} should not be quantized but was found to be quantized" assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}" @@ -296,9 +290,9 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no mem_with_exclusion = model_with_exclusion.get_memory_footprint() mem_fully_quantized = model_fully_quantized.get_memory_footprint() - assert mem_with_exclusion > mem_fully_quantized, ( - f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" - ) + assert ( + mem_with_exclusion > mem_fully_quantized + ), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" def _test_quantization_device_map(self, config_kwargs): """ @@ -316,12 +310,38 @@ def _test_quantization_device_map(self, config_kwargs): # Verify inference works with torch.no_grad(): inputs = self.get_dummy_inputs() - output = model(**inputs) - if isinstance(output, tuple): - output = output[0] + output = model(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" assert not torch.isnan(output).any(), "Model output contains NaN" + def _test_dequantize(self, config_kwargs): + """ + Test that dequantize() converts quantized model back to standard linear layers. + + Args: + config_kwargs: Quantization config parameters + """ + model = self._create_quantized_model(config_kwargs) + + # Verify model has dequantize method + if not hasattr(model, "dequantize"): + pytest.skip("Model does not have dequantize method") + + # Dequantize the model + model.dequantize() + + # Verify no modules are quantized after dequantization + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()" + + # Verify inference still works after dequantization + with torch.no_grad(): + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None after dequantization" + assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" + @is_bitsandbytes @require_accelerator @@ -379,9 +399,9 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs): def _verify_if_layer_quantized(self, name, module, config_kwargs): expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params - assert module.weight.__class__ == expected_weight_class, ( - f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" - ) + assert ( + module.weight.__class__ == expected_weight_class + ), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) def test_bnb_quantization_num_parameters(self, config_name): @@ -449,13 +469,13 @@ def test_bnb_keep_modules_in_fp32(self): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): - assert module.weight.dtype == torch.float32, ( - f"Module {name} should be FP32 but is {module.weight.dtype}" - ) + assert ( + module.weight.dtype == torch.float32 + ), f"Module {name} should be FP32 but is {module.weight.dtype}" else: - assert module.weight.dtype == torch.uint8, ( - f"Module {name} should be uint8 but is {module.weight.dtype}" - ) + assert ( + module.weight.dtype == torch.uint8 + ), f"Module {name} should be uint8 but is {module.weight.dtype}" with torch.no_grad(): inputs = self.get_dummy_inputs() @@ -476,6 +496,10 @@ def test_bnb_device_map(self): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"]) + def test_bnb_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(self.BNB_CONFIGS["4bit_nf4"]) + @is_quanto @require_quanto @@ -563,6 +587,10 @@ def test_quanto_device_map(self): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"]) + def test_quanto_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(self.QUANTO_WEIGHT_TYPES["int8"]) + @is_torchao @require_accelerator @@ -649,6 +677,10 @@ def test_torchao_device_map(self): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"]) + def test_torchao_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(self.TORCHAO_QUANT_TYPES["int8wo"]) + @is_gguf @require_accelerate @@ -716,24 +748,9 @@ def test_gguf_quantization_dtype_assignment(self): def test_gguf_quantization_lora_inference(self): self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16}) - def test_gguf_dequantize_model(self): - from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter - - model = self._create_quantized_model() - model.dequantize() - - def _check_for_gguf_linear(model): - has_children = list(model.children()) - if not has_children: - return - - for name, module in model.named_children(): - if isinstance(module, torch.nn.Linear): - assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear" - assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter" - - for name, module in model.named_children(): - _check_for_gguf_linear(module) + def test_gguf_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize({"compute_dtype": torch.bfloat16}) def test_gguf_quantized_layers(self): self._test_quantized_layers({"compute_dtype": torch.bfloat16}) @@ -826,3 +843,7 @@ def test_modelopt_modules_to_not_convert(self): def test_modelopt_device_map(self): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"]) + + def test_modelopt_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(self.MODELOPT_CONFIGS["fp8"]) diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py index 7e4193d59e84..f6612dd3be0f 100644 --- a/tests/models/testing_utils/training.py +++ b/tests/models/testing_utils/training.py @@ -50,10 +50,7 @@ def test_training(self): model = self.model_class(**init_dict) model.to(torch_device) model.train() - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) @@ -68,10 +65,7 @@ def test_training_with_ema(self): model.train() ema_model = EMAModel(model.parameters()) - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) @@ -137,9 +131,7 @@ def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_gra assert not model.is_gradient_checkpointing and model.training - out = model(**inputs_dict) - if isinstance(out, dict): - out = out.sample if hasattr(out, "sample") else out.to_tuple()[0] + out = model(**inputs_dict, return_dict=False)[0] # run the backwards pass on the model model.zero_grad() @@ -158,9 +150,7 @@ def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_gra assert model_2.is_gradient_checkpointing and model_2.training - out_2 = model_2(**inputs_dict_copy) - if isinstance(out_2, dict): - out_2 = out_2.sample if hasattr(out_2, "sample") else out_2.to_tuple()[0] + out_2 = model_2(**inputs_dict_copy, return_dict=False)[0] # run the backwards pass on the model model_2.zero_grad() @@ -198,10 +188,7 @@ def test_mixed_precision_training(self): # Test with float16 if torch.device(torch_device).type != "cpu": with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) @@ -212,10 +199,7 @@ def test_mixed_precision_training(self): if torch.device(torch_device).type != "cpu": model.zero_grad() with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.to_tuple()[0] + output = model(**inputs_dict, return_dict=False)[0] noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) diff --git a/utils/generate_model_tests.py b/utils/generate_model_tests.py index ffd600dfdf29..f3860f4b9a90 100644 --- a/utils/generate_model_tests.py +++ b/utils/generate_model_tests.py @@ -43,10 +43,15 @@ ALWAYS_INCLUDE_TESTERS = [ "ModelTesterMixin", "MemoryTesterMixin", - "AttentionTesterMixin", "TorchCompileTesterMixin", ] +# Attention-related class names that indicate the model uses attention +ATTENTION_INDICATORS = { + "AttentionMixin", + "AttentionModuleMixin", +} + OPTIONAL_TESTERS = [ ("BitsAndBytesTesterMixin", "bnb"), ("QuantoTesterMixin", "quanto"), @@ -62,6 +67,17 @@ class ModelAnalyzer(ast.NodeVisitor): def __init__(self): self.model_classes = [] self.current_class = None + self.imports = set() + + def visit_Import(self, node: ast.Import): + for alias in node.names: + self.imports.add(alias.name.split(".")[-1]) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom): + for alias in node.names: + self.imports.add(alias.name) + self.generic_visit(node) def visit_ClassDef(self, node: ast.ClassDef): base_names = [] @@ -164,7 +180,7 @@ def _get_value(self, node): return "" -def analyze_model_file(filepath: str) -> list[dict]: +def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]: with open(filepath) as f: source = f.read() @@ -172,10 +188,10 @@ def analyze_model_file(filepath: str) -> list[dict]: analyzer = ModelAnalyzer() analyzer.visit(tree) - return analyzer.model_classes + return analyzer.model_classes, analyzer.imports -def determine_testers(model_info: dict, include_optional: list[str]) -> list[str]: +def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]: testers = list(ALWAYS_INCLUDE_TESTERS) for base in model_info["bases"]: @@ -195,6 +211,10 @@ def determine_testers(model_info: dict, include_optional: list[str]) -> list[str if "ContextParallelTesterMixin" not in testers: testers.append("ContextParallelTesterMixin") + # Include AttentionTesterMixin if the model imports attention-related classes + if imports & ATTENTION_INDICATORS: + testers.append("AttentionTesterMixin") + for tester, flag in OPTIONAL_TESTERS: if flag in include_optional: if tester not in testers: @@ -335,9 +355,9 @@ def generate_test_class(model_name: str, config_class: str, tester: str) -> str: return "\n".join(lines) -def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str]) -> str: +def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str: model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "") - testers = determine_testers(model_info, include_optional) + testers = determine_testers(model_info, include_optional, imports) tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"}) lines = [ @@ -446,7 +466,7 @@ def main(): print(f"Error: File not found: {args.model_filepath}", file=sys.stderr) sys.exit(1) - model_classes = analyze_model_file(args.model_filepath) + model_classes, imports = analyze_model_file(args.model_filepath) if not model_classes: print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr) @@ -468,7 +488,7 @@ def main(): if "all" in include_optional: include_optional = [flag for _, flag in OPTIONAL_TESTERS] - generated_code = generate_test_file(model_info, args.model_filepath, include_optional) + generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports) if args.dry_run: print(generated_code) From dcd6026d172b519fc4f5115e66bcd5a34a0ee59d Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 15 Dec 2025 16:12:15 +0530 Subject: [PATCH 11/12] update --- tests/models/testing_utils/quantization.py | 119 +++++++++++---------- 1 file changed, 63 insertions(+), 56 deletions(-) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index b7f960a1351c..26904e8cf9dc 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -128,9 +128,9 @@ def _test_quantization_num_parameters(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) num_params_quantized = model_quantized.num_parameters() - assert ( - num_params == num_params_quantized - ), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + assert num_params == num_params_quantized, ( + f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + ) def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): model = self._load_unquantized_model() @@ -140,9 +140,9 @@ def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_red mem_quantized = model_quantized.get_memory_footprint() ratio = mem / mem_quantized - assert ( - ratio >= expected_memory_reduction - ), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + assert ratio >= expected_memory_reduction, ( + f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + ) def _test_quantization_inference(self, config_kwargs): model_quantized = self._create_quantized_model(config_kwargs) @@ -237,12 +237,12 @@ def _test_quantized_layers(self, config_kwargs): self._verify_if_layer_quantized(name, module, config_kwargs) num_quantized_layers += 1 - assert ( - num_quantized_layers > 0 - ), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" - assert ( - num_quantized_layers == expected_quantized_layers - ), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + assert num_quantized_layers > 0, ( + f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" + ) + assert num_quantized_layers == expected_quantized_layers, ( + f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + ) def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): """ @@ -266,9 +266,9 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no if any(excluded in name for excluded in modules_to_not_convert): found_excluded = True # This module should NOT be quantized - assert not self._is_module_quantized( - module - ), f"Module {name} should not be quantized but was found to be quantized" + assert not self._is_module_quantized(module), ( + f"Module {name} should not be quantized but was found to be quantized" + ) assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}" @@ -290,9 +290,9 @@ def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_no mem_with_exclusion = model_with_exclusion.get_memory_footprint() mem_fully_quantized = model_fully_quantized.get_memory_footprint() - assert ( - mem_with_exclusion > mem_fully_quantized - ), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + assert mem_with_exclusion > mem_fully_quantized, ( + f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + ) def _test_quantization_device_map(self, config_kwargs): """ @@ -399,40 +399,40 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs): def _verify_if_layer_quantized(self, name, module, config_kwargs): expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params - assert ( - module.weight.__class__ == expected_weight_class - ), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + assert module.weight.__class__ == expected_weight_class, ( + f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + ) - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) def test_bnb_quantization_num_parameters(self, config_name): self._test_quantization_num_parameters(self.BNB_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) def test_bnb_quantization_memory_footprint(self, config_name): expected = self.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) self._test_quantization_memory_footprint(self.BNB_CONFIGS[config_name], expected_memory_reduction=expected) - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) def test_bnb_quantization_inference(self, config_name): self._test_quantization_inference(self.BNB_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", ["4bit_nf4"]) + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) def test_bnb_quantization_dtype_assignment(self, config_name): self._test_quantization_dtype_assignment(self.BNB_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", ["4bit_nf4"]) + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) def test_bnb_quantization_lora_inference(self, config_name): self._test_quantization_lora_inference(self.BNB_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", ["4bit_nf4"]) + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) def test_bnb_quantization_serialization(self, config_name): self._test_quantization_serialization(self.BNB_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) def test_bnb_quantized_layers(self, config_name): self._test_quantized_layers(self.BNB_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys())) + @pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()), ids=list(BNB_CONFIGS.keys())) def test_bnb_quantization_config_serialization(self, config_name): model = self._create_quantized_model(self.BNB_CONFIGS[config_name]) @@ -469,13 +469,13 @@ def test_bnb_keep_modules_in_fp32(self): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): - assert ( - module.weight.dtype == torch.float32 - ), f"Module {name} should be FP32 but is {module.weight.dtype}" + assert module.weight.dtype == torch.float32, ( + f"Module {name} should be FP32 but is {module.weight.dtype}" + ) else: - assert ( - module.weight.dtype == torch.uint8 - ), f"Module {name} should be uint8 but is {module.weight.dtype}" + assert module.weight.dtype == torch.uint8, ( + f"Module {name} should be uint8 but is {module.weight.dtype}" + ) with torch.no_grad(): inputs = self.get_dummy_inputs() @@ -492,9 +492,10 @@ def test_bnb_modules_to_not_convert(self): self._test_quantization_modules_to_not_convert(self.BNB_CONFIGS["4bit_nf4"], modules_to_exclude) - def test_bnb_device_map(self): + @pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"]) + def test_bnb_device_map(self, config_name): """Test that device_map='auto' works correctly with quantization.""" - self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"]) + self._test_quantization_device_map(self.BNB_CONFIGS[config_name]) def test_bnb_dequantize(self): """Test that dequantize() works correctly.""" @@ -548,30 +549,36 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs): def _verify_if_layer_quantized(self, name, module, config_kwargs): assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}" - @pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys())) + @pytest.mark.parametrize( + "weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()), ids=list(QUANTO_WEIGHT_TYPES.keys()) + ) def test_quanto_quantization_num_parameters(self, weight_type_name): self._test_quantization_num_parameters(self.QUANTO_WEIGHT_TYPES[weight_type_name]) - @pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys())) + @pytest.mark.parametrize( + "weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()), ids=list(QUANTO_WEIGHT_TYPES.keys()) + ) def test_quanto_quantization_memory_footprint(self, weight_type_name): expected = self.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2) self._test_quantization_memory_footprint( self.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected ) - @pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys())) + @pytest.mark.parametrize( + "weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()), ids=list(QUANTO_WEIGHT_TYPES.keys()) + ) def test_quanto_quantization_inference(self, weight_type_name): self._test_quantization_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name]) - @pytest.mark.parametrize("weight_type_name", ["int8"]) + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) def test_quanto_quantized_layers(self, weight_type_name): self._test_quantized_layers(self.QUANTO_WEIGHT_TYPES[weight_type_name]) - @pytest.mark.parametrize("weight_type_name", ["int8"]) + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) def test_quanto_quantization_lora_inference(self, weight_type_name): self._test_quantization_lora_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name]) - @pytest.mark.parametrize("weight_type_name", ["int8"]) + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) def test_quanto_quantization_serialization(self, weight_type_name): self._test_quantization_serialization(self.QUANTO_WEIGHT_TYPES[weight_type_name]) @@ -636,30 +643,30 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs): def _verify_if_layer_quantized(self, name, module, config_kwargs): assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" - @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys())) + @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()), ids=list(TORCHAO_QUANT_TYPES.keys())) def test_torchao_quantization_num_parameters(self, quant_type): self._test_quantization_num_parameters(self.TORCHAO_QUANT_TYPES[quant_type]) - @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys())) + @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()), ids=list(TORCHAO_QUANT_TYPES.keys())) def test_torchao_quantization_memory_footprint(self, quant_type): expected = self.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2) self._test_quantization_memory_footprint( self.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected ) - @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys())) + @pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()), ids=list(TORCHAO_QUANT_TYPES.keys())) def test_torchao_quantization_inference(self, quant_type): self._test_quantization_inference(self.TORCHAO_QUANT_TYPES[quant_type]) - @pytest.mark.parametrize("quant_type", ["int8wo"]) + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) def test_torchao_quantized_layers(self, quant_type): self._test_quantized_layers(self.TORCHAO_QUANT_TYPES[quant_type]) - @pytest.mark.parametrize("quant_type", ["int8wo"]) + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) def test_torchao_quantization_lora_inference(self, quant_type): self._test_quantization_lora_inference(self.TORCHAO_QUANT_TYPES[quant_type]) - @pytest.mark.parametrize("quant_type", ["int8wo"]) + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) def test_torchao_quantization_serialization(self, quant_type): self._test_quantization_serialization(self.TORCHAO_QUANT_TYPES[quant_type]) @@ -801,34 +808,34 @@ def _create_quantized_model(self, config_kwargs, **extra_kwargs): def _verify_if_layer_quantized(self, name, module, config_kwargs): assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)" - @pytest.mark.parametrize("config_name", ["fp8"]) + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) def test_modelopt_quantization_num_parameters(self, config_name): self._test_quantization_num_parameters(self.MODELOPT_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys())) + @pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()), ids=list(MODELOPT_CONFIGS.keys())) def test_modelopt_quantization_memory_footprint(self, config_name): expected = self.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) self._test_quantization_memory_footprint( self.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected ) - @pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys())) + @pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()), ids=list(MODELOPT_CONFIGS.keys())) def test_modelopt_quantization_inference(self, config_name): self._test_quantization_inference(self.MODELOPT_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", ["fp8"]) + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) def test_modelopt_quantization_dtype_assignment(self, config_name): self._test_quantization_dtype_assignment(self.MODELOPT_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", ["fp8"]) + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) def test_modelopt_quantization_lora_inference(self, config_name): self._test_quantization_lora_inference(self.MODELOPT_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", ["fp8"]) + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) def test_modelopt_quantization_serialization(self, config_name): self._test_quantization_serialization(self.MODELOPT_CONFIGS[config_name]) - @pytest.mark.parametrize("config_name", ["fp8"]) + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) def test_modelopt_quantized_layers(self, config_name): self._test_quantized_layers(self.MODELOPT_CONFIGS[config_name]) From d9b73ffd5155507c7990aa792b439b886c35dbb9 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 15 Dec 2025 16:12:50 +0530 Subject: [PATCH 12/12] update --- tests/models/testing_utils/parallelism.py | 102 ++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/models/testing_utils/parallelism.py diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py new file mode 100644 index 000000000000..3bbbfe91bbd3 --- /dev/null +++ b/tests/models/testing_utils/parallelism.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch +import torch.multiprocessing as mp + +from diffusers.models._modeling_parallel import ContextParallelConfig + +from ...testing_utils import ( + is_context_parallel, + require_torch_multi_accelerator, +) + + +def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue): + try: + # Setup distributed environment + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + torch.distributed.init_process_group( + backend="nccl", + init_method="env://", + world_size=world_size, + rank=rank, + ) + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + + model = model_class(**init_dict) + model.to(device) + model.eval() + + inputs_on_device = {} + for key, value in inputs_dict.items(): + if isinstance(value, torch.Tensor): + inputs_on_device[key] = value.to(device) + else: + inputs_on_device[key] = value + + cp_config = ContextParallelConfig(**cp_dict) + model.enable_parallelism(config=cp_config) + + with torch.no_grad(): + output = model(**inputs_on_device, return_dict=False)[0] + + if rank == 0: + result_queue.put(("success", output.shape)) + + except Exception as e: + if rank == 0: + result_queue.put(("error", str(e))) + finally: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +@is_context_parallel +@require_torch_multi_accelerator +class ContextParallelTesterMixin: + base_precision = 1e-3 + + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_inference(self, cp_type): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + cp_dict = {cp_type: world_size} + + ctx = mp.get_context("spawn") + result_queue = ctx.Queue() + + mp.spawn( + _context_parallel_worker, + args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue), + nprocs=world_size, + join=True, + ) + + status, result = result_queue.get(timeout=60) + assert status == "success", f"Context parallel inference failed: {result}"