Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@

def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
config.addinivalue_line("markers", "training: marks tests for training functionality")
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to club this with attention backends?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can. But I think it's a bit out of scope for this PR because we need to create a container with the relevant backends available.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to do it in this PR then.

But just as a note, most of the attention backends that have complex installation processes, we should actually encourage users to rely on their kernels-variants (FA2, FA3, SAGE). This way, we won't have to build any containers.

config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")


def pytest_addoption(parser):
Expand Down
41 changes: 41 additions & 0 deletions tests/models/testing_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from .attention import AttentionTesterMixin
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .quantization import (
BitsAndBytesTesterMixin,
GGUFTesterMixin,
ModelOptTesterMixin,
QuantizationTesterMixin,
QuantoTesterMixin,
TorchAoTesterMixin,
)
from .single_file import SingleFileTesterMixin
from .training import TrainingTesterMixin


__all__ = [
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesTesterMixin",
"ContextParallelTesterMixin",
"CPUOffloadTesterMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
"LayerwiseCastingTesterMixin",
"LoraHotSwappingForModelTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptTesterMixin",
"ModelTesterMixin",
"QuantizationTesterMixin",
"QuantoTesterMixin",
"SingleFileTesterMixin",
"TorchAoTesterMixin",
"TorchCompileTesterMixin",
"TrainingTesterMixin",
]
185 changes: 185 additions & 0 deletions tests/models/testing_utils/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)

from ...testing_utils import (
assert_tensors_close,
is_attention,
torch_device,
)


@is_attention
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.

Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)

Expected class attributes to be set by subclasses:
- model_class: The model class to test
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)

Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass

Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we implement it in an individual model testing class? For example, say we want to skip it for model X where its attention class doesn't inherit from AttentionModuleMixin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, any model using attention also uses AttentionModuleMixin. The options here

  1. Do not add tests from attention Mixin to a module file.
  2. Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there are important classes like Autoencoders that don't use the Attention mixins.

Let's do this?

Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there are important classes like Autoencoders that don't use the Attention mixins.
The check is for AttentionModuleMixin not AttentionMixin and Autoencoders do use it

ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin

Copy link
Member

@sayakpaul sayakpaul Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see your point!

So, as long as there is an Attention module this class should apply.

So, maybe for each of the tests, at the beginning, we could check if their attention classes inherit from

if isinstance(module, AttentionModuleMixin):

and if that's not the case, we skip.

Otherwise, I think it could be cumbersome to check which model tests should and shouldn't use this class because attention is a common component.

"""

base_precision = 1e-3

def test_fuse_unfuse_qkv_projections(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")

# Get output before fusion
with torch.no_grad():
output_before_fusion = model(**inputs_dict, return_dict=False)[0]

# Fuse projections
model.fuse_qkv_projections()

# Verify fusion occurred by checking for fused attributes
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break

if has_fused_projections:
# Get output after fusion
with torch.no_grad():
output_after_fusion = model(**inputs_dict, return_dict=False)[0]

# Verify outputs match
assert_tensors_close(
output_before_fusion,
output_after_fusion,
atol=self.base_precision,
rtol=0,
msg="Output should not change after fusing projections",
)

# Unfuse projections
model.unfuse_qkv_projections()

# Verify unfusion occurred
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"

# Get output after unfusion
with torch.no_grad():
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]

# Verify outputs still match
assert_tensors_close(
output_before_fusion,
output_after_unfusion,
atol=self.base_precision,
rtol=0,
msg="Output should match original after unfusing projections",
)

def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)

# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")

# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"

# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"

# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"

def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)

if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")

# Get current processors
current_processors = model.attn_processors

# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}

# Set processors using dict
model.set_attn_processor(new_processors)

# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"

def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)

if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")

# Get current processors
current_processors = model.attn_processors

# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}

# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)

assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
Loading
Loading