Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pytest==8.2.2
tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
git+https://github.com/Lightricks/LTX-Video
git+https://github.com/Lightricks/LTX-Video.git#egg=ltx-video[inference]
git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax
opencv-python-headless==4.10.0.84
orbax-checkpoint
Expand Down
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ frame_rate: 30
max_sequence_length: 512
sampler: "from_checkpoint"

# Attention settings
attention: "standard"
attention_sharding_uniform: False

# Generation parameters
pipeline_type: multi-scale
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
Expand Down Expand Up @@ -99,4 +103,5 @@ compile_topology_num_slices: -1
quantization_local_shard_count: -1
use_qwix_quantization: False
jit_initializers: True
enable_single_replica_ckpt_restoring: False
enable_single_replica_ckpt_restoring: False
enable_profiler: True
48 changes: 35 additions & 13 deletions src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
from maxdiffusion import pyconfig, max_logging
from maxdiffusion import pyconfig, max_logging, max_utils
import torchvision.transforms.functional as TVF
import imageio
from datetime import datetime
Expand All @@ -29,6 +29,7 @@
from pathlib import Path
from PIL import Image
import torch
import jax


def calculate_padding(
Expand Down Expand Up @@ -206,19 +207,40 @@ def run(config):
else None
)

pipeline_args = {
"height": height_padded,
"width": width_padded,
"num_frames": num_frames_padded,
"is_video": True,
"output_type": "pt",
"config": config,
"enhance_prompt": enhance_prompt,
"conditioning_items": conditioning_items,
"seed": config.seed,
}


# Warm-up call
s0 = time.perf_counter()
images = pipeline(
height=height_padded,
width=width_padded,
num_frames=num_frames_padded,
is_video=True,
output_type="pt",
config=config,
enhance_prompt=enhance_prompt,
conditioning_items=conditioning_items,
seed=config.seed,
)
max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.")
images = pipeline(**pipeline_args)
max_logging.log(f"Warmup time: {time.perf_counter() - s0:.1f}s.")

# Normal call
s0 = time.perf_counter()
images = pipeline(**pipeline_args)
max_logging.log(f"Generation time: {time.perf_counter() - s0:.1f}s.")

# Profiled call
if config.enable_profiler:
profile_timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
profiler_output_path = f"gs://hjajoo-ai-ninja-bucket/ltx-video/profiler_traces/{profile_timestamp}"
jax.profiler.start_trace(profiler_output_path)
max_logging.log(f"JAX profiler started. Traces will be saved to: {profiler_output_path}")
s0 = time.perf_counter()
images = pipeline(**pipeline_args)
jax.profiler.stop_trace()
max_logging.log(f"JAX profiler stopped.")
max_logging.log(f"Generation time with profiler: {time.perf_counter() - s0:.1f}s.")

(pad_left, pad_right, pad_top, pad_bottom) = padding
pad_bottom = -pad_bottom
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,8 +1079,8 @@ def adain_filter_latent(latents: jnp.ndarray, reference_latents: jnp.ndarray, fa
jax.Array: The transformed latent tensor.
"""
with default_env():
latents = jax.device_put(latents, jax.devices("tpu")[0])
reference_latents = jax.device_put(reference_latents, jax.devices("tpu")[0])
latents = jax.device_put(jax.numpy.array(latents), jax.devices("tpu")[0])
reference_latents = jax.device_put(jax.numpy.array(reference_latents), jax.devices("tpu")[0])

# Define the core AdaIN operation for a single (F, H, W) slice.
# This function will be vmapped over batch (B) and channel (C) dimensions.
Expand Down
Loading