@@ -339,7 +339,7 @@ def parse_args(input_args=None):
339339 "--instance_prompt" ,
340340 type = str ,
341341 default = None ,
342- required = True ,
342+ required = False ,
343343 help = "The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'" ,
344344 )
345345 parser .add_argument (
@@ -827,15 +827,28 @@ def __init__(
827827 dest_image = self .cond_images [i ]
828828 image_width , image_height = dest_image .size
829829 if image_width * image_height > 1024 * 1024 :
830- dest_image = Flux2ImageProcessor .image_processor . _resize_to_target_area (dest_image , 1024 * 1024 )
830+ dest_image = Flux2ImageProcessor ._resize_to_target_area (dest_image , 1024 * 1024 )
831831 image_width , image_height = dest_image .size
832832
833833 multiple_of = 2 ** (4 - 1 ) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
834834 image_width = (image_width // multiple_of ) * multiple_of
835835 image_height = (image_height // multiple_of ) * multiple_of
836- dest_image = Flux2ImageProcessor .image_processor .preprocess (
836+ image_processor = Flux2ImageProcessor ()
837+ dest_image = image_processor .preprocess (
837838 dest_image , height = image_height , width = image_width , resize_mode = "crop"
838839 )
840+ # Convert back to PIL
841+ dest_image = dest_image .squeeze (0 )
842+ if dest_image .min () < 0 :
843+ dest_image = (dest_image + 1 ) / 2
844+ dest_image = (torch .clamp (dest_image , 0 , 1 ) * 255 ).byte ().cpu ()
845+
846+ if dest_image .shape [0 ] == 1 :
847+ # Gray scale image
848+ dest_image = Image .fromarray (dest_image .squeeze ().numpy (), mode = "L" )
849+ else :
850+ # RGB scale image: (C, H, W) -> (H, W, C)
851+ dest_image = TF .to_pil_image (dest_image )
839852
840853 dest_image = exif_transpose (dest_image )
841854 if not dest_image .mode == "RGB" :
@@ -1419,9 +1432,9 @@ def _encode_single(prompt: str):
14191432 args .instance_prompt , text_encoding_pipeline
14201433 )
14211434
1422- validation_image = load_image (args .validation_image_path ).convert ("RGB" )
1423- validation_kwargs = {"image" : validation_image }
14241435 if args .validation_prompt is not None :
1436+ validation_image = load_image (args .validation_image_path ).convert ("RGB" )
1437+ validation_kwargs = {"image" : validation_image }
14251438 if args .remote_text_encoder :
14261439 validation_kwargs ["prompt_embeds" ] = compute_remote_text_embeddings (args .validation_prompt )
14271440 else :
0 commit comments