Skip to content

Conversation

@tcaimm
Copy link
Contributor

@tcaimm tcaimm commented Jan 21, 2026

What does this PR do?

Expand LoRA support for FLUX.2 series single stream blocks and update docs

1. Architectural Evolution

Compared to the original FLUX framework, the FLUX.2 architecture has undergone significant changes. Firstly, the number of single-stream layers is far greater than that of double-stream layers. Furthermore, In the Single transformer Blocks, the q,k,v projections are fused with the MLP into a single unified linear layer: attn.to_qkv_mlp_proj.

Therefore, using Flux's Lora configuration to train Flux2 is insufficient.

2. Implementation Updates

To address these changes, I have updated the LoRA configuration in the following training scripts and added additional notes to the readme:

  • examples/dreambooth/README_flux2.md
  • examples/dreambooth/train_dreambooth_lora_flux2.py
  • examples/dreambooth/train_dreambooth_lora_flux2_klein.py

The target_modules logic has been modified to ensure that the Lora adapter can correctly train the main attention layers of both double-stream and single-stream layers.Flux2 has 48 single-stream layers, while Klein has 24.

# Updated LoRA targeting logic for FLUX.2 series
target_modules = [
    "to_k", "to_q", "to_v", "to_out.0"
] + [
    "to_qkv_mlp_proj", 
    *[f"single_transformer_blocks.{i}.attn.to_out" for i in range(nums)]
]

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@tcaimm tcaimm changed the title Add train flux2 lora config Add train flux2 series lora config Jan 21, 2026
@tcaimm
Copy link
Contributor Author

tcaimm commented Jan 22, 2026

@sayakpaul Please take a look at this PR. Thank you for your help!

@sayakpaul sayakpaul requested a review from linoytsaban January 22, 2026 03:13
@tcaimm
Copy link
Contributor Author

tcaimm commented Jan 24, 2026

@linoytsaban Please take a look at this PR. Thank you for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant