diff --git a/requirements.txt b/requirements.txt index 0516b9f2..298fbbb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,7 @@ pytest==8.2.2 tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 -git+https://github.com/Lightricks/LTX-Video +git+https://github.com/Lightricks/LTX-Video.git#egg=ltx-video[inference] git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax opencv-python-headless==4.10.0.84 orbax-checkpoint diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 71316ea1..eaa6ffbb 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -20,6 +20,10 @@ frame_rate: 30 max_sequence_length: 512 sampler: "from_checkpoint" +# Attention settings +attention: "standard" +attention_sharding_uniform: False + # Generation parameters pipeline_type: multi-scale prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie." @@ -99,4 +103,5 @@ compile_topology_num_slices: -1 quantization_local_shard_count: -1 use_qwix_quantization: False jit_initializers: True -enable_single_replica_ckpt_restoring: False \ No newline at end of file +enable_single_replica_ckpt_restoring: False +enable_profiler: True diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6ecc6666..97f9a239 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -20,7 +20,7 @@ from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor -from maxdiffusion import pyconfig, max_logging +from maxdiffusion import pyconfig, max_logging, max_utils import torchvision.transforms.functional as TVF import imageio from datetime import datetime @@ -29,6 +29,7 @@ from pathlib import Path from PIL import Image import torch +import jax def calculate_padding( @@ -206,19 +207,40 @@ def run(config): else None ) + pipeline_args = { + "height": height_padded, + "width": width_padded, + "num_frames": num_frames_padded, + "is_video": True, + "output_type": "pt", + "config": config, + "enhance_prompt": enhance_prompt, + "conditioning_items": conditioning_items, + "seed": config.seed, + } + + + # Warm-up call s0 = time.perf_counter() - images = pipeline( - height=height_padded, - width=width_padded, - num_frames=num_frames_padded, - is_video=True, - output_type="pt", - config=config, - enhance_prompt=enhance_prompt, - conditioning_items=conditioning_items, - seed=config.seed, - ) - max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.") + images = pipeline(**pipeline_args) + max_logging.log(f"Warmup time: {time.perf_counter() - s0:.1f}s.") + + # Normal call + s0 = time.perf_counter() + images = pipeline(**pipeline_args) + max_logging.log(f"Generation time: {time.perf_counter() - s0:.1f}s.") + + # Profiled call + if config.enable_profiler: + profile_timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + profiler_output_path = f"gs://hjajoo-ai-ninja-bucket/ltx-video/profiler_traces/{profile_timestamp}" + jax.profiler.start_trace(profiler_output_path) + max_logging.log(f"JAX profiler started. Traces will be saved to: {profiler_output_path}") + s0 = time.perf_counter() + images = pipeline(**pipeline_args) + jax.profiler.stop_trace() + max_logging.log(f"JAX profiler stopped.") + max_logging.log(f"Generation time with profiler: {time.perf_counter() - s0:.1f}s.") (pad_left, pad_right, pad_top, pad_bottom) = padding pad_bottom = -pad_bottom diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 1b8f4deb..49dc25f2 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -1079,8 +1079,8 @@ def adain_filter_latent(latents: jnp.ndarray, reference_latents: jnp.ndarray, fa jax.Array: The transformed latent tensor. """ with default_env(): - latents = jax.device_put(latents, jax.devices("tpu")[0]) - reference_latents = jax.device_put(reference_latents, jax.devices("tpu")[0]) + latents = jax.device_put(jax.numpy.array(latents), jax.devices("tpu")[0]) + reference_latents = jax.device_put(jax.numpy.array(reference_latents), jax.devices("tpu")[0]) # Define the core AdaIN operation for a single (F, H, W) slice. # This function will be vmapped over batch (B) and channel (C) dimensions.