-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[WIP] Refactor Model Tests #12822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] Refactor Model Tests #12822
Changes from all commits
1f026ad
1c55871
bffa3a9
aa29af8
0f1a4e0
fe451c3
489480b
0fdd9d3
c366b5a
d08e0bb
eae7543
dcd6026
d9b73ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| from .attention import AttentionTesterMixin | ||
| from .common import BaseModelTesterConfig, ModelTesterMixin | ||
| from .compile import TorchCompileTesterMixin | ||
| from .ip_adapter import IPAdapterTesterMixin | ||
| from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin | ||
| from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin | ||
| from .parallelism import ContextParallelTesterMixin | ||
| from .quantization import ( | ||
| BitsAndBytesTesterMixin, | ||
| GGUFTesterMixin, | ||
| ModelOptTesterMixin, | ||
| QuantizationTesterMixin, | ||
| QuantoTesterMixin, | ||
| TorchAoTesterMixin, | ||
| ) | ||
| from .single_file import SingleFileTesterMixin | ||
| from .training import TrainingTesterMixin | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "AttentionTesterMixin", | ||
| "BaseModelTesterConfig", | ||
| "BitsAndBytesTesterMixin", | ||
| "ContextParallelTesterMixin", | ||
| "CPUOffloadTesterMixin", | ||
| "GGUFTesterMixin", | ||
| "GroupOffloadTesterMixin", | ||
| "IPAdapterTesterMixin", | ||
| "LayerwiseCastingTesterMixin", | ||
| "LoraHotSwappingForModelTesterMixin", | ||
| "LoraTesterMixin", | ||
| "MemoryTesterMixin", | ||
| "ModelOptTesterMixin", | ||
| "ModelTesterMixin", | ||
| "QuantizationTesterMixin", | ||
| "QuantoTesterMixin", | ||
| "SingleFileTesterMixin", | ||
| "TorchAoTesterMixin", | ||
| "TorchCompileTesterMixin", | ||
| "TrainingTesterMixin", | ||
| ] |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,185 @@ | ||||||
| # coding=utf-8 | ||||||
| # Copyright 2025 HuggingFace Inc. | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| import pytest | ||||||
| import torch | ||||||
|
|
||||||
| from diffusers.models.attention import AttentionModuleMixin | ||||||
| from diffusers.models.attention_processor import ( | ||||||
| AttnProcessor, | ||||||
| ) | ||||||
|
|
||||||
| from ...testing_utils import ( | ||||||
| assert_tensors_close, | ||||||
| is_attention, | ||||||
| torch_device, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| @is_attention | ||||||
| class AttentionTesterMixin: | ||||||
| """ | ||||||
| Mixin class for testing attention processor and module functionality on models. | ||||||
|
|
||||||
| Tests functionality from AttentionModuleMixin including: | ||||||
| - Attention processor management (set/get) | ||||||
| - QKV projection fusion/unfusion | ||||||
| - Attention backends (XFormers, NPU, etc.) | ||||||
|
|
||||||
| Expected class attributes to be set by subclasses: | ||||||
| - model_class: The model class to test | ||||||
| - base_precision: Tolerance for floating point comparisons (default: 1e-3) | ||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| - uses_custom_attn_processor: Whether model uses custom attention processors (default: False) | ||||||
|
|
||||||
| Expected methods to be implemented by subclasses: | ||||||
| - get_init_dict(): Returns dict of arguments to initialize the model | ||||||
| - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass | ||||||
|
|
||||||
| Pytest mark: attention | ||||||
| Use `pytest -m "not attention"` to skip these tests | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we implement it in an individual model testing class? For example, say we want to skip it for model X where its attention class doesn't inherit from
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally, any model using attention also uses AttentionModuleMixin. The options here
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But there are important classes like Autoencoders that don't use the Attention mixins. Let's do this?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see your point! So, as long as there is an So, maybe for each of the tests, at the beginning, we could check if their attention classes inherit from
and if that's not the case, we skip. Otherwise, I think it could be cumbersome to check which model tests should and shouldn't use this class because attention is a common component. |
||||||
| """ | ||||||
|
|
||||||
| base_precision = 1e-3 | ||||||
|
|
||||||
| def test_fuse_unfuse_qkv_projections(self): | ||||||
| init_dict = self.get_init_dict() | ||||||
| inputs_dict = self.get_dummy_inputs() | ||||||
| model = self.model_class(**init_dict) | ||||||
| model.to(torch_device) | ||||||
| model.eval() | ||||||
|
|
||||||
| if not hasattr(model, "fuse_qkv_projections"): | ||||||
| pytest.skip("Model does not support QKV projection fusion.") | ||||||
|
|
||||||
| # Get output before fusion | ||||||
| with torch.no_grad(): | ||||||
| output_before_fusion = model(**inputs_dict, return_dict=False)[0] | ||||||
|
|
||||||
| # Fuse projections | ||||||
| model.fuse_qkv_projections() | ||||||
|
|
||||||
| # Verify fusion occurred by checking for fused attributes | ||||||
| has_fused_projections = False | ||||||
| for module in model.modules(): | ||||||
| if isinstance(module, AttentionModuleMixin): | ||||||
| if hasattr(module, "to_qkv") or hasattr(module, "to_kv"): | ||||||
| has_fused_projections = True | ||||||
| assert module.fused_projections, "fused_projections flag should be True" | ||||||
| break | ||||||
|
|
||||||
| if has_fused_projections: | ||||||
| # Get output after fusion | ||||||
| with torch.no_grad(): | ||||||
| output_after_fusion = model(**inputs_dict, return_dict=False)[0] | ||||||
|
|
||||||
| # Verify outputs match | ||||||
| assert_tensors_close( | ||||||
| output_before_fusion, | ||||||
| output_after_fusion, | ||||||
| atol=self.base_precision, | ||||||
| rtol=0, | ||||||
| msg="Output should not change after fusing projections", | ||||||
| ) | ||||||
|
|
||||||
| # Unfuse projections | ||||||
| model.unfuse_qkv_projections() | ||||||
|
|
||||||
| # Verify unfusion occurred | ||||||
| for module in model.modules(): | ||||||
| if isinstance(module, AttentionModuleMixin): | ||||||
| assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing" | ||||||
| assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing" | ||||||
| assert not module.fused_projections, "fused_projections flag should be False" | ||||||
|
|
||||||
| # Get output after unfusion | ||||||
| with torch.no_grad(): | ||||||
| output_after_unfusion = model(**inputs_dict, return_dict=False)[0] | ||||||
|
|
||||||
| # Verify outputs still match | ||||||
| assert_tensors_close( | ||||||
| output_before_fusion, | ||||||
| output_after_unfusion, | ||||||
| atol=self.base_precision, | ||||||
| rtol=0, | ||||||
| msg="Output should match original after unfusing projections", | ||||||
| ) | ||||||
|
|
||||||
| def test_get_set_processor(self): | ||||||
| init_dict = self.get_init_dict() | ||||||
| model = self.model_class(**init_dict) | ||||||
| model.to(torch_device) | ||||||
|
|
||||||
| # Check if model has attention processors | ||||||
| if not hasattr(model, "attn_processors"): | ||||||
| pytest.skip("Model does not have attention processors.") | ||||||
|
|
||||||
| # Test getting processors | ||||||
| processors = model.attn_processors | ||||||
| assert isinstance(processors, dict), "attn_processors should return a dict" | ||||||
| assert len(processors) > 0, "Model should have at least one attention processor" | ||||||
|
|
||||||
| # Test that all processors can be retrieved via get_processor | ||||||
| for module in model.modules(): | ||||||
| if isinstance(module, AttentionModuleMixin): | ||||||
| processor = module.get_processor() | ||||||
| assert processor is not None, "get_processor should return a processor" | ||||||
|
|
||||||
| # Test setting a new processor | ||||||
| new_processor = AttnProcessor() | ||||||
| module.set_processor(new_processor) | ||||||
| retrieved_processor = module.get_processor() | ||||||
| assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set" | ||||||
|
|
||||||
| def test_attention_processor_dict(self): | ||||||
| init_dict = self.get_init_dict() | ||||||
| model = self.model_class(**init_dict) | ||||||
| model.to(torch_device) | ||||||
|
|
||||||
| if not hasattr(model, "set_attn_processor"): | ||||||
| pytest.skip("Model does not support setting attention processors.") | ||||||
|
|
||||||
| # Get current processors | ||||||
| current_processors = model.attn_processors | ||||||
|
|
||||||
| # Create a dict of new processors | ||||||
| new_processors = {key: AttnProcessor() for key in current_processors.keys()} | ||||||
|
|
||||||
| # Set processors using dict | ||||||
| model.set_attn_processor(new_processors) | ||||||
|
|
||||||
| # Verify all processors were set | ||||||
| updated_processors = model.attn_processors | ||||||
| for key in current_processors.keys(): | ||||||
| assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor" | ||||||
|
|
||||||
| def test_attention_processor_count_mismatch_raises_error(self): | ||||||
| init_dict = self.get_init_dict() | ||||||
| model = self.model_class(**init_dict) | ||||||
| model.to(torch_device) | ||||||
|
|
||||||
| if not hasattr(model, "set_attn_processor"): | ||||||
| pytest.skip("Model does not support setting attention processors.") | ||||||
|
|
||||||
| # Get current processors | ||||||
| current_processors = model.attn_processors | ||||||
|
|
||||||
| # Create a dict with wrong number of processors | ||||||
| wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()} | ||||||
|
|
||||||
| # Verify error is raised | ||||||
| with pytest.raises(ValueError) as exc_info: | ||||||
| model.set_attn_processor(wrong_processors) | ||||||
|
|
||||||
| assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to club this with attention backends?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can. But I think it's a bit out of scope for this PR because we need to create a container with the relevant backends available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to do it in this PR then.
But just as a note, most of the attention backends that have complex installation processes, we should actually encourage users to rely on their
kernels-variants (FA2, FA3, SAGE). This way, we won't have to build any containers.