From 0c5fc854ea0ae8631847700a67624c5033a84a4b Mon Sep 17 00:00:00 2001 From: Hina Jajoo Date: Wed, 3 Dec 2025 08:52:30 +0000 Subject: [PATCH 1/4] Update requirements to add inference for ltx-video --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0516b9f2..298fbbb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,7 @@ pytest==8.2.2 tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 -git+https://github.com/Lightricks/LTX-Video +git+https://github.com/Lightricks/LTX-Video.git#egg=ltx-video[inference] git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax opencv-python-headless==4.10.0.84 orbax-checkpoint From cdec1b4cb97f5e45b97ae66d8882c53f08285e33 Mon Sep 17 00:00:00 2001 From: Hina Jajoo Date: Sat, 6 Dec 2025 10:22:47 +0000 Subject: [PATCH 2/4] Add missing attention config --- src/maxdiffusion/configs/ltx_video.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 71316ea1..20a7b127 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -20,6 +20,10 @@ frame_rate: 30 max_sequence_length: 512 sampler: "from_checkpoint" +# Attention settings +attention: "standard" +attention_sharding_uniform: False + # Generation parameters pipeline_type: multi-scale prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie." From 8758e293be4848e49e4bfbc947a7da25a2c6dccf Mon Sep 17 00:00:00 2001 From: Hina Jajoo Date: Sat, 6 Dec 2025 10:24:59 +0000 Subject: [PATCH 3/4] Fix latents error --- src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 1b8f4deb..49dc25f2 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -1079,8 +1079,8 @@ def adain_filter_latent(latents: jnp.ndarray, reference_latents: jnp.ndarray, fa jax.Array: The transformed latent tensor. """ with default_env(): - latents = jax.device_put(latents, jax.devices("tpu")[0]) - reference_latents = jax.device_put(reference_latents, jax.devices("tpu")[0]) + latents = jax.device_put(jax.numpy.array(latents), jax.devices("tpu")[0]) + reference_latents = jax.device_put(jax.numpy.array(reference_latents), jax.devices("tpu")[0]) # Define the core AdaIN operation for a single (F, H, W) slice. # This function will be vmapped over batch (B) and channel (C) dimensions. From e5a9df2cd3b06a1e2d447bea4007ab0b9b38e1b0 Mon Sep 17 00:00:00 2001 From: Hina Jajoo Date: Sat, 6 Dec 2025 10:25:31 +0000 Subject: [PATCH 4/4] Start profiling after warmup step --- src/maxdiffusion/configs/ltx_video.yml | 3 +- src/maxdiffusion/generate_ltx_video.py | 48 +++++++++++++++++++------- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 20a7b127..eaa6ffbb 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -103,4 +103,5 @@ compile_topology_num_slices: -1 quantization_local_shard_count: -1 use_qwix_quantization: False jit_initializers: True -enable_single_replica_ckpt_restoring: False \ No newline at end of file +enable_single_replica_ckpt_restoring: False +enable_profiler: True diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6ecc6666..97f9a239 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -20,7 +20,7 @@ from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor -from maxdiffusion import pyconfig, max_logging +from maxdiffusion import pyconfig, max_logging, max_utils import torchvision.transforms.functional as TVF import imageio from datetime import datetime @@ -29,6 +29,7 @@ from pathlib import Path from PIL import Image import torch +import jax def calculate_padding( @@ -206,19 +207,40 @@ def run(config): else None ) + pipeline_args = { + "height": height_padded, + "width": width_padded, + "num_frames": num_frames_padded, + "is_video": True, + "output_type": "pt", + "config": config, + "enhance_prompt": enhance_prompt, + "conditioning_items": conditioning_items, + "seed": config.seed, + } + + + # Warm-up call s0 = time.perf_counter() - images = pipeline( - height=height_padded, - width=width_padded, - num_frames=num_frames_padded, - is_video=True, - output_type="pt", - config=config, - enhance_prompt=enhance_prompt, - conditioning_items=conditioning_items, - seed=config.seed, - ) - max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.") + images = pipeline(**pipeline_args) + max_logging.log(f"Warmup time: {time.perf_counter() - s0:.1f}s.") + + # Normal call + s0 = time.perf_counter() + images = pipeline(**pipeline_args) + max_logging.log(f"Generation time: {time.perf_counter() - s0:.1f}s.") + + # Profiled call + if config.enable_profiler: + profile_timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + profiler_output_path = f"gs://hjajoo-ai-ninja-bucket/ltx-video/profiler_traces/{profile_timestamp}" + jax.profiler.start_trace(profiler_output_path) + max_logging.log(f"JAX profiler started. Traces will be saved to: {profiler_output_path}") + s0 = time.perf_counter() + images = pipeline(**pipeline_args) + jax.profiler.stop_trace() + max_logging.log(f"JAX profiler stopped.") + max_logging.log(f"Generation time with profiler: {time.perf_counter() - s0:.1f}s.") (pad_left, pad_right, pad_top, pad_bottom) = padding pad_bottom = -pad_bottom