Skip to content

Commit 7587a50

Browse files
committed
Bugfix for dreambooth flux2 img2img2
1 parent 8d415a6 commit 7587a50

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def parse_args(input_args=None):
339339
"--instance_prompt",
340340
type=str,
341341
default=None,
342-
required=True,
342+
required=False,
343343
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
344344
)
345345
parser.add_argument(
@@ -827,15 +827,28 @@ def __init__(
827827
dest_image = self.cond_images[i]
828828
image_width, image_height = dest_image.size
829829
if image_width * image_height > 1024 * 1024:
830-
dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
830+
dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
831831
image_width, image_height = dest_image.size
832832

833833
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
834834
image_width = (image_width // multiple_of) * multiple_of
835835
image_height = (image_height // multiple_of) * multiple_of
836-
dest_image = Flux2ImageProcessor.image_processor.preprocess(
836+
image_processor = Flux2ImageProcessor()
837+
dest_image = image_processor.preprocess(
837838
dest_image, height=image_height, width=image_width, resize_mode="crop"
838839
)
840+
# Convert back to PIL
841+
dest_image = dest_image.squeeze(0)
842+
if dest_image.min() < 0:
843+
dest_image = (dest_image + 1) / 2
844+
dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
845+
846+
if dest_image.shape[0] == 1:
847+
# Gray scale image
848+
dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
849+
else:
850+
# RGB scale image: (C, H, W) -> (H, W, C)
851+
dest_image = TF.to_pil_image(dest_image)
839852

840853
dest_image = exif_transpose(dest_image)
841854
if not dest_image.mode == "RGB":
@@ -1419,9 +1432,9 @@ def _encode_single(prompt: str):
14191432
args.instance_prompt, text_encoding_pipeline
14201433
)
14211434

1422-
validation_image = load_image(args.validation_image_path).convert("RGB")
1423-
validation_kwargs = {"image": validation_image}
14241435
if args.validation_prompt is not None:
1436+
validation_image = load_image(args.validation_image_path).convert("RGB")
1437+
validation_kwargs = {"image": validation_image}
14251438
if args.remote_text_encoder:
14261439
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
14271440
else:

0 commit comments

Comments
 (0)