Skip to content

Commit 5851928

Browse files
Support for control-lora (#10686)
* run control-lora on diffusers * cannot load lora adapter * test * 1 * add control-lora * 1 * 1 * 1 * fix PeftAdapterMixin * fix module_to_save bug * delete json print * resolve conflits * merged but bug * change peft.py * 1 * delete state_dict print * fix alpha * Create control_lora.py * Add files via upload * rename * no need modify as peft updated * add doc * fix code style * styling isn't that hard 😉 * empty --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 0c1ccc0 commit 5851928

File tree

7 files changed

+312
-1
lines changed

7 files changed

+312
-1
lines changed

docs/source/en/api/models/controlnet.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ url = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/m
3333
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
3434
```
3535

36+
## Loading from Control LoRA
37+
38+
Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.
39+
40+
```py
41+
from diffusers import ControlNetModel, UNet2DConditionModel
42+
43+
lora_id = "stabilityai/control-lora"
44+
lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors"
45+
46+
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.bfloat16).to("cuda")
47+
controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16)
48+
controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)
49+
```
50+
3651
## ControlNetModel
3752

3853
[[autodoc]] ControlNetModel
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Control-LoRA inference example
2+
3+
Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.
4+
5+
## Installing the dependencies
6+
7+
Before running the scripts, make sure to install the library's training dependencies:
8+
9+
**Important**
10+
11+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
12+
```bash
13+
git clone https://github.com/huggingface/diffusers
14+
cd diffusers
15+
pip install .
16+
```
17+
18+
Then cd in the example folder and run
19+
```bash
20+
pip install -r requirements.txt
21+
```
22+
23+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
24+
25+
```bash
26+
accelerate config
27+
```
28+
29+
## Inference on SDXL
30+
31+
[stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) provides a set of Control-LoRA weights for SDXL. Here we use the `canny` condition to generate an image from a text prompt and a reference image.
32+
33+
```bash
34+
python control_lora.py
35+
```
36+
37+
## Acknowledgements
38+
39+
- [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora)
40+
- [comfyanonymous/ControlNet-v1-1_fp16_safetensors](https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors)
41+
- [HighCWu/control-lora-v2](https://github.com/HighCWu/control-lora-v2)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import cv2
2+
import numpy as np
3+
import torch
4+
from PIL import Image
5+
6+
from diffusers import (
7+
AutoencoderKL,
8+
ControlNetModel,
9+
StableDiffusionXLControlNetPipeline,
10+
UNet2DConditionModel,
11+
)
12+
from diffusers.utils import load_image, make_image_grid
13+
14+
15+
pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
16+
lora_id = "stabilityai/control-lora"
17+
lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors"
18+
19+
unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.bfloat16).to("cuda")
20+
controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16)
21+
controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)
22+
23+
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
24+
negative_prompt = "low quality, bad quality, sketches"
25+
26+
image = load_image(
27+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
28+
)
29+
30+
controlnet_conditioning_scale = 1.0 # recommended for good generalization
31+
32+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.bfloat16)
33+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
34+
pipe_id,
35+
unet=unet,
36+
controlnet=controlnet,
37+
vae=vae,
38+
torch_dtype=torch.bfloat16,
39+
safety_checker=None,
40+
).to("cuda")
41+
42+
image = np.array(image)
43+
image = cv2.Canny(image, 100, 200)
44+
image = image[:, :, None]
45+
image = np.concatenate([image, image, image], axis=2)
46+
image = Image.fromarray(image)
47+
48+
images = pipe(
49+
prompt,
50+
negative_prompt=negative_prompt,
51+
image=image,
52+
controlnet_conditioning_scale=controlnet_conditioning_scale,
53+
num_images_per_prompt=4,
54+
).images
55+
56+
final_image = [image] + images
57+
grid = make_image_grid(final_image, 1, 5)
58+
grid.save("hf-logo_canny.png")

src/diffusers/loaders/peft.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
MIN_PEFT_VERSION,
2828
USE_PEFT_BACKEND,
2929
check_peft_version,
30+
convert_sai_sd_control_lora_state_dict_to_peft,
3031
convert_unet_state_dict_to_peft,
3132
delete_adapter_layers,
3233
get_adapter_name,
@@ -232,6 +233,13 @@ def load_lora_adapter(
232233
if "lora_A" not in first_key:
233234
state_dict = convert_unet_state_dict_to_peft(state_dict)
234235

236+
# Control LoRA from SAI is different from BFL Control LoRA
237+
# https://huggingface.co/stabilityai/control-lora
238+
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
239+
is_sai_sd_control_lora = "lora_controlnet" in state_dict
240+
if is_sai_sd_control_lora:
241+
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
242+
235243
rank = {}
236244
for key, val in state_dict.items():
237245
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
@@ -263,6 +271,14 @@ def load_lora_adapter(
263271
adapter_name=adapter_name,
264272
)
265273

274+
# Adjust LoRA config for Control LoRA
275+
if is_sai_sd_control_lora:
276+
lora_config.lora_alpha = lora_config.r
277+
lora_config.alpha_pattern = lora_config.rank_pattern
278+
lora_config.bias = "all"
279+
lora_config.modules_to_save = lora_config.exclude_modules
280+
lora_config.exclude_modules = None
281+
266282
# <Unsafe code
267283
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
268284
# Now we remove any existing hooks to `_pipeline`.

src/diffusers/models/controlnets/controlnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.nn import functional as F
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
22+
from ...loaders import PeftAdapterMixin
2223
from ...loaders.single_file_model import FromOriginalModelMixin
2324
from ...utils import BaseOutput, logging
2425
from ..attention import AttentionMixin
@@ -106,7 +107,7 @@ def forward(self, conditioning):
106107
return embedding
107108

108109

109-
class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin):
110+
class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
110111
"""
111112
A ControlNet model.
112113

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from .remote_utils import remote_decode
144144
from .state_dict_utils import (
145145
convert_all_state_dict_to_peft,
146+
convert_sai_sd_control_lora_state_dict_to_peft,
146147
convert_state_dict_to_diffusers,
147148
convert_state_dict_to_kohya,
148149
convert_state_dict_to_peft,

src/diffusers/utils/state_dict_utils.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,36 @@ class StateDictType(enum.Enum):
5656
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
5757
}
5858

59+
CONTROL_LORA_TO_DIFFUSERS = {
60+
".to_q.down": ".to_q.lora_A.weight",
61+
".to_q.up": ".to_q.lora_B.weight",
62+
".to_k.down": ".to_k.lora_A.weight",
63+
".to_k.up": ".to_k.lora_B.weight",
64+
".to_v.down": ".to_v.lora_A.weight",
65+
".to_v.up": ".to_v.lora_B.weight",
66+
".to_out.0.down": ".to_out.0.lora_A.weight",
67+
".to_out.0.up": ".to_out.0.lora_B.weight",
68+
".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight",
69+
".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight",
70+
".ff.net.2.down": ".ff.net.2.lora_A.weight",
71+
".ff.net.2.up": ".ff.net.2.lora_B.weight",
72+
".proj_in.down": ".proj_in.lora_A.weight",
73+
".proj_in.up": ".proj_in.lora_B.weight",
74+
".proj_out.down": ".proj_out.lora_A.weight",
75+
".proj_out.up": ".proj_out.lora_B.weight",
76+
".conv.down": ".conv.lora_A.weight",
77+
".conv.up": ".conv.lora_B.weight",
78+
**{f".conv{i}.down": f".conv{i}.lora_A.weight" for i in range(1, 3)},
79+
**{f".conv{i}.up": f".conv{i}.lora_B.weight" for i in range(1, 3)},
80+
"conv_in.down": "conv_in.lora_A.weight",
81+
"conv_in.up": "conv_in.lora_B.weight",
82+
".conv_shortcut.down": ".conv_shortcut.lora_A.weight",
83+
".conv_shortcut.up": ".conv_shortcut.lora_B.weight",
84+
**{f".linear_{i}.down": f".linear_{i}.lora_A.weight" for i in range(1, 3)},
85+
**{f".linear_{i}.up": f".linear_{i}.lora_B.weight" for i in range(1, 3)},
86+
"time_emb_proj.down": "time_emb_proj.lora_A.weight",
87+
"time_emb_proj.up": "time_emb_proj.lora_B.weight",
88+
}
5989

6090
DIFFUSERS_TO_PEFT = {
6191
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
@@ -259,6 +289,155 @@ def convert_unet_state_dict_to_peft(state_dict):
259289
return convert_state_dict(state_dict, mapping)
260290

261291

292+
def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
293+
def _convert_controlnet_to_diffusers(state_dict):
294+
is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict
295+
logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})")
296+
297+
# Retrieves the keys for the input blocks only
298+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer})
299+
input_blocks = {
300+
layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key]
301+
for layer_id in range(num_input_blocks)
302+
}
303+
layers_per_block = 2
304+
305+
# op blocks
306+
op_blocks = [key for key in state_dict if "0.op" in key]
307+
308+
converted_state_dict = {}
309+
# Conv in layers
310+
for key in input_blocks[0]:
311+
diffusers_key = key.replace("input_blocks.0.0", "conv_in")
312+
converted_state_dict[diffusers_key] = state_dict.get(key)
313+
314+
# controlnet time embedding blocks
315+
time_embedding_blocks = [key for key in state_dict if "time_embed" in key]
316+
for key in time_embedding_blocks:
317+
diffusers_key = key.replace("time_embed.0", "time_embedding.linear_1").replace(
318+
"time_embed.2", "time_embedding.linear_2"
319+
)
320+
converted_state_dict[diffusers_key] = state_dict.get(key)
321+
322+
# controlnet label embedding blocks
323+
label_embedding_blocks = [key for key in state_dict if "label_emb" in key]
324+
for key in label_embedding_blocks:
325+
diffusers_key = key.replace("label_emb.0.0", "add_embedding.linear_1").replace(
326+
"label_emb.0.2", "add_embedding.linear_2"
327+
)
328+
converted_state_dict[diffusers_key] = state_dict.get(key)
329+
330+
# Down blocks
331+
for i in range(1, num_input_blocks):
332+
block_id = (i - 1) // (layers_per_block + 1)
333+
layer_in_block_id = (i - 1) % (layers_per_block + 1)
334+
335+
resnets = [
336+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
337+
]
338+
for key in resnets:
339+
diffusers_key = (
340+
key.replace("in_layers.0", "norm1")
341+
.replace("in_layers.2", "conv1")
342+
.replace("out_layers.0", "norm2")
343+
.replace("out_layers.3", "conv2")
344+
.replace("emb_layers.1", "time_emb_proj")
345+
.replace("skip_connection", "conv_shortcut")
346+
)
347+
diffusers_key = diffusers_key.replace(
348+
f"input_blocks.{i}.0", f"down_blocks.{block_id}.resnets.{layer_in_block_id}"
349+
)
350+
converted_state_dict[diffusers_key] = state_dict.get(key)
351+
352+
if f"input_blocks.{i}.0.op.bias" in state_dict:
353+
for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]:
354+
diffusers_key = key.replace(
355+
f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv"
356+
)
357+
converted_state_dict[diffusers_key] = state_dict.get(key)
358+
359+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
360+
if attentions:
361+
for key in attentions:
362+
diffusers_key = key.replace(
363+
f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
364+
)
365+
converted_state_dict[diffusers_key] = state_dict.get(key)
366+
367+
# controlnet down blocks
368+
for i in range(num_input_blocks):
369+
converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight")
370+
converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias")
371+
372+
# Retrieves the keys for the middle blocks only
373+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer})
374+
middle_blocks = {
375+
layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key]
376+
for layer_id in range(num_middle_blocks)
377+
}
378+
379+
# Mid blocks
380+
for key in middle_blocks.keys():
381+
diffusers_key = max(key - 1, 0)
382+
if key % 2 == 0:
383+
for k in middle_blocks[key]:
384+
diffusers_key_hf = (
385+
k.replace("in_layers.0", "norm1")
386+
.replace("in_layers.2", "conv1")
387+
.replace("out_layers.0", "norm2")
388+
.replace("out_layers.3", "conv2")
389+
.replace("emb_layers.1", "time_emb_proj")
390+
.replace("skip_connection", "conv_shortcut")
391+
)
392+
diffusers_key_hf = diffusers_key_hf.replace(
393+
f"middle_block.{key}", f"mid_block.resnets.{diffusers_key}"
394+
)
395+
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
396+
else:
397+
for k in middle_blocks[key]:
398+
diffusers_key_hf = k.replace(f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}")
399+
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
400+
401+
# mid block
402+
converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight")
403+
converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias")
404+
405+
# controlnet cond embedding blocks
406+
cond_embedding_blocks = {
407+
".".join(layer.split(".")[:2])
408+
for layer in state_dict
409+
if "input_hint_block" in layer
410+
and ("input_hint_block.0" not in layer)
411+
and ("input_hint_block.14" not in layer)
412+
}
413+
num_cond_embedding_blocks = len(cond_embedding_blocks)
414+
415+
for idx in range(1, num_cond_embedding_blocks + 1):
416+
diffusers_idx = idx - 1
417+
cond_block_id = 2 * idx
418+
419+
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = state_dict.get(
420+
f"input_hint_block.{cond_block_id}.weight"
421+
)
422+
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get(
423+
f"input_hint_block.{cond_block_id}.bias"
424+
)
425+
426+
for key in [key for key in state_dict if "input_hint_block.0" in key]:
427+
diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in")
428+
converted_state_dict[diffusers_key] = state_dict.get(key)
429+
430+
for key in [key for key in state_dict if "input_hint_block.14" in key]:
431+
diffusers_key = key.replace("input_hint_block.14", "controlnet_cond_embedding.conv_out")
432+
converted_state_dict[diffusers_key] = state_dict.get(key)
433+
434+
return converted_state_dict
435+
436+
state_dict = _convert_controlnet_to_diffusers(state_dict)
437+
mapping = CONTROL_LORA_TO_DIFFUSERS
438+
return convert_state_dict(state_dict, mapping)
439+
440+
262441
def convert_all_state_dict_to_peft(state_dict):
263442
r"""
264443
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid

0 commit comments

Comments
 (0)