Skip to content

Conversation

@DN6
Copy link
Collaborator

@DN6 DN6 commented Dec 11, 2025

What does this PR do?

Following the plan outlined for Diffusers 1.0.0, this PR introduces changes to our model testing approach in order to reduce the overhead involved in adding comprehensive tests for new models and standardize tests across all models.

Changes include

  1. Introducing feature specific tester Mixins and marks for models (breaking up the very large ModelTesterMixin class)
  2. Introduce new test file structure using Config + Mixin pattern
  3. New markers for selective test execution
  4. Adds a ulility script (generate_model_tests.py) to automatically generate tests based on the model file. Also provide a flag that allows us to include any optional features to test e.g. (we can turn this into a bot down the line)

I've only made changes to Flux to make this PR easy to review. I'll open follow ups in phases for the other models once this is approved.

python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_qwenimage.py 

Will now generate a template test file that can be populated with the necessary config information

# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from diffusers import QwenImageTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
from ..testing_utils import (
    AttentionTesterMixin,
    ContextParallelTesterMixin,
    LoraTesterMixin,
    MemoryTesterMixin,
    ModelTesterMixin,
    TorchCompileTesterMixin,
    TrainingTesterMixin,
)


enable_full_determinism()


class QwenImageTransformerTesterConfig:
    model_class = QwenImageTransformer2DModel
    pretrained_model_name_or_path = ""
    pretrained_model_kwargs = {"subfolder": "transformer"}

    @property
    def generator(self):
        return torch.Generator("cpu").manual_seed(0)

    def get_init_dict(self) -> dict[str, int | list[int]]:
        # __init__ parameters:
        #   patch_size: int = 2
        #   in_channels: int = 64
        #   out_channels: Optional[int] = 16
        #   num_layers: int = 60
        #   attention_head_dim: int = 128
        #   num_attention_heads: int = 24
        #   joint_attention_dim: int = 3584
        #   guidance_embeds: bool = False
        #   axes_dims_rope: Tuple[int, int, int] = <complex>
        return {}

    def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
        # forward() parameters:
        #   hidden_states: torch.Tensor
        #   encoder_hidden_states: torch.Tensor
        #   encoder_hidden_states_mask: torch.Tensor
        #   timestep: torch.LongTensor
        #   img_shapes: Optional[List[Tuple[int, int, int]]]
        #   txt_seq_lens: Optional[List[int]]
        #   guidance: torch.Tensor
        #   attention_kwargs: Optional[Dict[str, Any]]
        #   controlnet_block_samples
        #   return_dict: bool = True
        # TODO: Fill in dummy inputs
        return {}

    @property
    def input_shape(self) -> tuple[int, ...]:
        return (1, 1)

    @property
    def output_shape(self) -> tuple[int, ...]:
        return (1, 1)


class TestQwenImageTransformerModel(QwenImageTransformerTesterConfig, ModelTesterMixin):
    pass


class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
    pass


class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
    pass


class TestQwenImageTransformerTorchCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
    different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]

    def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
        # TODO: Implement dynamic input generation
        return {}


class TestQwenImageTransformerLora(QwenImageTransformerTesterConfig, LoraTesterMixin):
    pass


class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
    pass


class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
    pass


class TestQwenImageTransformerLoraHotSwappingForModel(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
    different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]

    def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
        # TODO: Implement dynamic input generation
        return {}

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 DN6 requested review from dg845, sayakpaul and yiyixuxu December 11, 2025 06:16
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent stuff! S

ome general comments:

  • Normalize the model outputs to a common format before they go to torch.allclose().
  • Initialize the input dict newly before passing to a new initialization of the model with torch.manual_seed(0). This is because some autoencoder models take a generator input.
  • Use fixtures wherever possible to reduce boilerplate and take advantage of pytest features.
    • One particular session-level fixture could be base_output. It should help reduce test time quite a bit.
  • Use pytest.mark.parametrize where possible.

Okay for me to do in a future PR but:

  • Should also account for the attention backends.
  • Should we also do a cross between CP and attention backends?
  • How about the caching mixins?

Some nits:

  • Use torch.no_grad() as an entire decorator as opposed to using it inside the functions.

config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
config.addinivalue_line("markers", "training: marks tests for training functionality")
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to club this with attention backends?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can. But I think it's a bit out of scope for this PR because we need to create a container with the relevant backends available.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to do it in this PR then.

But just as a note, most of the attention backends that have complex installation processes, we should actually encourage users to rely on their kernels-variants (FA2, FA3, SAGE). This way, we won't have to build any containers.

- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we implement it in an individual model testing class? For example, say we want to skip it for model X where its attention class doesn't inherit from AttentionModuleMixin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, any model using attention also uses AttentionModuleMixin. The options here

  1. Do not add tests from attention Mixin to a module file.
  2. Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there are important classes like Autoencoders that don't use the Attention mixins.

Let's do this?

Add a decorator that makes AttentionModuleMixin a requirement for running attention tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there are important classes like Autoencoders that don't use the Attention mixins.
The check is for AttentionModuleMixin not AttentionMixin and Autoencoders do use it

ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin

Copy link
Member

@sayakpaul sayakpaul Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see your point!

So, as long as there is an Attention module this class should apply.

So, maybe for each of the tests, at the beginning, we could check if their attention classes inherit from

if isinstance(module, AttentionModuleMixin):

and if that's not the case, we skip.

Otherwise, I think it could be cumbersome to check which model tests should and shouldn't use this class because attention is a common component.


model.to(torch_device)

def _test_quantization_lora_inference(self, config_kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool rewrite!


model.to(torch_device)

def _test_quantization_lora_inference(self, config_kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should also be a corresponding training test similar to:

def test_training(self):

if isinstance(output, tuple):
output = output[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dequantize() tests are missing from bitsandbytes but they exist:

def test_generate_quality_dequantize(self):

def test_gguf_quantized_layers(self):
self._test_quantized_layers({"compute_dtype": torch.bfloat16})


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we include:

@@ -0,0 +1,489 @@
#!/usr/bin/env python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it cannot currently generate the dummy input and init dicts? I would understand if so because inferring those is quite non-trivial.

model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include a training test too?

@sayakpaul
Copy link
Member

One thing I think we should do is get a coverage report for tests/models with `main and with this PR and confirm we are not skipping anything truly critical.

If we are, then we should likely be able to explain why that's the case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants