From 8103ca22cbbbc01795e5fc1617b5f60e0c80ccd6 Mon Sep 17 00:00:00 2001 From: wujunqiang03 <1540970049@qq.com> Date: Thu, 11 Dec 2025 23:24:34 +0800 Subject: [PATCH 01/20] Add LongCat-Image --- src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 3 +- .../transformers/transformer_longcat_image.py | 568 ++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/longcat_image/__init__.py | 51 ++ .../longcat_image/pipeline_longcat_image.py | 693 +++++++++++++++++ .../pipeline_longcat_image_edit.py | 734 ++++++++++++++++++ .../longcat_image/pipeline_output.py | 20 + .../longcat_image/system_messages.py | 67 ++ 10 files changed, 2145 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/transformers/transformer_longcat_image.py create mode 100644 src/diffusers/pipelines/longcat_image/__init__.py create mode 100644 src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py create mode 100644 src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py create mode 100644 src/diffusers/pipelines/longcat_image/pipeline_output.py create mode 100644 src/diffusers/pipelines/longcat_image/system_messages.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fdd27f2464e4..9a7ccf9be151 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -238,6 +238,7 @@ "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", + "LongCatImageTransformer2DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", @@ -541,6 +542,8 @@ "Lumina2Text2ImgPipeline", "LuminaPipeline", "LuminaText2ImgPipeline", + "LongCatImagePipeline", + "LongCatImageEditPipeline", "MarigoldDepthPipeline", "MarigoldIntrinsicsPipeline", "MarigoldNormalsPipeline", @@ -973,6 +976,7 @@ LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, + LongCatImageTransformer2DModel, MochiTransformer3DModel, ModelMixin, MotionAdapter, @@ -1241,6 +1245,8 @@ LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline, + LongCatImagePipeline, + LongCatImageEditPipeline, LucyEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29d8b0b5a55d..bd9bd01343bb 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -103,6 +103,7 @@ _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] + _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] @@ -211,6 +212,7 @@ LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, + LongCatImageTransformer2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, OvisImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a42f6b2716e1..d1d7d0eb99f2 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -35,6 +35,7 @@ from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel + from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_ovis_image import OvisImageTransformer2DModel @@ -47,4 +48,4 @@ from .transformer_wan import WanTransformer3DModel from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel - from .transformer_z_image import ZImageTransformer2DModel + from .transformer_z_image import ZImageTransformer2DModel \ No newline at end of file diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py new file mode 100644 index 000000000000..2adf8a7333c9 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -0,0 +1,568 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import is_torch_npu_available, logging +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_projections(attn: "LongCatImageAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "LongCatImageAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "LongCatImageAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class LongCatImageAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "LongCatImageAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class LongCatImageAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = LongCatImageAttnProcessor + _available_processors = [ + LongCatImageAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +@maybe_allow_in_graph +class LongCatImageSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + self.attn = LongCatImageAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=LongCatImageAttnProcessor(), + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +@maybe_allow_in_graph +class LongCatImageTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = LongCatImageAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=LongCatImageAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class LongCatImagePosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + + +class TimestepEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + return timesteps_emb + + +class LongCatImageTransformer2DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + CacheMixin, ): + """ + The Transformer model introduced in Longcat-Image. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + pooled_projection_dim: int = 3584, + axes_dims_rope: List[int] = [16, 56, 56], + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.pooled_projection_dim = pooled_projection_dim + + self.pos_embed = LongCatImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) + + self.time_embed = TimestepEmbeddings(embedding_dim=self.inner_dim) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + LongCatImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + LongCatImageSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + self.use_checkpoint = [True] * num_layers + self.use_single_checkpoint = [True] * num_single_layers + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + + temb = self.time_embed( timestep, hidden_states.dtype ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + if is_torch_npu_available(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) + else: + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_single_checkpoint[index_block]: + encoder_hidden_states,hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 388551f812f8..4df047f67b3f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -291,6 +291,7 @@ _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] + _import_structure["longcat_image"] = ["LongCatImagePipeline","LongCatImageEditPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", @@ -722,6 +723,7 @@ from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline + from .longcat_image import LongCatImagePipeline,LongCatImageEditPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, diff --git a/src/diffusers/pipelines/longcat_image/__init__.py b/src/diffusers/pipelines/longcat_image/__init__.py new file mode 100644 index 000000000000..b68a72f21d4d --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["LongCatImagePipelineOutput"] + _import_structure["pipeline_longcat_image"] = ["LongCatImagePipeline"] + _import_structure["pipeline_longcat_image_edit"] = ["LongCatImageEditPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_output import LongCatImagePipelineOutput + from .pipeline_longcat_image import LongCatImagePipeline + from .pipeline_longcat_image_edit import LongCatImageEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py new file mode 100644 index 000000000000..16631e7ccadf --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -0,0 +1,693 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import json +import numpy as np +import torch +import re + +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...pipelines.pipeline_utils import DiffusionPipeline + +from ...models.transformers import LongCatImageTransformer2DModel +from .pipeline_output import LongCatImagePipelineOutput +from .system_messages import SYSTEM_PROMPT_EN,SYSTEM_PROMPT_ZH + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LongCatImagePipeline + + >>> pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = '一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。' + >>> image = pipe( + ... prompt, + ... height=768, + ... width=1344, + ... num_inference_steps=50, + ... guidance_scale=4.5, + ... generator=torch.Generator("cpu").manual_seed(43), + ... enable_cfg_renorm=True + ... ).images[0] + >>> image.save("longcat_image.png") + ``` +""" + + +def get_prompt_language(prompt): + pattern = re.compile(r'[\u4e00-\u9fff]') + if bool(pattern.search(prompt)): + return 'zh' + return 'en' + + + +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote pairs. + + Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." + >>> print(split_quotation(prompt_en)) + >>> # output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = 'longcat_$##$_longcat' * (i+1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ('‘', '’'), ('“', '”')] + + pattern = '|'.join([re.escape(q1) + r'[^' + re.escape(q1+q2) + + r']*?' + re.escape(q2) for q1, q2 in quote_pairs]) + + parts = re.split(f'({pattern})', prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +def prepare_pos_ids( + modality_id=0, + type='text', + start=(0, 0), + num_token=None, + height=None, + width=None): + if type == 'text': + assert num_token + if height or width: + print( + 'Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == 'image': + assert height and width + if num_token: + print('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = ( + pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + ) + pos_ids[..., 2] = ( + pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + ) + pos_ids = pos_ids.reshape(height*width, 3) + else: + raise KeyError(f'Unknow type {type}, only support "text" or "image".') + return pos_ids + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LongCatImagePipeline( + DiffusionPipeline,FromSingleFileMixin +): + r""" + The pipeline for text-to-image generation. + """ + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [ ] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_processor: Qwen2VLProcessor, + transformer: LongCatImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_processor=text_processor + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.prompt_template_encode_prefix = '<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n' + self.prompt_template_encode_suffix = '<|im_end|>\n<|im_start|>assistant\n' + self.prompt_template_encode_start_idx = 36 + self.prompt_template_encode_end_idx = 5 + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + + @torch.inference_mode() + def rewire_prompt(self, prompt, device): + language = get_prompt_language(prompt) + if language == 'zh': + question = SYSTEM_PROMPT_ZH + f"\n用户输入为:{prompt}\n改写后的prompt为:" + else: + question = SYSTEM_PROMPT_EN + f"\nUser Input: {prompt}\nRewritten prompt:" + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + ], + } + ] + # Preparation for inference + text = self.text_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = self.text_processor(text=[text],padding=True,return_tensors="pt").to(device) + + self.text_encoder.to(device) + generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = self.text_processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + rewrite_prompt= output_text + return rewrite_prompt + + def _encode_prompt( self, prompt, num_images_per_prompt ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(prompt[0]): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)['input_ids'] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)['input_ids'] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}" + ) + all_tokens = all_tokens[:self.tokenizer_max_length] + + text_tokens_and_mask = self.tokenizer.pad( + {'input_ids': [all_tokens]}, + max_length=self.tokenizer_max_length, + padding='max_length', + return_attention_mask=True, + return_tensors='pt') + + prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)['input_ids'] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)['input_ids'] + prefix_tokens_mask = torch.tensor( [1]*len(prefix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype ) + suffix_tokens_mask = torch.tensor( [1]*len(suffix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype ) + + prefix_tokens = torch.tensor(prefix_tokens,dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens,dtype=text_tokens_and_mask.input_ids.dtype) + + input_ids = torch.cat( (prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1 ) + attention_mask = torch.cat( (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 ) + + input_ids = input_ids.unsqueeze(0).to(self.device) + attention_mask = attention_mask.unsqueeze(0).to(self.device) + + text_output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True + ) + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:] + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + @torch.inference_mode() + def encode_prompt(self, + prompt : List[str] = None, + num_images_per_prompt: Optional[int] = 1, + prompt_embeds: Optional[torch.Tensor] = None ): + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is None: + prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt ) + + text_ids = prepare_pos_ids(modality_id=0, + type='text', + start=(0, 0), + num_token=prompt_embeds.shape[1]).to(self.device) + return prompt_embeds.to(self.device), text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = prepare_pos_ids(modality_id=1, + type='image', + start=(self.tokenizer_max_length, + self.tokenizer_max_length), + height=height//2, + width=width//2).to(device) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator,device=device) + latents = latents.to(dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + enable_cfg_renorm: Optional[bool] = True, + cfg_renorm_min: Optional[float] = 0.0, + enable_prompt_rewrite: Optional[bool] = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + enable_cfg_renorm: Whether to enable cfg_renorm. Enabling cfg_renorm will improve image quality, + but it may lead to a decrease in the stability of some image outputs.. + cfg_renorm_min: The minimum value of the cfg_renorm_scale range (0-1). + cfg_renorm_min = 1.0, renorm has no effect, while cfg_renorm_min=0.0, the renorm range is larger. + enable_prompt_rewrite: whether to enable prompt rewrite. + Examples: + + Returns: + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + pixel_step= self.vae_scale_factor * 2 + height = int( height/pixel_step )*pixel_step + width = int( width/pixel_step )*pixel_step + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # from IPython import embed; embed(); + if enable_prompt_rewrite: + prompt = self.rewire_prompt(prompt, device ) + logger.info(f'Rewrite prompt {prompt}!') + print( f'Rewrite prompt {prompt}!') + + negative_prompt = '' if negative_prompt is None else negative_prompt + ( + prompt_embeds, + text_ids + ) = self.encode_prompt( prompt=prompt, + prompt_embeds=prompt_embeds, + num_images_per_prompt= num_images_per_prompt + ) + if self.do_classifier_free_guidance: + ( + negative_prompt_embeds, + negative_text_ids + ) = self.encode_prompt( prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps,num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + # if self.do_classifier_free_guidance: + # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0).to(device) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred_text = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if enable_cfg_renorm: + cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min= cfg_renorm_min , max=1.0) + noise_pred = noise_pred * scale + else: + noise_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LongCatImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py new file mode 100644 index 000000000000..683b102b76f4 --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py @@ -0,0 +1,734 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union +import json +import numpy as np +import PIL +import torch +import re + +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...pipelines.pipeline_utils import DiffusionPipeline + +from ...models.transformers import LongCatImageTransformer2DModel +from .pipeline_output import LongCatImagePipelineOutput + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> import torch + >>> from diffusers import LongCatImageEditPipeline + + >>> pipe = LongCatImageEditPipeline.from_pretrained("meituan-longcat/LongCat-Image-Edit", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = 'change the cat to dog.' + >>> input_image = Image.open('test.jpg').convert("RGB") + >>> image = pipe( + ... input_image, + ... prompt, + ... num_inference_steps=50, + ... guidance_scale=4.5, + ... generator=torch.Generator("cpu").manual_seed(43), + ... ).images[0] + >>> image.save("longcat_image_edit.png") + ``` +""" + +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote pairs. + + Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." + >>> print(split_quotation(prompt_en)) + >>> # output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = 'longcat_$##$_longcat' * (i+1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ('‘', '’'), ('“', '”')] + + pattern = '|'.join([re.escape(q1) + r'[^' + re.escape(q1+q2) + + r']*?' + re.escape(q2) for q1, q2 in quote_pairs]) + + parts = re.split(f'({pattern})', prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +def prepare_pos_ids( + modality_id=0, + type='text', + start=(0, 0), + num_token=None, + height=None, + width=None): + if type == 'text': + assert num_token + if height or width: + print( + 'Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == 'image': + assert height and width + if num_token: + print('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = ( + pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + ) + pos_ids[..., 2] = ( + pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + ) + pos_ids = pos_ids.reshape(height*width, 3) + else: + raise KeyError(f'Unknow type {type}, only support "text" or "image".') + return pos_ids + + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = width if width % 16 == 0 else (width // 16 + 1) * 16 + height = height if height % 16 == 0 else (height // 16 + 1) * 16 + + width = int(width) + height = int(height) + + return width, height + +class LongCatImageEditPipeline( + DiffusionPipeline,FromSingleFileMixin +): + r""" + The LongCat-Image-Edit pipeline for image editing. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _optional_components = [ ] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_processor: Qwen2VLProcessor, + transformer: LongCatImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_processor=text_processor + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor_vl = text_processor.image_processor + + self.image_token = "<|image_pad|>" + self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + self.prompt_template_encode_suffix = '<|im_end|>\n<|im_start|>assistant\n' + self.prompt_template_encode_start_idx = 67 + self.prompt_template_encode_end_idx = 5 + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def _encode_prompt( self, prompt, image, num_images_per_prompt ): + + raw_vl_input = self.image_processor_vl(images=image,return_tensors="pt") + pixel_values = raw_vl_input['pixel_values'] + image_grid_thw = raw_vl_input['image_grid_thw'] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(prompt[0]): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)['input_ids'] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)['input_ids'] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}" + ) + all_tokens = all_tokens[:self.tokenizer_max_length] + + text_tokens_and_mask = self.tokenizer.pad( + {'input_ids': [all_tokens]}, + max_length=self.tokenizer_max_length, + padding='max_length', + return_attention_mask=True, + return_tensors='pt') + + text = self.prompt_template_encode_prefix + + merge_length = self.image_processor_vl.merge_size**2 + while self.image_token in text: + num_image_tokens = image_grid_thw.prod() // merge_length + text = text.replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + text = text.replace("<|placeholder|>", self.image_token) + + prefix_tokens = self.tokenizer(text, add_special_tokens=False)['input_ids'] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)['input_ids'] + prefix_tokens_mask = torch.tensor( [1]*len(prefix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype ) + suffix_tokens_mask = torch.tensor( [1]*len(suffix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype ) + + prefix_tokens = torch.tensor(prefix_tokens,dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens,dtype=text_tokens_and_mask.input_ids.dtype) + + input_ids = torch.cat( (prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1 ) + attention_mask = torch.cat( (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 ) + + input_ids = input_ids.unsqueeze(0).to(self.device) + attention_mask = attention_mask.unsqueeze(0).to(self.device) + + pixel_values = pixel_values.to(self.device) + image_grid_thw = image_grid_thw.to(self.device) + + text_output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw =image_grid_thw, + output_hidden_states=True + ) + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:] + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + @torch.inference_mode() + def encode_prompt(self, + prompt : List[str] = None, + image: Optional[torch.Tensor] = None, + num_images_per_prompt: Optional[int] = 1, + prompt_embeds: Optional[torch.Tensor] = None,): + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is None: + prompt_embeds = self._encode_prompt( prompt, image, num_images_per_prompt ) + + text_ids = prepare_pos_ids(modality_id=0, + type='text', + start=(0, 0), + num_token=prompt_embeds.shape[1]).to(self.device) + return prompt_embeds, text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + prompt_embeds_length, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + image_latents, image_latents_ids = None, None + + if image is not None: + image = image.to(device=self.device, dtype=dtype) + + if image.shape[1] != self.vae.config.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, height, width + ) + + image_latents_ids = prepare_pos_ids(modality_id=2, + type='image', + start=(prompt_embeds_length, + prompt_embeds_length), + height=height//2, + width=width//2).to(device, dtype=torch.float64) + + shape = (batch_size, num_channels_latents, height, width) + latents_ids = prepare_pos_ids(modality_id=1, + type='image', + start=(prompt_embeds_length, + prompt_embeds_length), + height=height//2, + width=width//2).to(device) + + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latents_ids, image_latents_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + image: Optional[PIL.Image.Image] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Examples: + + Returns: + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + image_size = image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0]*1.0/image_size[1]) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = self.image_processor.resize(image, calculated_height//2, calculated_width//2) + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + + negative_prompt = '' if negative_prompt is None else negative_prompt + ( + prompt_embeds, + text_ids + ) = self.encode_prompt( prompt=prompt, + image=prompt_image, + prompt_embeds=prompt_embeds, + num_images_per_prompt= num_images_per_prompt + ) + if self.do_classifier_free_guidance: + ( + negative_prompt_embeds, + negative_text_ids + ) = self.encode_prompt( prompt=negative_prompt, + image=prompt_image, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + calculated_height, + calculated_width, + prompt_embeds.dtype, + prompt_embeds.shape[1], + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps,num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + if image is not None: + latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0) + else: + latent_image_ids = latents_ids + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + #latent_model_input = torch.cat([latent_model_input] * 2) if self.do_classifier_free_guidance else latent_model_input + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred_text = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_text = noise_pred_text[:, :image_seq_len] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_uncond = noise_pred_uncond[:, :image_seq_len] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = noise_pred_text + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, calculated_height, calculated_width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + #Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LongCatImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_output.py b/src/diffusers/pipelines/longcat_image/pipeline_output.py new file mode 100644 index 000000000000..b600f9dd1fe1 --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image +from diffusers.utils import BaseOutput + + +@dataclass +class LongCatImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/pipelines/longcat_image/system_messages.py b/src/diffusers/pipelines/longcat_image/system_messages.py new file mode 100644 index 000000000000..77c017cfffdc --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/system_messages.py @@ -0,0 +1,67 @@ + + +SYSTEM_PROMPT_EN = """ +You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all information from the user's original prompt without deleting or distorting any details. +Specific requirements are as follows: +1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as concise as possible. +2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields English output. The rewritten token count should not exceed 512. +3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the original prompt, such as lighting and textures. +4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography style**. If the user specifies a style, retain the user's style. +5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a giraffe"). +6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% OFF"`). +7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer with the image title 'Grassland'." +8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all. +9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**. +Here are examples of rewrites for different types of prompts: +# Examples (Few-Shot Learning) + 1. User Input: An animal with nine lives. + Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home environment with light from the window filtering through curtains, creating a warm light and shadow effect. The shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image. + 2. User Input: Create an anime-style tourism flyer with a grassland theme. + Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere. + 3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer. + Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing strong visual appeal. + 4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident posture. + Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. The overall style is natural, elegant, and artistic. + 5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting. + Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the apple's life cycle. + 6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate a four-color rainbow based on this rule. The color order from top to bottom is 3142. + Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast with the background colors to ensure good readability. The stripes have high color saturation and a slight texture. The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing the numerical information. The image is high definition, with accurate colors and a consistent style, offering strong visual appeal. + 7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a Chinese garden. + Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the stone tablet and the classical beauty of the garden. +# Output Format +Please directly output the rewritten and optimized Prompt content. Do not include any explanatory language or JSON formatting, and do not add opening or closing quotes yourself.""" + + +SYSTEM_PROMPT_ZH = """ +你是一名文生图模型的prompt engineering专家。由于文生图模型对用户prompt的理解能力有限,你需要识别用户输入的核心主题和意图,并通过优化改写提升模型的理解准确性和生成质量。改写必须严格保留用户原始prompt的所有信息,不得删减或曲解任何细节。 +具体要求如下: +1. 改写不能影响用户原始prompt里表达的任何信息,改写后的prompt应该使用连贯的自然语言表达,不要出现低信息量的冗余描述,尽可能保持改写后prompt长度精简。 +2. 请确保输入和输出的语言类型一致,中文输入中文输出,英文输入英文输出,改写后的token数量不要超过512个; +3. 改写后的描述应当进一步完善原始prompt中出现的主体特征、美学技巧,如打光、纹理等; +4. 如果原始prompt没有指定图片风格时,确保改写后的prompt使用真实摄影风格,如果用户指定了图片风格,则保留用户风格; +5. 当原始prompt需要推理才能明确用户意图时,根据世界知识进行适当逻辑推理,将模糊抽象描述转化为具体指向事物(例:将"最高的动物"转化为"一头长颈鹿")。 +6. 当原始prompt需要生成文字时,请使用双引号圈定文字部分,例:`"限时5折"`)。 +7. 当原始prompt需要生成网页、logo、ui、海报等文字场景时,且没有指定具体的文字内容时,需要推断出合适的文字内容,并使用双引号圈定,如用户输入:一个旅游宣传单,以草原为主题。应该改写成:一个旅游宣传单,图片标题为“草原”。 +8. 当原始prompt中存在否定词时,需要确保改写后的prompt不存在否定词,如没有船的湖边,改写后的prompt不能出现船这个词汇。 +9. 除非用户指定生成品牌logo,否则不要增加额外的品牌logo. +10. 除了用户明确要求书写的文字内容外,**禁止增加任何额外的文字内容**。 +以下是针对不同类型prompt改写的示例: + +# Examples (Few-Shot Learning) + 1. 用户输入: 九条命的动物。 + 改写输出: 一只猫,被柔和的阳光笼罩着,毛发柔软而富有光泽。背景是一个舒适的家居环境,窗外的光线透过窗帘,形成温馨的光影效果。镜头采用中距离视角,突出猫悠闲舒展的姿态。光线巧妙地打在猫的脸部,强调它灵动的眼睛和精致的胡须,增加画面的层次感与亲和力。 + 2. 用户输入: 制作一个动画风格的旅游宣传单,以草原为主题。 + 改写输出: 画面中央偏右下角,一个短发女孩侧身坐在灰色的不规则形状岩石上,她穿着白色短袖连衣裙和棕色平底鞋,左手拿着一束白色小花,面带微笑,双腿自然垂下。女孩的头发为深棕色,齐肩短发,刘海覆盖额头,眼睛呈棕色,嘴巴微张。岩石表面有深浅不一的纹理。女孩的左侧和前方是茂盛的草地,草叶细长,呈黄绿色,部分草叶在阳光下泛着金色的光芒,仿佛被阳光照亮。草地向远处延伸,形成连绵起伏的绿色山丘,山丘的颜色由近及远逐渐变浅。天空占据了画面的上半部分,呈淡蓝色,点缀着几朵白色蓬松的云彩。画面的左上角有一行文字,文字内容是斜体、深绿色的“Explore Nature's Peace”。色彩以绿色、蓝色和黄色为主,线条流畅,光影明暗对比明显,营造出一种宁静、舒适的氛围。 + 3. 用户输入: 一张以红色为背景的圣诞节促销海报,主要宣传奶茶买一送一的优惠活动。 + 改写输出: 海报整体呈现红色调,上方和左侧点缀着白色雪花图案,右上方有一束冬青叶和红色浆果,以及一个松果。海报中央偏上位置,金色立体字样“圣诞节 暖心回馈”居中排列,和红色粗体字“买1送1”。海报下方,两个装满珍珠奶茶的透明杯子并排摆放,杯中奶茶呈浅棕色,底部和中间散布着深棕色珍珠。杯子下方,堆积着白色雪花,雪花上装饰着松枝、红色浆果和松果。右下角隐约可见一棵模糊的圣诞树。图片清晰度高,文字内容准确,整体设计风格统一,圣诞主题突出,排版布局合理,具有较强的视觉吸引力。 + 4. 用户输入: 一位女性在室内以自然光线拍摄,她面带微笑,双臂交叉,展现出轻松自信的姿态。 + 改写输出: 画面中是一位年轻的亚洲女性,她拥有深棕色的长发,发丝自然地垂落在双肩,部分发丝被光线照亮,呈现出柔和的光泽。她的五官精致,眉毛修长,眼睛明亮有神,瞳孔呈深棕色,眼神直视镜头,流露出平和与自信。鼻梁挺拔,嘴唇丰满,涂有裸色系唇膏,嘴角微微上扬,展现出浅浅的微笑。她的肤色白皙,脸颊和锁骨处被暖色调的光线照亮,呈现出健康的红润感。她穿着一件黑色的细吊带背心,肩带纤细,露出优美的锁骨线条。脖颈上佩戴着一条金色的细项链,项链由小珠子和几个细长的金属条组成,在光线下闪烁着光泽。她的外搭是一件米黄色的针织开衫,材质柔软,袖子部分有明显的针织纹理。她双臂交叉在胸前,双手被开衫的袖子覆盖,姿态放松。背景是纯粹的深棕色,没有多余的装饰,使得人物成为画面的绝对焦点。人物位于画面中央。光线从画面的右上方射入,在人物的左侧脸颊、脖颈和锁骨处形成明亮的光斑,右侧则略显阴影,营造出立体感和柔和的影调。图像细节清晰,人物的皮肤纹理、发丝以及衣物材质都得到了很好的展现。色彩以暖色调为主,米黄色和深棕色的搭配营造出温馨舒适的氛围。整体呈现出一种自然、优雅且富有亲和力的艺术风格。 + 5. 用户输入:创作一系列图片,展现苹果从种子到结果的生长过程。该系列图片应包含以下四个阶段:1. 播种,2. 幼苗生长,3. 植物成熟,4. 果实采摘。 + 改写输出:一个4宫格的精美插图,描绘苹果的生长过程,精确清晰地捕捉每个阶段。1.“播种”:特写镜头,一只手轻轻地将一颗小小的苹果种子放入肥沃的深色土壤中,土壤的纹理和种子光滑的表面清晰可见。背景是花园的柔焦画面,点缀着绿色的树叶和透过树叶洒下的阳光。2.“幼苗生长”:一棵幼小的苹果树苗破土而出,嫩绿的叶子向天空舒展。场景设定在一个生机勃勃的花园中,温暖的金光照亮了它。幼苗的纤细结构。3.“植物的成熟”:一棵成熟的苹果树,枝繁叶茂,挂满了嫩绿的叶子和正在萌发的小苹果。背景是一片生机勃勃的果园,湛蓝的天空下,斑驳的阳光营造出宁静祥和的氛围。4.“采摘果实”:一只手伸向树上,摘下一个成熟的红苹果,苹果光滑的果皮在阳光下闪闪发光。画面展现了果园的丰收景象,背景中摆放着一篮篮的苹果,给人一种圆满满足的感觉。每幅插图都采用写实风格,注重细节,色彩和谐,展现了苹果生命周期的自然之美和发展过程。 + 6. 用户输入: 如果1代表红色,2代表绿色,3代表紫色,4代表黄色,请按照此规则生成四色彩虹。它的颜色顺序从上到下是3142 + 改写输出:图片由四个水平排列的彩色条纹组成,从上到下依次为紫色、红色、黄色和绿色。每个条纹上都居中放置一个白色数字。最上方的紫色条纹上是数字“3”,其下方红色条纹上是数字“1”,再下方黄色条纹上是数字“4”,最下方的绿色条纹上是数字“2”。所有数字均采用无衬线字体,颜色为纯白色,与背景色形成鲜明对比,确保了良好的可读性。条纹的颜色饱和度高,且带有轻微的纹理感,整体排版简洁明了,视觉效果清晰,没有多余的装饰元素,强调了数字信息本身。图片整体清晰度高,色彩准确,风格一致,具有较强的视觉吸引力。 + 7. 用户输入:石碑上刻着“关关雎鸠,在河之洲”,自然光照,背景是中式园林 + 改写输出:一块古老的石碑上刻着“关关雎鸠,在河之洲”,石碑表面布满岁月的痕迹,字迹清晰而深刻。自然光线从上方洒下,柔和地照亮石碑的每一个细节,增强了其历史感。背景是一座典雅的中式园林,园林中有翠绿的竹林、蜿蜒的小径和静谧的水池,营造出一种宁静而悠远的氛围。整体画面采用写实风格,细节丰富,光影效果自然,突出了石碑的文化底蕴和园林的古典美。 +# 输出格式 +请直接输出改写优化后的 Prompt 内容,不要包含任何解释性语言或 JSON 格式,不要自行添加开头或结尾的引号。 +""" From 792746b49ae97c9e7f245aa038d8eff3bfc15049 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:30:41 +0800 Subject: [PATCH 02/20] Update src/diffusers/models/transformers/transformer_longcat_image.py Co-authored-by: YiYi Xu --- .../models/transformers/transformer_longcat_image.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 2adf8a7333c9..617dcc94a477 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -497,10 +497,6 @@ def forward( hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - else: - guidance = None temb = self.time_embed( timestep, hidden_states.dtype ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) From 898902c509f9c6ffcdb5ba690c91966de73d4a81 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:31:15 +0800 Subject: [PATCH 03/20] Update src/diffusers/models/transformers/transformer_longcat_image.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_longcat_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 617dcc94a477..50e5ff223541 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -382,7 +382,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: -class TimestepEmbeddings(nn.Module): +class LongCatImageTimestepEmbeddings(nn.Module): def __init__(self, embedding_dim): super().__init__() From 87dfd6808d7457b8f0ec01454d0d25e38d71615d Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:32:55 +0800 Subject: [PATCH 04/20] Update src/diffusers/models/transformers/transformer_longcat_image.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_longcat_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 50e5ff223541..66c50fcec049 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -428,7 +428,7 @@ def __init__( self.pos_embed = LongCatImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) - self.time_embed = TimestepEmbeddings(embedding_dim=self.inner_dim) + self.time_embed = LongCatImageTimestepEmbeddings(embedding_dim=self.inner_dim) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) From 600cd5ce41c1ba209b20fe6a386a5cc89b9732d6 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:33:12 +0800 Subject: [PATCH 05/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py Co-authored-by: YiYi Xu --- .../longcat_image/pipeline_longcat_image.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index 16631e7ccadf..ffda84663d53 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -390,34 +390,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() @property def guidance_scale(self): From a8d35f4d6eb0e77d23eebe11036888a8314650b0 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:34:05 +0800 Subject: [PATCH 06/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index ffda84663d53..aadef7477e7b 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -350,7 +350,6 @@ def _encode_prompt( self, prompt, num_images_per_prompt ): return prompt_embeds - @torch.inference_mode() def encode_prompt(self, prompt : List[str] = None, num_images_per_prompt: Optional[int] = 1, From d95fb05fc5a53e3ff34fb63cb25aeccaf479fca6 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:34:13 +0800 Subject: [PATCH 07/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index aadef7477e7b..dad33b98434d 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -260,7 +260,6 @@ def __init__( self.tokenizer_max_length = 512 - @torch.inference_mode() def rewire_prompt(self, prompt, device): language = get_prompt_language(prompt) if language == 'zh': From 40714c95313d20d4dd5dd10184dfa562e4e346d6 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:34:59 +0800 Subject: [PATCH 08/20] Update src/diffusers/models/transformers/transformer_longcat_image.py Co-authored-by: YiYi Xu --- .../models/transformers/transformer_longcat_image.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 66c50fcec049..165808aa2c21 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -501,18 +501,6 @@ def forward( temb = self.time_embed( timestep, hidden_states.dtype ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) if is_torch_npu_available(): From 0c6ee770e2e508c375d553c1a5a40019b74e8b60 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:36:33 +0800 Subject: [PATCH 09/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index dad33b98434d..f78b05da4744 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -290,7 +290,7 @@ def rewire_prompt(self, prompt, device): rewrite_prompt= output_text return rewrite_prompt - def _encode_prompt( self, prompt, num_images_per_prompt ): + def _encode_prompt( self, prompt ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) all_tokens = [] From 32032b305ee097d99444e2b0eeaa0d12076c1bfe Mon Sep 17 00:00:00 2001 From: wujunqiang03 <1540970049@qq.com> Date: Fri, 12 Dec 2025 21:03:59 +0800 Subject: [PATCH 10/20] fix code --- .../longcat_image/pipeline_longcat_image.py | 20 +++++++--------- .../pipeline_longcat_image_edit.py | 24 ++++++++----------- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index f78b05da4744..2f783f00954d 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -291,8 +291,6 @@ def rewire_prompt(self, prompt, device): return rewrite_prompt def _encode_prompt( self, prompt ): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) all_tokens = [] for clean_prompt_sub, matched in split_quotation(prompt[0]): if matched: @@ -341,23 +339,23 @@ def _encode_prompt( self, prompt ): prompt_embeds = text_output.hidden_states[-1].detach() prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:] - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - return prompt_embeds def encode_prompt(self, prompt : List[str] = None, num_images_per_prompt: Optional[int] = 1, prompt_embeds: Optional[torch.Tensor] = None ): - + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) # If prompt_embeds is provided and prompt is None, skip encoding if prompt_embeds is None: - prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt ) - + prompt_embeds = self._encode_prompt( prompt ) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + text_ids = prepare_pos_ids(modality_id=0, type='text', start=(0, 0), diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py index 683b102b76f4..a4df9a8fb2cf 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py @@ -279,14 +279,10 @@ def __init__( self.default_sample_size = 128 self.tokenizer_max_length = 512 - def _encode_prompt( self, prompt, image, num_images_per_prompt ): - + def _encode_prompt( self, prompt, image ): raw_vl_input = self.image_processor_vl(images=image,return_tensors="pt") pixel_values = raw_vl_input['pixel_values'] image_grid_thw = raw_vl_input['image_grid_thw'] - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) all_tokens = [] for clean_prompt_sub, matched in split_quotation(prompt[0]): if matched: @@ -348,12 +344,6 @@ def _encode_prompt( self, prompt, image, num_images_per_prompt ): prompt_embeds = text_output.hidden_states[-1].detach() prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:] - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - return prompt_embeds @torch.inference_mode() @@ -361,12 +351,18 @@ def encode_prompt(self, prompt : List[str] = None, image: Optional[torch.Tensor] = None, num_images_per_prompt: Optional[int] = 1, - prompt_embeds: Optional[torch.Tensor] = None,): - + prompt_embeds: Optional[torch.Tensor] = None): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) # If prompt_embeds is provided and prompt is None, skip encoding if prompt_embeds is None: - prompt_embeds = self._encode_prompt( prompt, image, num_images_per_prompt ) + prompt_embeds = self._encode_prompt( prompt, image ) + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + text_ids = prepare_pos_ids(modality_id=0, type='text', start=(0, 0), From dbdbd01db868c9639aefdc17b93766e9ead81998 Mon Sep 17 00:00:00 2001 From: wujunqiang03 <1540970049@qq.com> Date: Fri, 12 Dec 2025 21:45:50 +0800 Subject: [PATCH 11/20] add doc --- .../api/models/longcat_image_transformer2d.md | 25 ++++ docs/source/en/api/pipelines/longcat_image.md | 114 ++++++++++++++++++ tests/pipelines/longcat_image/__init__.py | 0 3 files changed, 139 insertions(+) create mode 100644 docs/source/en/api/models/longcat_image_transformer2d.md create mode 100644 docs/source/en/api/pipelines/longcat_image.md create mode 100644 tests/pipelines/longcat_image/__init__.py diff --git a/docs/source/en/api/models/longcat_image_transformer2d.md b/docs/source/en/api/models/longcat_image_transformer2d.md new file mode 100644 index 000000000000..f40b2583e68b --- /dev/null +++ b/docs/source/en/api/models/longcat_image_transformer2d.md @@ -0,0 +1,25 @@ + + +# LongCatImageTransformer2DModel + +The model can be loaded with the following code snippet. + +```python +from diffusers import LongCatImageTransformer2DModel + +transformer = LongCatImageTransformer2DModel.from_pretrained("meituan-longcat/LongCat-Image ", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## LongCatImageTransformer2DModel + +[[autodoc]] LongCatImageTransformer2DModel \ No newline at end of file diff --git a/docs/source/en/api/pipelines/longcat_image.md b/docs/source/en/api/pipelines/longcat_image.md new file mode 100644 index 000000000000..a7e8a7a3712e --- /dev/null +++ b/docs/source/en/api/pipelines/longcat_image.md @@ -0,0 +1,114 @@ + + +# LongCat-Image + +
+ LoRA +
+ + +We introduce LongCat-Image, a pioneering open-source and bilingual (Chinese-English) foundation model for image generation, designed to address core challenges in multilingual text rendering, photorealism, deployment efficiency, and developer accessibility prevalent in current leading models. + + +### Key Features +- 🌟 **Exceptional Efficiency and Performance**: With only **6B parameters**, LongCat-Image surpasses numerous open-source models that are several times larger across multiple benchmarks, demonstrating the immense potential of efficient model design. +- 🌟 **Superior Editing Performance**: LongCat-Image-Edit model achieves state-of-the-art performance among open-source models, delivering leading instruction-following and image quality with superior visual consistency. +- 🌟 **Powerful Chinese Text Rendering**: LongCat-Image demonstrates superior accuracy and stability in rendering common Chinese characters compared to existing SOTA open-source models and achieves industry-leading coverage of the Chinese dictionary. +- 🌟 **Remarkable Photorealism**: Through an innovative data strategy and training framework, LongCat-Image achieves remarkable photorealism in generated images. +- 🌟 **Comprehensive Open-Source Ecosystem**: We provide a complete toolchain, from intermediate checkpoints to full training code, significantly lowering the barrier for further research and development. + +For more details, please refer to the comprehensive [***LongCat-Image Technical Report***](https://arxiv.org/abs/2412.11963) + + +## Usage Example + +```py +import torch +import diffusers +from diffusers import LongCatImagePipeline + +weight_dtype = torch.bfloat16 +pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16 ) +pipe.to('cuda') +# pipe.enable_model_cpu_offload() + +prompt = '一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。' +image = pipe( + prompt, + height=768, + width=1344, + guidance_scale=4.0, + num_inference_steps=50, + num_images_per_prompt=1, + generator=torch.Generator("cpu").manual_seed(43), + enable_cfg_renorm=True, + enable_prompt_rewrite=True, +).images[0] +image.save(f'./longcat_image_t2i_example.png') +``` + + +This pipeline was contributed by LongCat-Image Team. The original codebase can be found [here](https://github.com/meituan-longcat/LongCat-Image). + +Available models: +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelsTypeDescriptionDownload Link
LongCat‑ImageText‑to‑ImageFinal Release. The standard model for out‑of‑the‑box inference. + 🤗 Huggingface +
LongCat‑Image‑DevText‑to‑ImageDevelopment. Mid-training checkpoint, suitable for fine-tuning. + 🤗 Huggingface +
LongCat‑Image‑EditImage EditingSpecialized model for image editing. + 🤗 Huggingface +
+
+ +## LongCatImagePipeline + +[[autodoc]] LongCatImagePipeline +- all +- __call__ + +## LongCatImagePipelineOutput + +[[autodoc]] pipelines.longcat_image.pipeline_output.LongCatImagePipelineOutput + + + diff --git a/tests/pipelines/longcat_image/__init__.py b/tests/pipelines/longcat_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From 110698bc17f765951d8cda55e3bf4a23a2ef8562 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Sat, 13 Dec 2025 03:41:07 +0800 Subject: [PATCH 12/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py Co-authored-by: YiYi Xu --- .../pipeline_longcat_image_edit.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py index a4df9a8fb2cf..e47f0b28d53c 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py @@ -406,34 +406,6 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() @property def guidance_scale(self): From 833f27530f7c86dc3caff00186f02ea6f3dd8ce5 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Sat, 13 Dec 2025 03:41:24 +0800 Subject: [PATCH 13/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py Co-authored-by: YiYi Xu --- .../pipelines/longcat_image/pipeline_longcat_image_edit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py index e47f0b28d53c..50083f6ef71f 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py @@ -346,7 +346,6 @@ def _encode_prompt( self, prompt, image ): return prompt_embeds - @torch.inference_mode() def encode_prompt(self, prompt : List[str] = None, image: Optional[torch.Tensor] = None, From b5858c464e9655d8440a9e46ae1be7559275d1c4 Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Sat, 13 Dec 2025 03:45:22 +0800 Subject: [PATCH 14/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index 2f783f00954d..09658dd58e13 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -519,7 +519,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # from IPython import embed; embed(); if enable_prompt_rewrite: prompt = self.rewire_prompt(prompt, device ) logger.info(f'Rewrite prompt {prompt}!') From fcfdda951aa9443619d9ce74c31b5105d358399c Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Sat, 13 Dec 2025 03:45:57 +0800 Subject: [PATCH 15/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index 09658dd58e13..9e783cd66531 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -254,8 +254,6 @@ def __init__( self.prompt_template_encode_prefix = '<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n' self.prompt_template_encode_suffix = '<|im_end|>\n<|im_start|>assistant\n' - self.prompt_template_encode_start_idx = 36 - self.prompt_template_encode_end_idx = 5 self.default_sample_size = 128 self.tokenizer_max_length = 512 From 5e2f36c5a22108a814c118bc0503262e88aec4cd Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Sat, 13 Dec 2025 03:46:59 +0800 Subject: [PATCH 16/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index 9e783cd66531..1a224fa19400 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -334,6 +334,8 @@ def _encode_prompt( self, prompt ): ) # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] # clone to have a contiguous tensor + self.prompt_template_encode_start_idx = len(prefix_tokens) + self.prompt_template_encode_end_idx = len(suffix_tokens) prompt_embeds = text_output.hidden_states[-1].detach() prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:] From 9f8103570d5072dcd61cdb6f3a2abac718e8c1ed Mon Sep 17 00:00:00 2001 From: junqiangwu <32137981+junqiangwu@users.noreply.github.com> Date: Sat, 13 Dec 2025 03:48:43 +0800 Subject: [PATCH 17/20] Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index 1a224fa19400..b64c09197afb 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -302,7 +302,7 @@ def _encode_prompt( self, prompt ): if len(all_tokens) > self.tokenizer_max_length: logger.warning( "Your input was truncated because `max_sequence_length` is set to " - f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}" + f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}" ) all_tokens = all_tokens[:self.tokenizer_max_length] From 4e92dc6bbe9c3734a7c5a5422e82547d6741aec0 Mon Sep 17 00:00:00 2001 From: wujunqiang03 <1540970049@qq.com> Date: Sat, 13 Dec 2025 06:06:06 +0800 Subject: [PATCH 18/20] fix code & mask style & fix-copies --- docs/source/en/_toctree.yml | 4 + src/diffusers/__init__.py | 6 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/transformers/__init__.py | 4 +- .../transformers/transformer_longcat_image.py | 24 +- src/diffusers/pipelines/__init__.py | 4 +- .../pipelines/longcat_image/__init__.py | 2 +- .../longcat_image/pipeline_longcat_image.py | 352 +++++++++--------- .../pipeline_longcat_image_edit.py | 292 ++++++++------- .../longcat_image/pipeline_output.py | 1 + .../longcat_image/system_messages.py | 2 - src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 30 ++ 13 files changed, 406 insertions(+), 332 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c176f8786f38..3012d253b51b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -373,6 +373,8 @@ title: LuminaNextDiT2DModel - local: api/models/mochi_transformer3d title: MochiTransformer3DModel + - local: api/models/longcat_image_transformer2d + title: LongCatImageTransformer2DModel - local: api/models/omnigen_transformer title: OmniGenTransformer2DModel - local: api/models/ovisimage_transformer2d @@ -567,6 +569,8 @@ title: Lumina 2.0 - local: api/pipelines/lumina title: Lumina-T2X + - local: api/pipelines/longcat_image + title: LongCat-Image - local: api/pipelines/marigold title: Marigold - local: api/pipelines/panorama diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 02c82edc6a67..9754737b540f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -973,10 +973,10 @@ Kandinsky3UNet, Kandinsky5Transformer3DModel, LatteTransformer3DModel, + LongCatImageTransformer2DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, - LongCatImageTransformer2DModel, MochiTransformer3DModel, ModelMixin, MotionAdapter, @@ -1241,12 +1241,12 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LongCatImageEditPipeline, + LongCatImagePipeline, LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline, - LongCatImagePipeline, - LongCatImageEditPipeline, LucyEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index bd9bd01343bb..b4eed931fc0c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -209,10 +209,10 @@ HunyuanVideoTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, + LongCatImageTransformer2DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, - LongCatImageTransformer2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, OvisImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d1d7d0eb99f2..40b5d4a0dfc9 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -33,9 +33,9 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel + from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel - from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_ovis_image import OvisImageTransformer2DModel @@ -48,4 +48,4 @@ from .transformer_wan import WanTransformer3DModel from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel - from .transformer_z_image import ZImageTransformer2DModel \ No newline at end of file + from .transformer_z_image import ZImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 165808aa2c21..1ef9ebd4a102 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -24,14 +23,14 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import is_torch_npu_available, logging from ...utils.torch_utils import maybe_allow_in_graph -from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin +from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -381,7 +380,6 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin - class LongCatImageTimestepEmbeddings(nn.Module): def __init__(self, embedding_dim): super().__init__() @@ -394,14 +392,15 @@ def forward(self, timestep, hidden_dtype): timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) return timesteps_emb - + class LongCatImageTransformer2DModel( ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, - CacheMixin, ): + CacheMixin, +): """ The Transformer model introduced in Longcat-Image. """ @@ -455,10 +454,8 @@ def __init__( ] ) - self.norm_out = AdaLayerNormContinuous( - self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear( - self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False self.use_checkpoint = [True] * num_layers @@ -498,10 +495,9 @@ def forward( timestep = timestep.to(hidden_states.dtype) * 1000 - temb = self.time_embed( timestep, hidden_states.dtype ) + temb = self.time_embed(timestep, hidden_states.dtype) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - ids = torch.cat((txt_ids, img_ids), dim=0) if is_torch_npu_available(): freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) @@ -528,7 +524,7 @@ def forward( for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_single_checkpoint[index_block]: - encoder_hidden_states,hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 4df047f67b3f..ff5cd829ce8b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -291,7 +291,7 @@ _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] - _import_structure["longcat_image"] = ["LongCatImagePipeline","LongCatImageEditPipeline"] + _import_structure["longcat_image"] = ["LongCatImagePipeline", "LongCatImageEditPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", @@ -719,11 +719,11 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) + from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline - from .longcat_image import LongCatImagePipeline,LongCatImageEditPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, diff --git a/src/diffusers/pipelines/longcat_image/__init__.py b/src/diffusers/pipelines/longcat_image/__init__.py index b68a72f21d4d..cd95e96ece2f 100644 --- a/src/diffusers/pipelines/longcat_image/__init__.py +++ b/src/diffusers/pipelines/longcat_image/__init__.py @@ -33,9 +33,9 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_output import LongCatImagePipelineOutput from .pipeline_longcat_image import LongCatImagePipeline from .pipeline_longcat_image_edit import LongCatImageEditPipeline + from .pipeline_output import LongCatImagePipelineOutput else: import sys diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index b64c09197afb..31391cddf076 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -12,28 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Union -import json -import numpy as np -import torch import re +from typing import Any, Dict, List, Optional, Union +import numpy as np +import torch from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKL +from ...models.transformers import LongCatImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor -from ...pipelines.pipeline_utils import DiffusionPipeline - -from ...models.transformers import LongCatImageTransformer2DModel from .pipeline_output import LongCatImagePipelineOutput -from .system_messages import SYSTEM_PROMPT_EN,SYSTEM_PROMPT_ZH +from .system_messages import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH + if is_torch_xla_available(): import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True else: XLA_AVAILABLE = False @@ -66,17 +66,15 @@ def get_prompt_language(prompt): - pattern = re.compile(r'[\u4e00-\u9fff]') + pattern = re.compile(r"[\u4e00-\u9fff]") if bool(pattern.search(prompt)): - return 'zh' - return 'en' - + return "zh" + return "en" def split_quotation(prompt, quote_pairs=None): """ - Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote pairs. - + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote pairs. Examples:: >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) @@ -87,17 +85,14 @@ def split_quotation(prompt, quote_pairs=None): mapping_word_internal_quote = [] for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): - word_tgt = 'longcat_$##$_longcat' * (i+1) + word_tgt = "longcat_$##$_longcat" * (i + 1) prompt = prompt.replace(word_src, word_tgt) mapping_word_internal_quote.append([word_src, word_tgt]) if quote_pairs is None: - quote_pairs = [("'", "'"), ('"', '"'), ('‘', '’'), ('“', '”')] - - pattern = '|'.join([re.escape(q1) + r'[^' + re.escape(q1+q2) + - r']*?' + re.escape(q2) for q1, q2 in quote_pairs]) - - parts = re.split(f'({pattern})', prompt) + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) result = [] for part in parts: @@ -112,35 +107,24 @@ def split_quotation(prompt, quote_pairs=None): return result -def prepare_pos_ids( - modality_id=0, - type='text', - start=(0, 0), - num_token=None, - height=None, - width=None): - if type == 'text': +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None): + if type == "text": assert num_token if height or width: - print( - 'Warning: The parameters of height and width will be ignored in "text" type.') + print('Warning: The parameters of height and width will be ignored in "text" type.') pos_ids = torch.zeros(num_token, 3) pos_ids[..., 0] = modality_id pos_ids[..., 1] = torch.arange(num_token) + start[0] pos_ids[..., 2] = torch.arange(num_token) + start[1] - elif type == 'image': + elif type == "image": assert height and width if num_token: print('Warning: The parameter of num_token will be ignored in "image" type.') pos_ids = torch.zeros(height, width, 3) pos_ids[..., 0] = modality_id - pos_ids[..., 1] = ( - pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] - ) - pos_ids[..., 2] = ( - pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] - ) - pos_ids = pos_ids.reshape(height*width, 3) + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) else: raise KeyError(f'Unknow type {type}, only support "text" or "image".') return pos_ids @@ -219,14 +203,13 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class LongCatImagePipeline( - DiffusionPipeline,FromSingleFileMixin -): +class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin): r""" The pipeline for text-to-image generation. """ + model_cpu_offload_seq = "text_encoder->transformer->vae" - _optional_components = [ ] + _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( @@ -246,121 +229,130 @@ def __init__( tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, - text_processor=text_processor + text_processor=text_processor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) - - self.prompt_template_encode_prefix = '<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n' - self.prompt_template_encode_suffix = '<|im_end|>\n<|im_start|>assistant\n' + + self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" self.default_sample_size = 128 self.tokenizer_max_length = 512 - def rewire_prompt(self, prompt, device): - language = get_prompt_language(prompt) - if language == 'zh': - question = SYSTEM_PROMPT_ZH + f"\n用户输入为:{prompt}\n改写后的prompt为:" - else: - question = SYSTEM_PROMPT_EN + f"\nUser Input: {prompt}\nRewritten prompt:" - - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": question}, - ], - } - ] - # Preparation for inference - text = self.text_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - inputs = self.text_processor(text=[text],padding=True,return_tensors="pt").to(device) + all_text = [] + for each_prompt in prompt: + language = get_prompt_language(each_prompt) + if language == "zh": + question = SYSTEM_PROMPT_ZH + f"\n用户输入为:{each_prompt}\n改写后的prompt为:" + else: + question = SYSTEM_PROMPT_EN + f"\nUser Input: {each_prompt}\nRewritten prompt:" + message = [ + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + ], + } + ] + # Preparation for inference + text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + all_text.append(text) + + inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device) self.text_encoder.to(device) generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length) - generated_ids_trimmed = [ - out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = self.text_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - )[0] - rewrite_prompt= output_text + ) + rewrite_prompt = output_text return rewrite_prompt - - def _encode_prompt( self, prompt ): - all_tokens = [] - for clean_prompt_sub, matched in split_quotation(prompt[0]): - if matched: - for sub_word in clean_prompt_sub: - tokens = self.tokenizer(sub_word, add_special_tokens=False)['input_ids'] + + def _encode_prompt(self, prompt: List[str]): + batch_all_tokens = [] + + for each_prompt in prompt: + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(each_prompt): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] all_tokens.extend(tokens) - else: - tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)['input_ids'] - all_tokens.extend(tokens) - - if len(all_tokens) > self.tokenizer_max_length: - logger.warning( - "Your input was truncated because `max_sequence_length` is set to " - f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}" - ) - all_tokens = all_tokens[:self.tokenizer_max_length] - + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + batch_all_tokens.append(all_tokens) + text_tokens_and_mask = self.tokenizer.pad( - {'input_ids': [all_tokens]}, + {"input_ids": batch_all_tokens}, max_length=self.tokenizer_max_length, - padding='max_length', + padding="max_length", return_attention_mask=True, - return_tensors='pt') - - prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)['input_ids'] - suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)['input_ids'] - prefix_tokens_mask = torch.tensor( [1]*len(prefix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype ) - suffix_tokens_mask = torch.tensor( [1]*len(suffix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype ) - - prefix_tokens = torch.tensor(prefix_tokens,dtype=text_tokens_and_mask.input_ids.dtype) - suffix_tokens = torch.tensor(suffix_tokens,dtype=text_tokens_and_mask.input_ids.dtype) - - input_ids = torch.cat( (prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1 ) - attention_mask = torch.cat( (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 ) - - input_ids = input_ids.unsqueeze(0).to(self.device) - attention_mask = attention_mask.unsqueeze(0).to(self.device) - - text_output = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True + return_tensors="pt", ) + + prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + prefix_len = len(prefix_tokens) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + batch_size = text_tokens_and_mask.input_ids.size(0) + + prefix_tokens_batch = prefix_tokens.unsqueeze(0).expand(batch_size, -1) + suffix_tokens_batch = suffix_tokens.unsqueeze(0).expand(batch_size, -1) + prefix_mask_batch = prefix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + suffix_mask_batch = suffix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + + input_ids = torch.cat((prefix_tokens_batch, text_tokens_and_mask.input_ids, suffix_tokens_batch), dim=-1) + attention_mask = torch.cat((prefix_mask_batch, text_tokens_and_mask.attention_mask, suffix_mask_batch), dim=-1) + + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] # clone to have a contiguous tensor - self.prompt_template_encode_start_idx = len(prefix_tokens) - self.prompt_template_encode_end_idx = len(suffix_tokens) prompt_embeds = text_output.hidden_states[-1].detach() - prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:] - + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] return prompt_embeds - def encode_prompt(self, - prompt : List[str] = None, - num_images_per_prompt: Optional[int] = 1, - prompt_embeds: Optional[torch.Tensor] = None ): + def encode_prompt( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: Optional[int] = 1, + prompt_embeds: Optional[torch.Tensor] = None, + ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) # If prompt_embeds is provided and prompt is None, skip encoding if prompt_embeds is None: - prompt_embeds = self._encode_prompt( prompt ) + prompt_embeds = self._encode_prompt(prompt) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - text_ids = prepare_pos_ids(modality_id=0, - type='text', - start=(0, 0), - num_token=prompt_embeds.shape[1]).to(self.device) - return prompt_embeds.to(self.device), text_ids + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds.to(self.device), text_ids @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -386,11 +378,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents - - @property - def guidance_scale(self): - return self._guidance_scale - @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @@ -412,13 +399,14 @@ def prepare_latents( width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = prepare_pos_ids(modality_id=1, - type='image', - start=(self.tokenizer_max_length, - self.tokenizer_max_length), - height=height//2, - width=width//2).to(device) - + latent_image_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(self.tokenizer_max_length, self.tokenizer_max_length), + height=height // 2, + width=width // 2, + ).to(device) + if latents is not None: return latents.to(device=device, dtype=dtype), latent_image_ids @@ -428,7 +416,7 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator,device=device) + latents = randn_tensor(shape, generator=generator, device=device) latents = latents.to(dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) @@ -454,6 +442,32 @@ def current_timestep(self): def interrupt(self): return self._interrupt + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + @replace_example_docstring(EXAMPLE_DOC_STRING) @torch.no_grad() def __call__( @@ -466,8 +480,7 @@ def __call__( sigmas: Optional[List[float]] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, - List[torch.Generator]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, @@ -497,14 +510,17 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor - if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: - logger.warning( - f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" - ) - pixel_step= self.vae_scale_factor * 2 - height = int( height/pixel_step )*pixel_step - width = int( width/pixel_step )*pixel_step - + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._current_timestep = None @@ -517,30 +533,23 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - + device = self._execution_device if enable_prompt_rewrite: - prompt = self.rewire_prompt(prompt, device ) - logger.info(f'Rewrite prompt {prompt}!') - print( f'Rewrite prompt {prompt}!') - - negative_prompt = '' if negative_prompt is None else negative_prompt - ( - prompt_embeds, - text_ids - ) = self.encode_prompt( prompt=prompt, - prompt_embeds=prompt_embeds, - num_images_per_prompt= num_images_per_prompt + prompt = self.rewire_prompt(prompt, device) + logger.info(f"Rewrite prompt {prompt}!") + + negative_prompt = "" if negative_prompt is None else negative_prompt + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt ) if self.do_classifier_free_guidance: - ( - negative_prompt_embeds, - negative_text_ids - ) = self.encode_prompt( prompt=negative_prompt, - prompt_embeds=negative_prompt_embeds, - num_images_per_prompt=num_images_per_prompt + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, ) - + # 4. Prepare latent variables num_channels_latents = 16 latents, latent_image_ids = self.prepare_latents( @@ -555,7 +564,7 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1.0 / num_inference_steps,num_inference_steps) if sigmas is None else sigmas + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -580,9 +589,6 @@ def __call__( if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} - # if self.do_classifier_free_guidance: - # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0).to(device) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -612,15 +618,15 @@ def __call__( return_dict=False, )[0] noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - + if enable_cfg_renorm: cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - scale = (cond_norm / (noise_norm + 1e-8)).clamp(min= cfg_renorm_min , max=1.0) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) noise_pred = noise_pred * scale else: noise_pred = noise_pred_text - + # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py index 50083f6ef71f..616cb1eb15fc 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py @@ -13,28 +13,28 @@ # limitations under the License. import inspect import math -from typing import Any, Callable, Dict, List, Optional, Union -import json +import re +from typing import Any, Dict, List, Optional, Union + import numpy as np import PIL import torch -import re - from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKL +from ...models.transformers import LongCatImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor -from ...pipelines.pipeline_utils import DiffusionPipeline - -from ...models.transformers import LongCatImageTransformer2DModel from .pipeline_output import LongCatImagePipelineOutput + if is_torch_xla_available(): import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True else: XLA_AVAILABLE = False @@ -65,10 +65,11 @@ ``` """ + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.split_quotation def split_quotation(prompt, quote_pairs=None): """ - Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote pairs. - + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote pairs. Examples:: >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) @@ -79,17 +80,14 @@ def split_quotation(prompt, quote_pairs=None): mapping_word_internal_quote = [] for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): - word_tgt = 'longcat_$##$_longcat' * (i+1) + word_tgt = "longcat_$##$_longcat" * (i + 1) prompt = prompt.replace(word_src, word_tgt) mapping_word_internal_quote.append([word_src, word_tgt]) if quote_pairs is None: - quote_pairs = [("'", "'"), ('"', '"'), ('‘', '’'), ('“', '”')] - - pattern = '|'.join([re.escape(q1) + r'[^' + re.escape(q1+q2) + - r']*?' + re.escape(q2) for q1, q2 in quote_pairs]) - - parts = re.split(f'({pattern})', prompt) + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) result = [] for part in parts: @@ -104,41 +102,31 @@ def split_quotation(prompt, quote_pairs=None): return result -def prepare_pos_ids( - modality_id=0, - type='text', - start=(0, 0), - num_token=None, - height=None, - width=None): - if type == 'text': +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.prepare_pos_ids +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None): + if type == "text": assert num_token if height or width: - print( - 'Warning: The parameters of height and width will be ignored in "text" type.') + print('Warning: The parameters of height and width will be ignored in "text" type.') pos_ids = torch.zeros(num_token, 3) pos_ids[..., 0] = modality_id pos_ids[..., 1] = torch.arange(num_token) + start[0] pos_ids[..., 2] = torch.arange(num_token) + start[1] - elif type == 'image': + elif type == "image": assert height and width if num_token: print('Warning: The parameter of num_token will be ignored in "image" type.') pos_ids = torch.zeros(height, width, 3) pos_ids[..., 0] = modality_id - pos_ids[..., 1] = ( - pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] - ) - pos_ids[..., 2] = ( - pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] - ) - pos_ids = pos_ids.reshape(height*width, 3) + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) else: raise KeyError(f'Unknow type {type}, only support "text" or "image".') return pos_ids - +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, @@ -211,6 +199,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -224,6 +213,7 @@ def retrieve_latents( else: raise AttributeError("Could not access latents of provided encoder_output") + def calculate_dimensions(target_area, ratio): width = math.sqrt(target_area * ratio) height = width / ratio @@ -236,15 +226,14 @@ def calculate_dimensions(target_area, ratio): return width, height -class LongCatImageEditPipeline( - DiffusionPipeline,FromSingleFileMixin -): + +class LongCatImageEditPipeline(DiffusionPipeline, FromSingleFileMixin): r""" The LongCat-Image-Edit pipeline for image editing. """ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" - _optional_components = [ ] + _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( @@ -264,67 +253,71 @@ def __init__( tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, - text_processor=text_processor + text_processor=text_processor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.image_processor_vl = text_processor.image_processor - + self.image_token = "<|image_pad|>" self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" - self.prompt_template_encode_suffix = '<|im_end|>\n<|im_start|>assistant\n' - self.prompt_template_encode_start_idx = 67 - self.prompt_template_encode_end_idx = 5 + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" self.default_sample_size = 128 self.tokenizer_max_length = 512 - def _encode_prompt( self, prompt, image ): - raw_vl_input = self.image_processor_vl(images=image,return_tensors="pt") - pixel_values = raw_vl_input['pixel_values'] - image_grid_thw = raw_vl_input['image_grid_thw'] + def _encode_prompt(self, prompt, image): + raw_vl_input = self.image_processor_vl(images=image, return_tensors="pt") + pixel_values = raw_vl_input["pixel_values"] + image_grid_thw = raw_vl_input["image_grid_thw"] all_tokens = [] for clean_prompt_sub, matched in split_quotation(prompt[0]): if matched: for sub_word in clean_prompt_sub: - tokens = self.tokenizer(sub_word, add_special_tokens=False)['input_ids'] + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] all_tokens.extend(tokens) else: - tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)['input_ids'] + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] all_tokens.extend(tokens) - + if len(all_tokens) > self.tokenizer_max_length: logger.warning( "Your input was truncated because `max_sequence_length` is set to " f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}" ) - all_tokens = all_tokens[:self.tokenizer_max_length] - + all_tokens = all_tokens[: self.tokenizer_max_length] + text_tokens_and_mask = self.tokenizer.pad( - {'input_ids': [all_tokens]}, + {"input_ids": [all_tokens]}, max_length=self.tokenizer_max_length, - padding='max_length', + padding="max_length", return_attention_mask=True, - return_tensors='pt') - + return_tensors="pt", + ) + text = self.prompt_template_encode_prefix - + merge_length = self.image_processor_vl.merge_size**2 while self.image_token in text: num_image_tokens = image_grid_thw.prod() // merge_length text = text.replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) text = text.replace("<|placeholder|>", self.image_token) - prefix_tokens = self.tokenizer(text, add_special_tokens=False)['input_ids'] - suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)['input_ids'] - prefix_tokens_mask = torch.tensor( [1]*len(prefix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype ) - suffix_tokens_mask = torch.tensor( [1]*len(suffix_tokens), dtype = text_tokens_and_mask.attention_mask[0].dtype ) + prefix_tokens = self.tokenizer(text, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + prefix_len = len(prefix_tokens) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) - prefix_tokens = torch.tensor(prefix_tokens,dtype=text_tokens_and_mask.input_ids.dtype) - suffix_tokens = torch.tensor(suffix_tokens,dtype=text_tokens_and_mask.input_ids.dtype) - - input_ids = torch.cat( (prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1 ) - attention_mask = torch.cat( (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 ) + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + input_ids = torch.cat((prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1) + attention_mask = torch.cat( + (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 + ) input_ids = input_ids.unsqueeze(0).to(self.device) attention_mask = attention_mask.unsqueeze(0).to(self.device) @@ -336,37 +329,37 @@ def _encode_prompt( self, prompt, image ): input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, - image_grid_thw =image_grid_thw, - output_hidden_states=True + image_grid_thw=image_grid_thw, + output_hidden_states=True, ) # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] # clone to have a contiguous tensor prompt_embeds = text_output.hidden_states[-1].detach() - prompt_embeds = prompt_embeds[:,self.prompt_template_encode_start_idx: -self.prompt_template_encode_end_idx ,:] - + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] return prompt_embeds - def encode_prompt(self, - prompt : List[str] = None, - image: Optional[torch.Tensor] = None, - num_images_per_prompt: Optional[int] = 1, - prompt_embeds: Optional[torch.Tensor] = None): + def encode_prompt( + self, + prompt: List[str] = None, + image: Optional[torch.Tensor] = None, + num_images_per_prompt: Optional[int] = 1, + prompt_embeds: Optional[torch.Tensor] = None, + ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) # If prompt_embeds is provided and prompt is None, skip encoding if prompt_embeds is None: - prompt_embeds = self._encode_prompt( prompt, image ) - + prompt_embeds = self._encode_prompt(prompt, image) + _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - text_ids = prepare_pos_ids(modality_id=0, - type='text', - start=(0, 0), - num_token=prompt_embeds.shape[1]).to(self.device) - return prompt_embeds, text_ids + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds, text_ids @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -391,7 +384,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents - + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ @@ -405,11 +398,6 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents - - @property - def guidance_scale(self): - return self._guidance_scale - @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @@ -451,25 +439,24 @@ def prepare_latents( else: image_latents = torch.cat([image_latents], dim=0) - image_latents = self._pack_latents( - image_latents, batch_size, num_channels_latents, height, width - ) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) - image_latents_ids = prepare_pos_ids(modality_id=2, - type='image', - start=(prompt_embeds_length, - prompt_embeds_length), - height=height//2, - width=width//2).to(device, dtype=torch.float64) + image_latents_ids = prepare_pos_ids( + modality_id=2, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device, dtype=torch.float64) shape = (batch_size, num_channels_latents, height, width) - latents_ids = prepare_pos_ids(modality_id=1, - type='image', - start=(prompt_embeds_length, - prompt_embeds_length), - height=height//2, - width=width//2).to(device) - + latents_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -505,6 +492,39 @@ def current_timestep(self): def interrupt(self): return self._interrupt + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None: + if isinstance(prompt, str): + pass + elif isinstance(prompt, list) and len(prompt) == 1: + pass + else: + raise ValueError( + f"`prompt` must be a `str` or a `list` of length 1, but is {prompt} (type: {type(prompt)})" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + @replace_example_docstring(EXAMPLE_DOC_STRING) @torch.no_grad() def __call__( @@ -516,8 +536,7 @@ def __call__( sigmas: Optional[List[float]] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, - List[torch.Generator]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, @@ -536,9 +555,19 @@ def __call__( images. """ - image_size = image.size - calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0]*1.0/image_size[1]) - + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + calculated_height, + calculated_width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._current_timestep = None @@ -551,30 +580,25 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - - device = self._execution_device - image = self.image_processor.resize(image, calculated_height, calculated_width) - prompt_image = self.image_processor.resize(image, calculated_height//2, calculated_width//2) - image = self.image_processor.preprocess(image, calculated_height, calculated_width) - - negative_prompt = '' if negative_prompt is None else negative_prompt - ( - prompt_embeds, - text_ids - ) = self.encode_prompt( prompt=prompt, - image=prompt_image, - prompt_embeds=prompt_embeds, - num_images_per_prompt= num_images_per_prompt + device = self._execution_device + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = self.image_processor.resize(image, calculated_height // 2, calculated_width // 2) + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + + negative_prompt = "" if negative_prompt is None else negative_prompt + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, image=prompt_image, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt ) if self.do_classifier_free_guidance: - ( - negative_prompt_embeds, - negative_text_ids - ) = self.encode_prompt( prompt=negative_prompt, - image=prompt_image, - prompt_embeds=negative_prompt_embeds, - num_images_per_prompt=num_images_per_prompt + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + image=prompt_image, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, ) # 4. Prepare latent variables @@ -593,7 +617,7 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1.0 / num_inference_steps,num_inference_steps) if sigmas is None else sigmas + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -635,7 +659,7 @@ def __call__( if image_latents is not None: latent_model_input = torch.cat([latents, image_latents], dim=1) - #latent_model_input = torch.cat([latent_model_input] * 2) if self.do_classifier_free_guidance else latent_model_input + # latent_model_input = torch.cat([latent_model_input] * 2) if self.do_classifier_free_guidance else latent_model_input timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) with self.transformer.cache_context("cond"): noise_pred_text = self.transformer( @@ -692,7 +716,7 @@ def __call__( image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) - #Offload all models + # Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/longcat_image/pipeline_output.py b/src/diffusers/pipelines/longcat_image/pipeline_output.py index b600f9dd1fe1..e3c25f1cbfa7 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_output.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_output.py @@ -3,6 +3,7 @@ import numpy as np import PIL.Image + from diffusers.utils import BaseOutput diff --git a/src/diffusers/pipelines/longcat_image/system_messages.py b/src/diffusers/pipelines/longcat_image/system_messages.py index 77c017cfffdc..48ca605db279 100644 --- a/src/diffusers/pipelines/longcat_image/system_messages.py +++ b/src/diffusers/pipelines/longcat_image/system_messages.py @@ -1,5 +1,3 @@ - - SYSTEM_PROMPT_EN = """ You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all information from the user's original prompt without deleting or distorting any details. Specific requirements are as follows: diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8628893200fe..606fea18c750 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1132,6 +1132,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LongCatImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ff65372f3c89..be7f8f8ce4a3 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1832,6 +1832,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LongCatImageEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LongCatImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXConditionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From b43ba6b1439d13e183bcd0d273b3d08a01d8c93c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 15 Dec 2025 02:17:57 +0000 Subject: [PATCH 19/20] Apply style fixes --- docs/source/en/_toctree.yml | 10 +- src/diffusers/__init__.py | 6 +- src/diffusers/models/__init__.py | 2 +- .../transformers/transformer_longcat_image.py | 2 +- .../pipelines/longcat_image/__init__.py | 2 +- .../longcat_image/pipeline_longcat_image.py | 19 ++- .../pipeline_longcat_image_edit.py | 29 ++-- .../longcat_image/system_messages.py | 141 ++++++++++++++---- 8 files changed, 144 insertions(+), 67 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3012d253b51b..f0cb0164436e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -365,6 +365,8 @@ title: HunyuanVideoTransformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel + - local: api/models/longcat_image_transformer2d + title: LongCatImageTransformer2DModel - local: api/models/ltx_video_transformer3d title: LTXVideoTransformer3DModel - local: api/models/lumina2_transformer2d @@ -373,8 +375,6 @@ title: LuminaNextDiT2DModel - local: api/models/mochi_transformer3d title: MochiTransformer3DModel - - local: api/models/longcat_image_transformer2d - title: LongCatImageTransformer2DModel - local: api/models/omnigen_transformer title: OmniGenTransformer2DModel - local: api/models/ovisimage_transformer2d @@ -404,7 +404,7 @@ - local: api/models/wan_transformer_3d title: WanTransformer3DModel - local: api/models/z_image_transformer2d - title: ZImageTransformer2DModel + title: ZImageTransformer2DModel title: Transformers - sections: - local: api/models/stable_cascade_unet @@ -565,12 +565,12 @@ title: Latent Diffusion - local: api/pipelines/ledits_pp title: LEDITS++ + - local: api/pipelines/longcat_image + title: LongCat-Image - local: api/pipelines/lumina2 title: Lumina 2.0 - local: api/pipelines/lumina title: Lumina-T2X - - local: api/pipelines/longcat_image - title: LongCat-Image - local: api/pipelines/marigold title: Marigold - local: api/pipelines/panorama diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9754737b540f..86f196f9be49 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -235,10 +235,10 @@ "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", + "LongCatImageTransformer2DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", - "LongCatImageTransformer2DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", @@ -533,6 +533,8 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LongCatImageEditPipeline", + "LongCatImagePipeline", "LTXConditionPipeline", "LTXImageToVideoPipeline", "LTXLatentUpsamplePipeline", @@ -542,8 +544,6 @@ "Lumina2Text2ImgPipeline", "LuminaPipeline", "LuminaText2ImgPipeline", - "LongCatImagePipeline", - "LongCatImageEditPipeline", "MarigoldDepthPipeline", "MarigoldIntrinsicsPipeline", "MarigoldNormalsPipeline", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b4eed931fc0c..4c1b397bdf8c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -101,9 +101,9 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] + _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] - _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 1ef9ebd4a102..7fbaaa3fee85 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -472,7 +472,7 @@ def forward( return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ - The forward method. + The forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): diff --git a/src/diffusers/pipelines/longcat_image/__init__.py b/src/diffusers/pipelines/longcat_image/__init__.py index cd95e96ece2f..e4bb0e5819c8 100644 --- a/src/diffusers/pipelines/longcat_image/__init__.py +++ b/src/diffusers/pipelines/longcat_image/__init__.py @@ -21,9 +21,9 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_output"] = ["LongCatImagePipelineOutput"] _import_structure["pipeline_longcat_image"] = ["LongCatImagePipeline"] _import_structure["pipeline_longcat_image_edit"] = ["LongCatImageEditPipeline"] + _import_structure["pipeline_output"] = ["LongCatImagePipelineOutput"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index 31391cddf076..1c5854b091ab 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -50,7 +50,7 @@ >>> pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> prompt = '一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。' + >>> prompt = "一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。" >>> image = pipe( ... prompt, ... height=768, @@ -58,7 +58,7 @@ ... num_inference_steps=50, ... guidance_scale=4.5, ... generator=torch.Generator("cpu").manual_seed(43), - ... enable_cfg_renorm=True + ... enable_cfg_renorm=True, ... ).images[0] >>> image.save("longcat_image.png") ``` @@ -74,11 +74,10 @@ def get_prompt_language(prompt): def split_quotation(prompt, quote_pairs=None): """ - Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote pairs. - Examples:: - >>> prompt_en = "Please write 'Hello' on the blackboard for me." - >>> print(split_quotation(prompt_en)) - >>> # output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote + pairs. Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> # + output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] """ word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) @@ -503,9 +502,9 @@ def __call__( Examples: Returns: - [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. """ height = height or self.default_sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py index 616cb1eb15fc..988aef212507 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py @@ -49,11 +49,13 @@ >>> import torch >>> from diffusers import LongCatImageEditPipeline - >>> pipe = LongCatImageEditPipeline.from_pretrained("meituan-longcat/LongCat-Image-Edit", torch_dtype=torch.bfloat16) + >>> pipe = LongCatImageEditPipeline.from_pretrained( + ... "meituan-longcat/LongCat-Image-Edit", torch_dtype=torch.bfloat16 + ... ) >>> pipe.to("cuda") - >>> prompt = 'change the cat to dog.' - >>> input_image = Image.open('test.jpg').convert("RGB") + >>> prompt = "change the cat to dog." + >>> input_image = Image.open("test.jpg").convert("RGB") >>> image = pipe( ... input_image, ... prompt, @@ -69,11 +71,10 @@ # Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.split_quotation def split_quotation(prompt, quote_pairs=None): """ - Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote pairs. - Examples:: - >>> prompt_en = "Please write 'Hello' on the blackboard for me." - >>> print(split_quotation(prompt_en)) - >>> # output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote + pairs. Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> # + output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] """ word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) @@ -518,7 +519,7 @@ def check_inputs( raise ValueError( f"`prompt` must be a `str` or a `list` of length 1, but is {prompt} (type: {type(prompt)})" ) - + if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" @@ -550,9 +551,9 @@ def __call__( Examples: Returns: - [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. """ image_size = image[0].size if isinstance(image, list) else image.size @@ -582,13 +583,13 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - + # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): image = self.image_processor.resize(image, calculated_height, calculated_width) prompt_image = self.image_processor.resize(image, calculated_height // 2, calculated_width // 2) image = self.image_processor.preprocess(image, calculated_height, calculated_width) - + negative_prompt = "" if negative_prompt is None else negative_prompt (prompt_embeds, text_ids) = self.encode_prompt( prompt=prompt, image=prompt_image, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt diff --git a/src/diffusers/pipelines/longcat_image/system_messages.py b/src/diffusers/pipelines/longcat_image/system_messages.py index 48ca605db279..b8b2318e4e81 100644 --- a/src/diffusers/pipelines/longcat_image/system_messages.py +++ b/src/diffusers/pipelines/longcat_image/system_messages.py @@ -1,37 +1,110 @@ SYSTEM_PROMPT_EN = """ -You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all information from the user's original prompt without deleting or distorting any details. -Specific requirements are as follows: -1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as concise as possible. -2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields English output. The rewritten token count should not exceed 512. -3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the original prompt, such as lighting and textures. -4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography style**. If the user specifies a style, retain the user's style. -5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a giraffe"). -6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% OFF"`). -7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer with the image title 'Grassland'." -8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all. +You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in +understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's +understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all +information from the user's original prompt without deleting or distorting any details. Specific requirements are as +follows: +1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use + coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as + concise as possible. +2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields + English output. The rewritten token count should not exceed 512. +3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the + original prompt, such as lighting and textures. +4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography + style**. If the user specifies a style, retain the user's style. +5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge + to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a + giraffe"). +6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% + OFF"`). +7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no + specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For + example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer + with the image title 'Grassland'." +8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For + example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all. 9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**. -Here are examples of rewrites for different types of prompts: -# Examples (Few-Shot Learning) +Here are examples of rewrites for different types of prompts: # Examples (Few-Shot Learning) 1. User Input: An animal with nine lives. - Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home environment with light from the window filtering through curtains, creating a warm light and shadow effect. The shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image. + Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home + environment with light from the window filtering through curtains, creating a warm light and shadow effect. The + shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits + the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image. 2. User Input: Create an anime-style tourism flyer with a grassland theme. - Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere. + Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped + rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her + left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs + covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To + the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The + grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies + the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is + a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, + and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere. 3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer. - Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing strong visual appeal. - 4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident posture. - Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. The overall style is natural, elegant, and artistic. - 5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting. - Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the apple's life cycle. - 6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate a four-color rainbow based on this rule. The color order from top to bottom is 3142. - Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast with the background colors to ensure good readability. The stripes have high color saturation and a slight texture. The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing the numerical information. The image is high definition, with accurate colors and a consistent style, offering strong visual appeal. - 7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a Chinese garden. - Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the stone tablet and the classical beauty of the garden. -# Output Format -Please directly output the rewritten and optimized Prompt content. Do not include any explanatory language or JSON formatting, and do not add opening or closing quotes yourself.""" + Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and + left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, + golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two + transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls + scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, + and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, + accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing + strong visual appeal. + 4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident + posture. + Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her + shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long + eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She + has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. + Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a + black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and + metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible + knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a + relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute + focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots + on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional + and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are + dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. + The overall style is natural, elegant, and artistic. + 5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should + include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting. + Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage + precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark + soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with + green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, + stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden + light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches + and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under + a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into + the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the + orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a + realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the + apple's life cycle. + 6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate + a four-color rainbow based on this rule. The color order from top to bottom is 3142. + Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as + purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the + number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the + bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast + with the background colors to ensure good readability. The stripes have high color saturation and a slight texture. + The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing + the numerical information. The image is high definition, with accurate colors and a consistent style, offering + strong visual appeal. + 7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a + Chinese garden. + Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with + traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the + stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo + forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a + realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the + stone tablet and the classical beauty of the garden. +# Output Format Please directly output the rewritten and optimized Prompt content. Do not include any explanatory +language or JSON formatting, and do not add opening or closing quotes yourself.""" SYSTEM_PROMPT_ZH = """ -你是一名文生图模型的prompt engineering专家。由于文生图模型对用户prompt的理解能力有限,你需要识别用户输入的核心主题和意图,并通过优化改写提升模型的理解准确性和生成质量。改写必须严格保留用户原始prompt的所有信息,不得删减或曲解任何细节。 +你是一名文生图模型的prompt +engineering专家。由于文生图模型对用户prompt的理解能力有限,你需要识别用户输入的核心主题和意图,并通过优化改写提升模型的理解准确性和生成质量。改写必须严格保留用户原始prompt的所有信息,不得删减或曲解任何细节。 具体要求如下: 1. 改写不能影响用户原始prompt里表达的任何信息,改写后的prompt应该使用连贯的自然语言表达,不要出现低信息量的冗余描述,尽可能保持改写后prompt长度精简。 2. 请确保输入和输出的语言类型一致,中文输入中文输出,英文输入英文输出,改写后的token数量不要超过512个; @@ -47,19 +120,23 @@ # Examples (Few-Shot Learning) 1. 用户输入: 九条命的动物。 - 改写输出: 一只猫,被柔和的阳光笼罩着,毛发柔软而富有光泽。背景是一个舒适的家居环境,窗外的光线透过窗帘,形成温馨的光影效果。镜头采用中距离视角,突出猫悠闲舒展的姿态。光线巧妙地打在猫的脸部,强调它灵动的眼睛和精致的胡须,增加画面的层次感与亲和力。 + 改写输出: + 一只猫,被柔和的阳光笼罩着,毛发柔软而富有光泽。背景是一个舒适的家居环境,窗外的光线透过窗帘,形成温馨的光影效果。镜头采用中距离视角,突出猫悠闲舒展的姿态。光线巧妙地打在猫的脸部,强调它灵动的眼睛和精致的胡须,增加画面的层次感与亲和力。 2. 用户输入: 制作一个动画风格的旅游宣传单,以草原为主题。 - 改写输出: 画面中央偏右下角,一个短发女孩侧身坐在灰色的不规则形状岩石上,她穿着白色短袖连衣裙和棕色平底鞋,左手拿着一束白色小花,面带微笑,双腿自然垂下。女孩的头发为深棕色,齐肩短发,刘海覆盖额头,眼睛呈棕色,嘴巴微张。岩石表面有深浅不一的纹理。女孩的左侧和前方是茂盛的草地,草叶细长,呈黄绿色,部分草叶在阳光下泛着金色的光芒,仿佛被阳光照亮。草地向远处延伸,形成连绵起伏的绿色山丘,山丘的颜色由近及远逐渐变浅。天空占据了画面的上半部分,呈淡蓝色,点缀着几朵白色蓬松的云彩。画面的左上角有一行文字,文字内容是斜体、深绿色的“Explore Nature's Peace”。色彩以绿色、蓝色和黄色为主,线条流畅,光影明暗对比明显,营造出一种宁静、舒适的氛围。 + 改写输出: + 画面中央偏右下角,一个短发女孩侧身坐在灰色的不规则形状岩石上,她穿着白色短袖连衣裙和棕色平底鞋,左手拿着一束白色小花,面带微笑,双腿自然垂下。女孩的头发为深棕色,齐肩短发,刘海覆盖额头,眼睛呈棕色,嘴巴微张。岩石表面有深浅不一的纹理。女孩的左侧和前方是茂盛的草地,草叶细长,呈黄绿色,部分草叶在阳光下泛着金色的光芒,仿佛被阳光照亮。草地向远处延伸,形成连绵起伏的绿色山丘,山丘的颜色由近及远逐渐变浅。天空占据了画面的上半部分,呈淡蓝色,点缀着几朵白色蓬松的云彩。画面的左上角有一行文字,文字内容是斜体、深绿色的“Explore + Nature's Peace”。色彩以绿色、蓝色和黄色为主,线条流畅,光影明暗对比明显,营造出一种宁静、舒适的氛围。 3. 用户输入: 一张以红色为背景的圣诞节促销海报,主要宣传奶茶买一送一的优惠活动。 - 改写输出: 海报整体呈现红色调,上方和左侧点缀着白色雪花图案,右上方有一束冬青叶和红色浆果,以及一个松果。海报中央偏上位置,金色立体字样“圣诞节 暖心回馈”居中排列,和红色粗体字“买1送1”。海报下方,两个装满珍珠奶茶的透明杯子并排摆放,杯中奶茶呈浅棕色,底部和中间散布着深棕色珍珠。杯子下方,堆积着白色雪花,雪花上装饰着松枝、红色浆果和松果。右下角隐约可见一棵模糊的圣诞树。图片清晰度高,文字内容准确,整体设计风格统一,圣诞主题突出,排版布局合理,具有较强的视觉吸引力。 + 改写输出: 海报整体呈现红色调,上方和左侧点缀着白色雪花图案,右上方有一束冬青叶和红色浆果,以及一个松果。海报中央偏上位置,金色立体字样“圣诞节 + 暖心回馈”居中排列,和红色粗体字“买1送1”。海报下方,两个装满珍珠奶茶的透明杯子并排摆放,杯中奶茶呈浅棕色,底部和中间散布着深棕色珍珠。杯子下方,堆积着白色雪花,雪花上装饰着松枝、红色浆果和松果。右下角隐约可见一棵模糊的圣诞树。图片清晰度高,文字内容准确,整体设计风格统一,圣诞主题突出,排版布局合理,具有较强的视觉吸引力。 4. 用户输入: 一位女性在室内以自然光线拍摄,她面带微笑,双臂交叉,展现出轻松自信的姿态。 - 改写输出: 画面中是一位年轻的亚洲女性,她拥有深棕色的长发,发丝自然地垂落在双肩,部分发丝被光线照亮,呈现出柔和的光泽。她的五官精致,眉毛修长,眼睛明亮有神,瞳孔呈深棕色,眼神直视镜头,流露出平和与自信。鼻梁挺拔,嘴唇丰满,涂有裸色系唇膏,嘴角微微上扬,展现出浅浅的微笑。她的肤色白皙,脸颊和锁骨处被暖色调的光线照亮,呈现出健康的红润感。她穿着一件黑色的细吊带背心,肩带纤细,露出优美的锁骨线条。脖颈上佩戴着一条金色的细项链,项链由小珠子和几个细长的金属条组成,在光线下闪烁着光泽。她的外搭是一件米黄色的针织开衫,材质柔软,袖子部分有明显的针织纹理。她双臂交叉在胸前,双手被开衫的袖子覆盖,姿态放松。背景是纯粹的深棕色,没有多余的装饰,使得人物成为画面的绝对焦点。人物位于画面中央。光线从画面的右上方射入,在人物的左侧脸颊、脖颈和锁骨处形成明亮的光斑,右侧则略显阴影,营造出立体感和柔和的影调。图像细节清晰,人物的皮肤纹理、发丝以及衣物材质都得到了很好的展现。色彩以暖色调为主,米黄色和深棕色的搭配营造出温馨舒适的氛围。整体呈现出一种自然、优雅且富有亲和力的艺术风格。 + 改写输出: + 画面中是一位年轻的亚洲女性,她拥有深棕色的长发,发丝自然地垂落在双肩,部分发丝被光线照亮,呈现出柔和的光泽。她的五官精致,眉毛修长,眼睛明亮有神,瞳孔呈深棕色,眼神直视镜头,流露出平和与自信。鼻梁挺拔,嘴唇丰满,涂有裸色系唇膏,嘴角微微上扬,展现出浅浅的微笑。她的肤色白皙,脸颊和锁骨处被暖色调的光线照亮,呈现出健康的红润感。她穿着一件黑色的细吊带背心,肩带纤细,露出优美的锁骨线条。脖颈上佩戴着一条金色的细项链,项链由小珠子和几个细长的金属条组成,在光线下闪烁着光泽。她的外搭是一件米黄色的针织开衫,材质柔软,袖子部分有明显的针织纹理。她双臂交叉在胸前,双手被开衫的袖子覆盖,姿态放松。背景是纯粹的深棕色,没有多余的装饰,使得人物成为画面的绝对焦点。人物位于画面中央。光线从画面的右上方射入,在人物的左侧脸颊、脖颈和锁骨处形成明亮的光斑,右侧则略显阴影,营造出立体感和柔和的影调。图像细节清晰,人物的皮肤纹理、发丝以及衣物材质都得到了很好的展现。色彩以暖色调为主,米黄色和深棕色的搭配营造出温馨舒适的氛围。整体呈现出一种自然、优雅且富有亲和力的艺术风格。 5. 用户输入:创作一系列图片,展现苹果从种子到结果的生长过程。该系列图片应包含以下四个阶段:1. 播种,2. 幼苗生长,3. 植物成熟,4. 果实采摘。 改写输出:一个4宫格的精美插图,描绘苹果的生长过程,精确清晰地捕捉每个阶段。1.“播种”:特写镜头,一只手轻轻地将一颗小小的苹果种子放入肥沃的深色土壤中,土壤的纹理和种子光滑的表面清晰可见。背景是花园的柔焦画面,点缀着绿色的树叶和透过树叶洒下的阳光。2.“幼苗生长”:一棵幼小的苹果树苗破土而出,嫩绿的叶子向天空舒展。场景设定在一个生机勃勃的花园中,温暖的金光照亮了它。幼苗的纤细结构。3.“植物的成熟”:一棵成熟的苹果树,枝繁叶茂,挂满了嫩绿的叶子和正在萌发的小苹果。背景是一片生机勃勃的果园,湛蓝的天空下,斑驳的阳光营造出宁静祥和的氛围。4.“采摘果实”:一只手伸向树上,摘下一个成熟的红苹果,苹果光滑的果皮在阳光下闪闪发光。画面展现了果园的丰收景象,背景中摆放着一篮篮的苹果,给人一种圆满满足的感觉。每幅插图都采用写实风格,注重细节,色彩和谐,展现了苹果生命周期的自然之美和发展过程。 6. 用户输入: 如果1代表红色,2代表绿色,3代表紫色,4代表黄色,请按照此规则生成四色彩虹。它的颜色顺序从上到下是3142 改写输出:图片由四个水平排列的彩色条纹组成,从上到下依次为紫色、红色、黄色和绿色。每个条纹上都居中放置一个白色数字。最上方的紫色条纹上是数字“3”,其下方红色条纹上是数字“1”,再下方黄色条纹上是数字“4”,最下方的绿色条纹上是数字“2”。所有数字均采用无衬线字体,颜色为纯白色,与背景色形成鲜明对比,确保了良好的可读性。条纹的颜色饱和度高,且带有轻微的纹理感,整体排版简洁明了,视觉效果清晰,没有多余的装饰元素,强调了数字信息本身。图片整体清晰度高,色彩准确,风格一致,具有较强的视觉吸引力。 7. 用户输入:石碑上刻着“关关雎鸠,在河之洲”,自然光照,背景是中式园林 改写输出:一块古老的石碑上刻着“关关雎鸠,在河之洲”,石碑表面布满岁月的痕迹,字迹清晰而深刻。自然光线从上方洒下,柔和地照亮石碑的每一个细节,增强了其历史感。背景是一座典雅的中式园林,园林中有翠绿的竹林、蜿蜒的小径和静谧的水池,营造出一种宁静而悠远的氛围。整体画面采用写实风格,细节丰富,光影效果自然,突出了石碑的文化底蕴和园林的古典美。 -# 输出格式 -请直接输出改写优化后的 Prompt 内容,不要包含任何解释性语言或 JSON 格式,不要自行添加开头或结尾的引号。 +# 输出格式 请直接输出改写优化后的 Prompt 内容,不要包含任何解释性语言或 JSON 格式,不要自行添加开头或结尾的引号。 """ From 666941527d816f93009d9b9ee09994c9a0f126ef Mon Sep 17 00:00:00 2001 From: hadoop-imagen Date: Mon, 15 Dec 2025 20:25:17 +0800 Subject: [PATCH 20/20] fix single input rewrite error --- src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py index 1c5854b091ab..a758d545fa4a 100644 --- a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -240,6 +240,7 @@ def __init__( self.tokenizer_max_length = 512 def rewire_prompt(self, prompt, device): + prompt = [prompt] if isinstance(prompt, str) else prompt all_text = [] for each_prompt in prompt: language = get_prompt_language(each_prompt)