Skip to content

Commit 1567243

Browse files
authored
[lora] Remove lora docs unneeded and add " # Copied from ..." (#12824)
* remove unneeded docs on load_lora_weights(). * remove more. * up[ * up * up
1 parent 0eac64c commit 1567243

File tree

1 file changed

+20
-128
lines changed

1 file changed

+20
-128
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 20 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,7 +1487,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
14871487
Load LoRA layers into [`FluxTransformer2DModel`],
14881488
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
14891489
1490-
Specific to [`StableDiffusion3Pipeline`].
1490+
Specific to [`FluxPipeline`].
14911491
"""
14921492

14931493
_lora_loadable_modules = ["transformer", "text_encoder"]
@@ -1628,30 +1628,7 @@ def load_lora_weights(
16281628
**kwargs,
16291629
):
16301630
"""
1631-
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
1632-
`self.text_encoder`.
1633-
1634-
All kwargs are forwarded to `self.lora_state_dict`.
1635-
1636-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
1637-
loaded.
1638-
1639-
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
1640-
dict is loaded into `self.transformer`.
1641-
1642-
Parameters:
1643-
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1644-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1645-
adapter_name (`str`, *optional*):
1646-
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1647-
`default_{i}` where i is the total number of adapters being loaded.
1648-
low_cpu_mem_usage (`bool`, *optional*):
1649-
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1650-
weights.
1651-
hotswap (`bool`, *optional*):
1652-
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
1653-
kwargs (`dict`, *optional*):
1654-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1631+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
16551632
"""
16561633
if not USE_PEFT_BACKEND:
16571634
raise ValueError("PEFT backend is required for this method.")
@@ -3651,44 +3628,17 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
36513628

36523629
@classmethod
36533630
@validate_hf_hub_args
3631+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
36543632
def lora_state_dict(
36553633
cls,
36563634
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
36573635
**kwargs,
36583636
):
36593637
r"""
3660-
Return state dict for lora weights and the network alphas.
3661-
3662-
Parameters:
3663-
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3664-
Can be either:
3665-
- A string, the *model id* of a pretrained model hosted on the Hub.
3666-
- A path to a *directory* containing the model weights.
3667-
- A [torch state
3668-
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3669-
3670-
cache_dir (`Union[str, os.PathLike]`, *optional*):
3671-
Path to a directory where a downloaded pretrained model configuration is cached.
3672-
force_download (`bool`, *optional*, defaults to `False`):
3673-
Whether or not to force the (re-)download of the model weights.
3674-
proxies (`Dict[str, str]`, *optional*):
3675-
A dictionary of proxy servers to use by protocol or endpoint.
3676-
local_files_only (`bool`, *optional*, defaults to `False`):
3677-
Whether to only load local model weights and configuration files.
3678-
token (`str` or *bool*, *optional*):
3679-
The token to use as HTTP bearer authorization for remote files.
3680-
revision (`str`, *optional*, defaults to `"main"`):
3681-
The specific model version to use.
3682-
subfolder (`str`, *optional*, defaults to `""`):
3683-
The subfolder location of a model file within a larger model repository.
3684-
weight_name (`str`, *optional*, defaults to None):
3685-
Name of the serialized state dict file.
3686-
use_safetensors (`bool`, *optional*):
3687-
Whether to use safetensors for loading.
3688-
return_lora_metadata (`bool`, *optional*, defaults to False):
3689-
When enabled, additionally return the LoRA adapter metadata.
3638+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
36903639
"""
3691-
# Load the main state dict first which has the LoRA layers
3640+
# Load the main state dict first which has the LoRA layers for either of
3641+
# transformer and text encoder or both.
36923642
cache_dir = kwargs.pop("cache_dir", None)
36933643
force_download = kwargs.pop("force_download", False)
36943644
proxies = kwargs.pop("proxies", None)
@@ -3731,6 +3681,7 @@ def lora_state_dict(
37313681
out = (state_dict, metadata) if return_lora_metadata else state_dict
37323682
return out
37333683

3684+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
37343685
def load_lora_weights(
37353686
self,
37363687
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -3739,26 +3690,13 @@ def load_lora_weights(
37393690
**kwargs,
37403691
):
37413692
"""
3742-
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
3743-
3744-
Parameters:
3745-
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3746-
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
3747-
adapter_name (`str`, *optional*):
3748-
Adapter name to be used for referencing the loaded adapter model.
3749-
hotswap (`bool`, *optional*):
3750-
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
3751-
low_cpu_mem_usage (`bool`, *optional*):
3752-
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3753-
weights.
3754-
kwargs (`dict`, *optional*):
3755-
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
3693+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
37563694
"""
37573695
if not USE_PEFT_BACKEND:
37583696
raise ValueError("PEFT backend is required for this method.")
37593697

37603698
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3761-
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
3699+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
37623700
raise ValueError(
37633701
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
37643702
)
@@ -3775,7 +3713,6 @@ def load_lora_weights(
37753713
if not is_correct_format:
37763714
raise ValueError("Invalid LoRA checkpoint.")
37773715

3778-
# Load LoRA into transformer
37793716
self.load_lora_into_transformer(
37803717
state_dict,
37813718
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
@@ -3787,6 +3724,7 @@ def load_lora_weights(
37873724
)
37883725

37893726
@classmethod
3727+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
37903728
def load_lora_into_transformer(
37913729
cls,
37923730
state_dict,
@@ -3798,23 +3736,9 @@ def load_lora_into_transformer(
37983736
metadata=None,
37993737
):
38003738
"""
3801-
Load the LoRA layers specified in `state_dict` into `transformer`.
3802-
3803-
Parameters:
3804-
state_dict (`dict`):
3805-
A standard state dict containing the lora layer parameters.
3806-
transformer (`Kandinsky5Transformer3DModel`):
3807-
The transformer model to load the LoRA layers into.
3808-
adapter_name (`str`, *optional*):
3809-
Adapter name to be used for referencing the loaded adapter model.
3810-
low_cpu_mem_usage (`bool`, *optional*):
3811-
Speed up model loading by only loading the pretrained LoRA weights.
3812-
hotswap (`bool`, *optional*):
3813-
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
3814-
metadata (`dict`):
3815-
Optional LoRA adapter metadata.
3739+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
38163740
"""
3817-
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
3741+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
38183742
raise ValueError(
38193743
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
38203744
)
@@ -3832,6 +3756,7 @@ def load_lora_into_transformer(
38323756
)
38333757

38343758
@classmethod
3759+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
38353760
def save_lora_weights(
38363761
cls,
38373762
save_directory: Union[str, os.PathLike],
@@ -3840,24 +3765,10 @@ def save_lora_weights(
38403765
weight_name: str = None,
38413766
save_function: Callable = None,
38423767
safe_serialization: bool = True,
3843-
transformer_lora_adapter_metadata=None,
3768+
transformer_lora_adapter_metadata: Optional[dict] = None,
38443769
):
38453770
r"""
3846-
Save the LoRA parameters corresponding to the transformer and text encoders.
3847-
3848-
Arguments:
3849-
save_directory (`str` or `os.PathLike`):
3850-
Directory to save LoRA parameters to.
3851-
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3852-
State dict of the LoRA layers corresponding to the `transformer`.
3853-
is_main_process (`bool`, *optional*, defaults to `True`):
3854-
Whether the process calling this is the main process.
3855-
save_function (`Callable`):
3856-
The function to use to save the state dictionary.
3857-
safe_serialization (`bool`, *optional*, defaults to `True`):
3858-
Whether to save the model using `safetensors` or the traditional PyTorch way.
3859-
transformer_lora_adapter_metadata:
3860-
LoRA adapter metadata associated with the transformer.
3771+
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
38613772
"""
38623773
lora_layers = {}
38633774
lora_metadata = {}
@@ -3867,7 +3778,7 @@ def save_lora_weights(
38673778
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
38683779

38693780
if not lora_layers:
3870-
raise ValueError("You must pass at least one of `transformer_lora_layers`")
3781+
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
38713782

38723783
cls._save_lora_weights(
38733784
save_directory=save_directory,
@@ -3879,6 +3790,7 @@ def save_lora_weights(
38793790
safe_serialization=safe_serialization,
38803791
)
38813792

3793+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
38823794
def fuse_lora(
38833795
self,
38843796
components: List[str] = ["transformer"],
@@ -3888,25 +3800,7 @@ def fuse_lora(
38883800
**kwargs,
38893801
):
38903802
r"""
3891-
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3892-
3893-
Args:
3894-
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3895-
lora_scale (`float`, defaults to 1.0):
3896-
Controls how much to influence the outputs with the LoRA parameters.
3897-
safe_fusing (`bool`, defaults to `False`):
3898-
Whether to check fused weights for NaN values before fusing.
3899-
adapter_names (`List[str]`, *optional*):
3900-
Adapter names to be used for fusing.
3901-
3902-
Example:
3903-
```py
3904-
from diffusers import Kandinsky5T2VPipeline
3905-
3906-
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
3907-
pipeline.load_lora_weights("path/to/lora.safetensors")
3908-
pipeline.fuse_lora(lora_scale=0.7)
3909-
```
3803+
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
39103804
"""
39113805
super().fuse_lora(
39123806
components=components,
@@ -3916,12 +3810,10 @@ def fuse_lora(
39163810
**kwargs,
39173811
)
39183812

3813+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
39193814
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
39203815
r"""
3921-
Reverses the effect of [`pipe.fuse_lora()`].
3922-
3923-
Args:
3924-
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3816+
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
39253817
"""
39263818
super().unfuse_lora(components=components, **kwargs)
39273819

0 commit comments

Comments
 (0)