From 7587a50f87719b18811ee1ca5a76f9b9aaf9419b Mon Sep 17 00:00:00 2001 From: js1234567 Date: Thu, 11 Dec 2025 15:09:22 +0800 Subject: [PATCH] Bugfix for dreambooth flux2 img2img2 --- .../train_dreambooth_lora_flux2_img2img.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 32bce9531b71..0de575eeec7f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -339,7 +339,7 @@ def parse_args(input_args=None): "--instance_prompt", type=str, default=None, - required=True, + required=False, help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) parser.add_argument( @@ -827,15 +827,28 @@ def __init__( dest_image = self.cond_images[i] image_width, image_height = dest_image.size if image_width * image_height > 1024 * 1024: - dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024) + dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024) image_width, image_height = dest_image.size multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp! image_width = (image_width // multiple_of) * multiple_of image_height = (image_height // multiple_of) * multiple_of - dest_image = Flux2ImageProcessor.image_processor.preprocess( + image_processor = Flux2ImageProcessor() + dest_image = image_processor.preprocess( dest_image, height=image_height, width=image_width, resize_mode="crop" ) + # Convert back to PIL + dest_image = dest_image.squeeze(0) + if dest_image.min() < 0: + dest_image = (dest_image + 1) / 2 + dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu() + + if dest_image.shape[0] == 1: + # Gray scale image + dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L") + else: + # RGB scale image: (C, H, W) -> (H, W, C) + dest_image = TF.to_pil_image(dest_image) dest_image = exif_transpose(dest_image) if not dest_image.mode == "RGB": @@ -1419,9 +1432,9 @@ def _encode_single(prompt: str): args.instance_prompt, text_encoding_pipeline ) - validation_image = load_image(args.validation_image_path).convert("RGB") - validation_kwargs = {"image": validation_image} if args.validation_prompt is not None: + validation_image = load_image(args.validation_image_path).convert("RGB") + validation_kwargs = {"image": validation_image} if args.remote_text_encoder: validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt) else: