diff --git a/.gitignore b/.gitignore index 8e4e723f..bd4a64b8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,6 @@ __pycache__/ *.py[cod] *$py.class - # C extensions *.so @@ -98,6 +97,7 @@ celerybeat-schedule # Environments .env +.history .venv env/ venv/ diff --git a/docs/attention_blocks_flowchart.md b/docs/attention_blocks_flowchart.md new file mode 100644 index 00000000..69816ac7 --- /dev/null +++ b/docs/attention_blocks_flowchart.md @@ -0,0 +1,30 @@ +# Attention block sizes + +## Description +- "block_q": Block sizes (HBM TO VMEM and VREG) to tile along Q sequence in forward pass +- "block_kv_compute" : Sub Block size (VMEM to VREG) of "block_kv" where compute is performed in forward pass. It must be factor or same as "block_kv" +- "block_kv" : Block sizes (HBM TO VMEM) to tile along KV sequence in forward pass +- "block_q_dkv" : Block sizes along Q sequence in backward pass with fused kernel to compute gradient of q, k , v. It must be factor or same as block_q +- "block_kv_dkv" : Block sizes along KV sequence in backward pass. It must be factor or same as block_kv +- "block_kv_dkv_compute" : Sub Block Sizes of block_kv_dkv, must be factor or same as "block_kv_dkv" +- "block_q_dq" : Block sizes along Q sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_q" +- "block_kv_dq" : Block sizes along KV to tiline on KV sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_kv" +- "use_fused_bwd_kernel" : This means fused bwd kernel is used where DQ, DK, DV are computed in single kernel. It usually more perfomant but comes with slight HBM memory overhead. + +## Flowchart + +Maxdiffusion automatically adheres to this flowchart to ensure working, and there is a log that will inform you on the modifications that maxdiffusion makes to the specified block sizes. + +![alt text](attention_blocks_flowchart.png) + +> "tokamax_flash" uses the splash attention implementation in [tokamax-repo](https://github.com/openxla/tokamax/blob/main/tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel.py) This kernel only supports fused backward pass where gradients for q,k,v are computed in a single kernel so "block_q_dq" and "block_kv_dq" are not used + +## How block sizes matter for perfomance and accuracy + +Block sizes key to saturating HBM bandwidth and ensuring maximum possible overlap of computation on cores with HBM use and VMEM to VREG. It is highly recommended to tune them. + +Block sizes also have an effect on the sequence length. Sequence length is multiple of resolution and number of frames (video), along with VAE scale down factors and patchifying ratios. This sequence length or shard of this sequence length needs to be multiple of the block sizes specified. Therefore maxdiffusion pads the sequence lengths to the nearest multiple of the block sizes. It is advisable to choose block sizes which are factor of sequence length, atleast for the Q block sizes. + +> In cross attention Image or Video tokens are attending to text tokens sequence length of text tokens is really small and potentially smaller than specified block size so KV block sizes are overwritten to safe values. + +> KV block sizes must be multiple of 128 since the size of register is 8x128 and in attention KV sequence dim lies on 128 for the multiplications as K is transposed. \ No newline at end of file diff --git a/docs/attention_blocks_flowchart.png b/docs/attention_blocks_flowchart.png new file mode 100644 index 00000000..bed28e63 Binary files /dev/null and b/docs/attention_blocks_flowchart.png differ diff --git a/preview-xpk.sh b/preview-xpk.sh deleted file mode 100755 index 25a76aa0..00000000 --- a/preview-xpk.sh +++ /dev/null @@ -1,93 +0,0 @@ -#!/bin/bash -bash docker_build_dependency_image.sh -docker tag maxdiffusion_base_image:latest gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest -docker push gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest -CLUSTER_NAME=bodaborg-tpu7x-128 -DEVICE_TYPE=tpu7x-128 # can change to any size <= tpu7x-256 -PROJECT=cloud-tpu-multipod-dev -ZONE=us-central1 - -# Please change the RUN_NAME and OUTPUT_DIR to your own GCS bucket path. -export RUN_NAME=sanbao-wan-v7x-20k-${RANDOM} -OUTPUT_DIR=gs://sanbao-bucket/wan/${RUN_NAME} -# OUTPUT_DIR=gs://sanbao-bucket/wan/sanbao-wan-train-test -DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/ -EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/ -SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/ -RANDOM=123456789 -IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest -# IMAGE_DIR=gcr.io/tpu-prod-env-multipod/maxdiffusion_jax_stable_stack_nightly@sha256:fd27d49a3be7f743f08e3b6b03e5ae00196794944310e3fee2a7795b99d81195 -LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl - -xpk workload create \ ---cluster=$CLUSTER_NAME \ ---project=$PROJECT \ ---zone=$ZONE \ ---device-type=$DEVICE_TYPE \ ---num-slices=1 \ ---command=" \ -pip install . && \ -gsutil cp gs://libtpu-tpu7x-releases/wheels/libtpu/${LIBTPU_VERSION} . && \ -python -m pip install ${LIBTPU_VERSION} && \ -export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true \ ---xla_tpu_enable_async_collective_fusion=true \ ---xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \ ---xla_enable_async_all_reduce=true \ ---xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \ ---xla_max_concurrent_async_all_gathers=4 \ ---xla_tpu_enable_async_all_to_all=true \ ---xla_latency_hiding_scheduler_rerun=5 \ ---xla_tpu_rwb_fusion=false \ ---xla_tpu_enable_sublane_major_scaling_bitcast_fusion=false \ ---xla_tpu_impure_enable_packed_bf16_math_ops=false \ ---xla_tpu_enable_sparse_core_reduce_scatter_v2=true \ ---xla_tpu_enable_sparse_core_collective_offload_all_gather=true \ ---xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \ ---xla_tpu_enable_all_gather_offload_tracing=true \ ---xla_tpu_use_tc_device_shape_on_sc=true \ ---xla_tpu_prefer_async_allgather_to_allreduce=true \ ---xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \ ---xla_tpu_scoped_vmem_limit_kib=65536 \ ---xla_tpu_enable_tpu_custom_call_scoped_vmem_adjustments=true \ ---xla_enable_transpose_trace=false' && \ -echo 'Starting WAN training ...' && \ -HF_HUB_CACHE=/dev/shm python src/maxdiffusion/train_wan.py \ - src/maxdiffusion/configs/base_wan_14b.yml \ - attention='flash' \ - weights_dtype=bfloat16 \ - activations_dtype=bfloat16 \ - guidance_scale=5.0 \ - flow_shift=5.0 \ - fps=16 \ - skip_jax_distributed_system=False \ - run_name='test-wan-training-new' \ - output_dir=${OUTPUT_DIR} \ - train_data_dir=${DATASET_DIR} \ - load_tfrecord_cached=True \ - height=1280 \ - width=720 \ - num_frames=81 \ - num_inference_steps=50 \ - prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \ - jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ - enable_profiler=True \ - dataset_save_location=${SAVE_DATASET_DIR} \ - remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ - flash_min_seq_length=0 \ - seed=$RANDOM \ - skip_first_n_steps_for_profiler=3 \ - profiler_steps=3 \ - per_device_batch_size=0.5 \ - ici_data_parallelism=64 \ - ici_fsdp_parallelism=2 \ - ici_tensor_parallelism=1 \ - allow_split_physical_axes=True \ - max_train_steps=150 \ - scan_layers=true \ - flash_block_sizes='{\"block_q\":2048,\"block_kv_compute\":512,\"block_kv\":2048,\"block_q_dkv\":2048,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":512,\"use_fused_bwd_kernel\":true}' \ - " \ ---base-docker-image=${IMAGE_DIR} \ ---enable-debug-logs \ ---workload=${RUN_NAME} \ ---priority=medium \ ---max-restarts=0 diff --git a/requirements.txt b/requirements.txt index 478359fe..0516b9f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ ftfy tensorboard>=2.17.0 tensorboardx>=2.6.2.2 tensorboard-plugin-profile>=2.15.2 +tokamax Jinja2 scikit-image parameterized diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index 51fe2b8d..71b3735d 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -33,7 +33,11 @@ BlockSizes = splash_attention_kernel.BlockSizes AxisNames = tuple[str, ...] - +# Physical axis names for device meshes. +DATA = "data" +FSDP = "fsdp" +TENSOR = "tensor" +# Logical axis names for model parameters and activations. BATCH = "activation_batch" LENGTH = "activation_length" KV_LENGTH = "activation_kv_length" @@ -48,3 +52,33 @@ WAN2_2 = "wan2.2" WAN_MODEL = WAN2_1 + +# For setting self/cross attention independently in splash kernel +SELF_ATTN_HEAD = "activation_self_attn_heads" +SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length" +SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length" +CROSS_ATTN_HEAD = "activation_cross_attn_heads" +CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length" +CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length" + + +WAN_MODEL = "Wan2.1" + +### Common axis rules for ring attention ### +RING_ATTENTION_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, FSDP], + [SELF_ATTN_KV_LENGTH, FSDP], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, FSDP], + [CROSS_ATTN_KV_LENGTH, FSDP], +] + +SEQUENCE_PARALLEL_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, FSDP], + [SELF_ATTN_KV_LENGTH, None], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, FSDP], + [CROSS_ATTN_KV_LENGTH, None], +] \ No newline at end of file diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index 80daf9ea..7bd8ae70 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -50,6 +50,15 @@ jit_initializers: True from_pt: False split_head_dim: True attention: 'dot_product' # Supported attention: dot_product, flash +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index d02af595..24dffe40 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -49,6 +49,16 @@ jit_initializers: True from_pt: False split_head_dim: True attention: 'dot_product' # Supported attention: dot_product, flash +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True + flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index b535762e..7b224058 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -50,6 +50,16 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True + flash_block_sizes: {} # to override default block sizes for flash attention # flash_block_sizes: diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index a7ca1350..0036b363 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -63,6 +63,15 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True flash_block_sizes: {} # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index 0da843fd..ac0a0f8c 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -63,6 +63,15 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True #flash_block_sizes: {} # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 300ec039..c60dd79e 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -62,6 +62,15 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True flash_block_sizes: { "block_q" : 256, "block_kv_compute" : 256, diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 8cd7e70f..e8146a70 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -60,7 +60,17 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring -flash_min_seq_length: 4096 +flash_min_seq_length: 0 + +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True dropout: 0.1 flash_block_sizes: { @@ -70,7 +80,7 @@ flash_block_sizes: { "block_q_dkv" : 2048, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 512, - "use_fused_bwd_kernel" : True + "use_fused_bwd_kernel": True } # Use on v6e # flash_block_sizes: { @@ -79,11 +89,22 @@ flash_block_sizes: { # "block_kv" : 2048, # "block_q_dkv" : 3024, # "block_kv_dkv" : 2048, -# "block_kv_dkv_compute" : 2048, +# "block_kv_dkv_compute" : 1024, # "block_q_dq" : 3024, # "block_kv_dq" : 2048, # "use_fused_bwd_kernel": False, # } +# Use on v5p +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 1024, +# "block_kv_dkv" : 3072, +# "block_kv_dkv_compute" : 256, +# "block_q_dq" : 1024, +# "block_kv_dq" : 3072 +# } # GroupNorm groups norm_num_groups: 32 @@ -144,8 +165,9 @@ mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], ['activation_batch', 'data'], + ['activation_self_attn_heads', ['fsdp', 'tensor']], + ['activation_cross_attn_q_length', ['fsdp', 'tensor']], ['activation_length', 'fsdp'], - ['activation_heads', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], @@ -279,7 +301,7 @@ flow_shift: 3.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 -fps: 24 +fps: 16 save_final_checkpoint: False # SDXL Lightning parameters diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index ffdf02eb..1b93a32a 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -61,6 +61,15 @@ from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring flash_min_seq_length: 4096 +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True dropout: 0.1 flash_block_sizes: { diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index aa07940e..49e53ae5 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -50,6 +50,15 @@ jit_initializers: True from_pt: False split_head_dim: True attention: 'dot_product' # Supported attention: dot_product, flash +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index ee2e59d5..6f6662b0 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -48,6 +48,15 @@ jit_initializers: True from_pt: False split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True flash_block_sizes: {} # GroupNorm groups norm_num_groups: 32 diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 900ae7f9..48c6ca44 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -501,17 +501,26 @@ def get_flash_block_sizes(config): """Create custom flash attention BlockSizes.""" flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: - use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False) + attention_is_tokamax = "tokamax" in config.attention + user_block_sizes:Dict[str, int] = config.flash_block_sizes + if attention_is_tokamax: + max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel." + "Hence following flash block properties specified will be ignored:" + f"block_q: {user_block_sizes['block_q']}," + f"block_q_dq: {user_block_sizes.get('block_q_dq')}," + f"block_kv_dq: {user_block_sizes.get('block_kv_dq')}," + f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}" + ) flash_block_sizes = splash_attention_kernel.BlockSizes( - block_q=config.flash_block_sizes["block_q"], - block_kv_compute=config.flash_block_sizes["block_kv_compute"], - block_kv=config.flash_block_sizes["block_kv"], - block_q_dkv=config.flash_block_sizes["block_q_dkv"], - block_kv_dkv=config.flash_block_sizes["block_kv_dkv"], - block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"], - block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"), - block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"), - use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"), + block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) if attention_is_tokamax else user_block_sizes["block_q"], + block_kv_compute=user_block_sizes["block_kv_compute"], + block_kv=user_block_sizes["block_kv"], + block_q_dkv=user_block_sizes["block_q_dkv"], + block_kv_dkv=user_block_sizes["block_kv_dkv"], + block_kv_dkv_compute=user_block_sizes["block_kv_dkv_compute"], + block_q_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_q_dq"), + block_kv_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_kv_dq"), + use_fused_bwd_kernel=True if attention_is_tokamax else value_or_none(user_block_sizes, "use_fused_bwd_kernel"), ) return flash_block_sizes @@ -641,4 +650,4 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") else: - jax.distributed.initialize() \ No newline at end of file + jax.distributed.initialize() diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index e95a8b25..520c4071 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -25,6 +25,9 @@ from jax.experimental import shard_map from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel +from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask +from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel +from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel from einops import rearrange from .. import common_types, max_logging @@ -46,6 +49,13 @@ EMBED = common_types.EMBED Quant = quantizations.AqtQuantization +SELF_ATTN_HEAD = common_types.SELF_ATTN_HEAD +SELF_ATTN_Q_LENGTH = common_types.SELF_ATTN_Q_LENGTH +SELF_ATTN_KV_LENGTH = common_types.SELF_ATTN_KV_LENGTH +CROSS_ATTN_HEAD = common_types.CROSS_ATTN_HEAD +CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH +CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH + def _maybe_aqt_einsum(quant: Quant): return jnp.einsum if quant is None else quant.einsum() @@ -163,6 +173,40 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): return tensor, kv_size, seq_len +def convert_to_tokamax_splash_config( block_sizes: BlockSizes, + q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + residual_checkpoint_name: str | None = None, + attn_logits_soft_cap: float | None = None, + fuse_reciprocal: bool = True, + use_base2_exp: bool = False, + max_logit_const: float | None = None, + interpret: bool = False, + dq_reduction_steps: int | None = None) -> tokamax_splash_attention_kernel.SplashConfig: + assert block_sizes.use_fused_bwd_kernel, "Tokamax Splash attention only supports fused bwd kernel." + return tokamax_splash_attention_kernel.SplashConfig( + block_q=block_sizes.block_q, + block_kv=block_sizes.block_kv, + block_kv_compute=block_sizes.block_kv_compute, + block_q_dkv=block_sizes.block_q_dkv, + block_kv_dkv=block_sizes.block_kv_dkv, + block_kv_dkv_compute=block_sizes.block_kv_dkv_compute, + block_q_dq= None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq, + block_kv_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_kv_dq, + use_fused_bwd_kernel=block_sizes.use_fused_bwd_kernel, + q_layout=q_layout, + k_layout=k_layout, + v_layout=v_layout, + residual_checkpoint_name=residual_checkpoint_name, + attn_logits_soft_cap=attn_logits_soft_cap, + fuse_reciprocal=fuse_reciprocal, + use_base2_exp=use_base2_exp, + max_logit_const=max_logit_const, + interpret=interpret, + dq_reduction_steps=dq_reduction_steps, + ) + def _tpu_flash_attention( query: jax.Array, @@ -175,6 +219,7 @@ def _tpu_flash_attention( flash_block_sizes: BlockSizes, dtype: jnp.dtype = jnp.float32, attention_kernel: str = "flash", + mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, ) -> jax.Array: """TPU Flash Attention""" @@ -186,17 +231,19 @@ def _tpu_flash_attention( kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size - if flash_block_sizes: + # ensure that for cross attention we override the block sizes. + if flash_block_sizes and key.shape[1] == query.shape[1]: block_sizes = flash_block_sizes else: + block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(q_max_block_size, query.shape[2]), + block_q=block_size_q, block_kv_compute=min(kv_max_block_size, key.shape[2]), block_kv=min(kv_max_block_size, key.shape[2]), - block_q_dkv=min(q_max_block_size, query.shape[2]), + block_q_dkv=block_size_q, block_kv_dkv=min(kv_max_block_size, key.shape[2]), block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]), - block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq, + block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q, block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]), use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, ) @@ -253,55 +300,98 @@ def wrap_flash_attention(query, key, value): # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, - head_shards=1, # the sizes of the axis is sharding over heads - q_seq_shards=1, # the sizes of the axis is sharding over seq_len - block_sizes=block_sizes, - save_residuals=True if attention_kernel == "ring" else False, - residual_checkpoint_name=residual_checkpoint_name, - ) - vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) + if attention_kernel == "tokamax_flash": + mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),) + splash_kernel = tokamax_splash_attention_kernel.make_splash_mha( + mask=mask, + q_seq_shards=1, # the sizes of the axis is sharding over seq_len + config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), + save_residuals=True if "ring" in attention_kernel else False, + ) + elif attention_kernel == "tokamax_ring": + mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),) + splash_kernel = tokamax_ring_attention_kernel.make_ring_attention( + mask=mask, + is_mqa=False, + config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), + save_residuals=True, + ring_axis="fsdp", + ) + else: + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, # the sizes of the axis is sharding over heads + q_seq_shards=1, # the sizes of the axis is sharding over seq_len + block_sizes=block_sizes, + save_residuals=True if "ring" in attention_kernel else False, + residual_checkpoint_name=residual_checkpoint_name + ) - if attention_kernel == "flash": - attention_output = vmapped_splash(query, key, value, segment_ids) + if attention_kernel == "tokamax_ring": + # For tokamax_ring, use the kernel directly without vmap + # The ring attention kernel handles the ring topology internally + if not mask_padding_tokens: + segment_ids = None + attention_output = splash_kernel( + fwd_mask_info=None, + dkv_mask_info=None, + q=query, + k=key, + v=value, + segment_ids=segment_ids, + is_mqa=False, + config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), + mask_value=-jnp.inf, + mask_function=None, + fwd_mask_sparsity=1.0, + save_residuals=True, + ) else: - if num_fsdp_shards > 1: - out, (lse,) = vmapped_splash(query, key, value, segment_ids) - m = lse.astype(jnp.float32) - l = jnp.exp(lse - m) - o = out.astype(jnp.float32) * l[..., None] + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) - perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] + if not mask_padding_tokens: + segment_ids = None + if attention_kernel in ["flash", "tokamax_flash"]: + attention_output = vmapped_splash(query, key, value, segment_ids) + else: + if num_fsdp_shards > 1: + out, (lse,) = vmapped_splash(query, key, value, segment_ids) + m = lse.astype(jnp.float32) + l = jnp.exp(lse - m) + o = out.astype(jnp.float32) * l[..., None] - k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) - v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) + perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] - def ring_scan_body(carry, _): - m, l, o, k_current, v_current = carry - k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) - v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) + k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) + v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) - out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) + def ring_scan_body(carry, _): + m, l, o, k_current, v_current = carry + k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) + v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) - m_chunk = lse_chunk.astype(jnp.float32) - m_old = m - m = jnp.maximum(m_old, m_chunk) + out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) - exp_m_diff = jnp.exp(m_old - m) - exp_m_chunk_diff = jnp.exp(m_chunk - m) + m_chunk = lse_chunk.astype(jnp.float32) + m_old = m + m = jnp.maximum(m_old, m_chunk) - l = l * exp_m_diff + jnp.exp(lse_chunk - m) - o = o * exp_m_diff[..., None] - o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) + exp_m_diff = jnp.exp(m_old - m) + exp_m_chunk_diff = jnp.exp(m_chunk - m) - # Return the updated state for the next iteration - return (m, l, o, k_next, v_next), None + l = l * exp_m_diff + jnp.exp(lse_chunk - m) + o = o * exp_m_diff[..., None] + o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) - initial_carry = (m, l, o, k1, v1) - (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) + # Return the updated state for the next iteration + return (m, l, o, k_next, v_next), None - attention_output = o_final / l_final[..., None] + initial_carry = (m, l, o, k1, v1) + (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) + + attention_output = o_final / l_final[..., None] + else: + raise ValueError("ring attention requires fsdp > 1") return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) @@ -442,6 +532,7 @@ def _apply_attention( axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dpa_layer: Callable, + mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, ): """Routes to different attention kernels.""" @@ -449,7 +540,7 @@ def _apply_attention( seq_len_idx = 1 if query.ndim == 4: seq_len_idx = 2 - if attention_kernel == "flash": + if attention_kernel in ["flash", "tokamax_flash"]: can_use_flash_attention = ( query.shape[seq_len_idx] >= flash_min_seq_length and key.shape[seq_len_idx] >= flash_min_seq_length @@ -461,7 +552,7 @@ def _apply_attention( return _apply_attention_dot( query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention ) - elif attention_kernel == "flash": + elif attention_kernel in ["flash", "tokamax_flash"]: return _tpu_flash_attention( query, key * scale, @@ -472,11 +563,14 @@ def _apply_attention( axis_names_kv, flash_block_sizes, dtype, + attention_kernel, + mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, ) - elif attention_kernel == "ring": + elif "ring" in attention_kernel: return _tpu_flash_attention( - query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel + query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, + mask_padding_tokens=mask_padding_tokens, ) elif attention_kernel == "cudnn_flash_te": return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) @@ -484,6 +578,7 @@ def _apply_attention( raise ValueError(f"Unexpected attention kernel {attention_kernel=}.") + def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" num_kv, num_heads, k_features = key.shape[-3:] @@ -607,6 +702,7 @@ def __init__( flash_block_sizes: BlockSizes = None, dtype: DType = jnp.float32, quant: Quant = None, + mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, ): self.dpa_layer = None @@ -627,6 +723,7 @@ def __init__( self.flash_block_sizes = flash_block_sizes self.dtype = dtype self.quant = quant + self.mask_padding_tokens = mask_padding_tokens self.residual_checkpoint_name = residual_checkpoint_name def apply_attention(self, query: Array, key: Array, value: Array): @@ -648,6 +745,7 @@ def apply_attention(self, query: Array, key: Array, value: Array): axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, + mask_padding_tokens=self.mask_padding_tokens, residual_checkpoint_name=self.residual_checkpoint_name, ) @@ -737,6 +835,8 @@ def __init__( precision: jax.lax.Precision = None, qkv_bias: bool = False, quant: Quant = None, + is_self_attention: bool = True, + mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, enable_jax_named_scopes: bool = False, ): @@ -750,11 +850,18 @@ def __init__( self.inner_dim = dim_head * heads scale = dim_head**-0.5 self.qk_norm = qk_norm - self.enable_jax_named_scopes = enable_jax_named_scopes self.query_axis_names = query_axis_names self.key_axis_names = key_axis_names self.value_axis_names = value_axis_names self.out_axis_names = out_axis_names + self.enable_jax_named_scopes = enable_jax_named_scopes + + if is_self_attention: + axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV) + axis_names_kv = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, D_KV) + else: + axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) + axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) self.attention_op = NNXAttentionOp( mesh=mesh, @@ -765,10 +872,13 @@ def __init__( use_memory_efficient_attention=use_memory_efficient_attention, split_head_dim=split_head_dim, float32_qk_product=False, + axis_names_q=axis_names_q, + axis_names_kv=axis_names_kv, flash_min_seq_length=flash_min_seq_length, flash_block_sizes=flash_block_sizes, dtype=dtype, quant=quant, + mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, ) # None axes corresponds to the stacked weights across all blocks diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 0226a859..77f35073 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -16,6 +16,7 @@ from typing import Tuple, List, Sequence, Union, Optional +import flax import jax import jax.numpy as jnp from flax import nnx @@ -27,7 +28,7 @@ BlockSizes = common_types.BlockSizes CACHE_T = 2 - +flax.config.update('flax_always_shard_variable', False) # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 128c2203..5d7aec10 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -273,6 +273,7 @@ def __init__( precision: jax.lax.Precision = None, attention: str = "dot_product", dropout: float = 0.0, + mask_padding_tokens: bool = True, enable_jax_named_scopes: bool = False, ): @@ -295,6 +296,8 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=True, + mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", enable_jax_named_scopes=enable_jax_named_scopes, ) @@ -315,6 +318,8 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=False, + mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", enable_jax_named_scopes=enable_jax_named_scopes, ) @@ -362,43 +367,50 @@ def __call__( hidden_states = checkpoint_name(hidden_states, "hidden_states") encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) - # 1. Self-attention - with self.conditional_named_scope("self_attn"): - with self.conditional_named_scope("self_attn_norm"): - norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( - hidden_states.dtype - ) - with self.conditional_named_scope("self_attn_attn"): - attn_output = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_hidden_states, - rotary_emb=rotary_emb, - deterministic=deterministic, - rngs=rngs, - ) - with self.conditional_named_scope("self_attn_residual"): - hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) - - # 2. Cross-attention - norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype) - attn_output = self.attn2( - hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs - ) - hidden_states = hidden_states + attn_output - - # 3. Feed-forward - with self.conditional_named_scope("mlp"): - with self.conditional_named_scope("mlp_norm"): - norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( - hidden_states.dtype - ) - with self.conditional_named_scope("mlp_ffn"): - ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) - with self.conditional_named_scope("mlp_residual"): - hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( - hidden_states.dtype - ) - return hidden_states + # 1. Self-attention + with self.conditional_named_scope("self_attn"): + with self.conditional_named_scope("self_attn_norm"): + norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( + hidden_states.dtype + ) + with self.conditional_named_scope("self_attn_attn"): + attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + rotary_emb=rotary_emb, + deterministic=deterministic, + rngs=rngs, + ) + with self.conditional_named_scope("self_attn_residual"): + hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) + + # 2. Cross-attention + with self.conditional_named_scope("cross_attn"): + with self.conditional_named_scope("cross_attn_norm"): + norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype) + with self.conditional_named_scope("cross_attn_attn"): + attn_output = self.attn2( + hidden_states=norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + deterministic=deterministic, + rngs=rngs, + ) + with self.conditional_named_scope("cross_attn_residual"): + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + with self.conditional_named_scope("mlp"): + with self.conditional_named_scope("mlp_norm"): + norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( + hidden_states.dtype + ) + with self.conditional_named_scope("mlp_ffn"): + ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) + with self.conditional_named_scope("mlp_residual"): + hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( + hidden_states.dtype + ) + return hidden_states class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -435,6 +447,7 @@ def __init__( remat_policy: str = "None", names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = [], + mask_padding_tokens: bool = True, scan_layers: bool = True, enable_jax_named_scopes: bool = False, ): @@ -493,6 +506,8 @@ def init_block(rngs): precision=precision, attention=attention, dropout=dropout, + mask_padding_tokens=mask_padding_tokens, + enable_jax_named_scopes=enable_jax_named_scopes, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) @@ -562,14 +577,15 @@ def __call__( post_patch_width = width // p_w hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) - rotary_emb = self.rope(hidden_states) - - hidden_states = self.patch_embedding(hidden_states) - hidden_states = jax.lax.collapse(hidden_states, 1, -1) - - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image - ) + with self.conditional_named_scope("rotary_embedding"): + rotary_emb = self.rope(hidden_states) + with self.conditional_named_scope("patch_embedding"): + hidden_states = self.patch_embedding(hidden_states) + hidden_states = jax.lax.collapse(hidden_states, 1, -1) + with self.conditional_named_scope("condition_embedder"): + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 153c225d..f2c4a41e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -113,6 +113,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded wan_config["flash_min_seq_length"] = config.flash_min_seq_length wan_config["dropout"] = config.dropout + wan_config["mask_padding_tokens"] = config.mask_padding_tokens wan_config["scan_layers"] = config.scan_layers wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes @@ -533,13 +534,14 @@ def _prepare_model_inputs( batch_size = len(prompt) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) + with jax.named_scope("Encode-Prompt"): + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) num_channel_latents = self._get_num_channel_latents() if latents is None: diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 9488f106..060cc1bf 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -27,7 +27,7 @@ from . import max_logging from . import max_utils from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH -from maxdiffusion.common_types import LENGTH, KV_LENGTH, WAN2_1, WAN2_2 +from maxdiffusion.common_types import LENGTH, KV_LENGTH, WAN2_1, WAN2_2, RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES _ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2} _ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1} @@ -46,7 +46,6 @@ def _validate_training_model_name(model_name: str | None): if model_name not in _ALLOWED_TRAINING_MODEL_NAMES: raise ValueError(f"Invalid config.model_name '{model_name}' for training. Allowed values: {sorted(_ALLOWED_TRAINING_MODEL_NAMES)}") - def string_to_bool(s: str) -> bool: if s.lower() == "true": return True @@ -196,15 +195,29 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. - if raw_keys["attention"] == "ring": + if "ring" in raw_keys["attention"] or raw_keys["attention_sharding_uniform"]: + max_logging.log(f"Adding sequence sharding to q and kv if not already present because '{raw_keys['attention']}' contains 'ring' or {raw_keys['attention_sharding_uniform']} is set.") logical_axis_rules = list(raw_keys["logical_axis_rules"]) + max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") + new_rules = [] q_seq_sharding = (LENGTH, "fsdp") kv_seq_sharding = (KV_LENGTH, "fsdp") if q_seq_sharding not in logical_axis_rules: logical_axis_rules.append(q_seq_sharding) if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) - raw_keys["logical_axis_rules"] = tuple(logical_axis_rules) + if "ring" in raw_keys["attention"]: + for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: + if ring_attention_axis_rule not in logical_axis_rules: + max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") + new_rules.append(ring_attention_axis_rule) + else: # attention contains 'flash' but sequence parallel sharding requested for both self and cross attention + for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES: + if seq_parallel_axis_rule not in logical_axis_rules: + max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}") + new_rules.append(seq_parallel_axis_rule) + raw_keys["logical_axis_rules"] = tuple(new_rules) + tuple(logical_axis_rules) + max_logging.log(f"Final logical axis rules: {raw_keys['logical_axis_rules']}") raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 34f0ef64..5c22c3c8 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -23,7 +23,7 @@ from absl.testing import absltest from flax import nnx from jax.sharding import Mesh - +from flax.linen import partitioning as nn_partitioning from .. import pyconfig from ..max_utils import (create_device_mesh, get_flash_block_sizes) from ..models.wan.transformers.transformer_wan import ( @@ -53,6 +53,18 @@ class WanTransformerTest(unittest.TestCase): def setUp(self): WanTransformerTest.dummy_data = {} + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + self.config = config + devices_array = create_device_mesh(config) + self.mesh = Mesh(devices_array, config.mesh_axes) + def test_rotary_pos_embed(self): batch_size = 1 @@ -70,18 +82,20 @@ def test_nnx_pixart_alpha_text_projection(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_caption = jnp.ones((1, 512, 4096)) - layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) - dummy_output = layer(dummy_caption) - dummy_output.shape == (1, 512, 5120) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) + dummy_output = layer(dummy_caption) + dummy_output.shape == (1, 512, 5120) def test_nnx_timestep_embedding(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_sample = jnp.ones((1, 256)) - layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) - dummy_output = layer(dummy_sample) - assert dummy_output.shape == (1, 5120) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) + dummy_output = layer(dummy_sample) + assert dummy_output.shape == (1, 5120) def test_fp32_layer_norm(self): key = jax.random.key(0) @@ -89,9 +103,10 @@ def test_fp32_layer_norm(self): batch_size = 1 dummy_hidden_states = jnp.ones((batch_size, 75600, 5120)) # expected same output shape with same dtype - layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) - dummy_output = layer(dummy_hidden_states) - assert dummy_output.shape == dummy_hidden_states.shape + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) + dummy_output = layer(dummy_hidden_states) + assert dummy_output.shape == dummy_hidden_states.shape @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_time_text_embedding(self): @@ -102,20 +117,21 @@ def test_wan_time_text_embedding(self): time_freq_dim = 256 time_proj_dim = 30720 text_embed_dim = 4096 - layer = WanTimeTextImageEmbedding( - rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = WanTimeTextImageEmbedding( + rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim + ) - dummy_timestep = jnp.ones(batch_size) + dummy_timestep = jnp.ones(batch_size) - encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) - dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( - dummy_timestep, dummy_encoder_hidden_states - ) - assert temb.shape == (batch_size, dim) - assert timestep_proj.shape == (batch_size, time_proj_dim) - assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) + encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) + dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( + dummy_timestep, dummy_encoder_hidden_states + ) + assert temb.shape == (batch_size, dim) + assert timestep_proj.shape == (batch_size, time_proj_dim) + assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) def test_wan_block(self): key = jax.random.key(0) @@ -163,86 +179,83 @@ def test_wan_block(self): dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim)) dummy_temb = jnp.ones((batch_size, 6, dim)) - - wan_block = WanTransformerBlock( - rngs=rngs, - dim=dim, - ffn_dim=ffn_dim, - num_heads=num_heads, - qk_norm=qk_norm, - cross_attn_norm=cross_attn_norm, - eps=eps, - attention="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) - with mesh: + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_block = WanTransformerBlock( + rngs=rngs, + dim=dim, + ffn_dim=ffn_dim, + num_heads=num_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) assert dummy_output.shape == dummy_hidden_states.shape def test_wan_attention(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - - batch_size = 1 - channels = 16 - frames = 21 - height = 90 - width = 160 - hidden_states_shape = (batch_size, frames, height, width, channels) - dummy_hidden_states = jnp.ones(hidden_states_shape) - wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) - dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) - - key = jax.random.key(0) - rngs = nnx.Rngs(key) - devices_array = create_device_mesh(config) - - flash_block_sizes = get_flash_block_sizes(config) - - mesh = Mesh(devices_array, config.mesh_axes) - batch_size = 1 - query_dim = 5120 - attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) - - dummy_hidden_states_shape = (batch_size, 75600, query_dim) - - dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) - dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) - with mesh: - dummy_output = attention( - hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb - ) - assert dummy_output.shape == dummy_hidden_states_shape - - # dot product - try: - attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="dot_product", - split_head_dim=True, - mesh=mesh, - flash_block_sizes=flash_block_sizes, + for attention_kernel in ["flash", "tokamax_flash"]: + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + f"attention={attention_kernel}" + ], + unittest=True ) - except NotImplementedError: - pass + config = pyconfig.config + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) + dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 + query_dim = 5120 + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + flash_block_sizes = get_flash_block_sizes(config) + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel=attention_kernel, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + dummy_hidden_states_shape = (batch_size, 75600, query_dim) + + dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_output = attention( + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + ) + assert dummy_output.shape == dummy_hidden_states_shape + + # dot product + try: + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="dot_product", + split_head_dim=True, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + except NotImplementedError: + pass @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_model(self): @@ -272,7 +285,8 @@ def test_wan_model(self): mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 num_layers = 1 - wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) + with nn_partitioning.axis_rules(config.logical_axis_rules): + wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 2268411c..b2ffbc3b 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from flax import nnx +from flax.linen import partitioning as nn_partitioning from jax.sharding import Mesh from .. import pyconfig from ..max_utils import ( @@ -163,6 +164,17 @@ class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + self.config = config + devices_array = create_device_mesh(config) + self.mesh = Mesh(devices_array, config.mesh_axes) def test_wanrms_norm(self): """Test against the Pytorch implementation""" @@ -212,12 +224,13 @@ def test_zero_padded_conv(self): output_torch = resample(input) assert output_torch.shape == (1, 96, 240, 360) - model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) - dummy_input = jnp.ones(input_shape) - dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) - output = model(dummy_input) - output = jnp.transpose(output, (0, 3, 1, 2)) - assert output.shape == (1, 96, 240, 360) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + dummy_input = jnp.ones(input_shape) + dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) + output = model(dummy_input) + output = jnp.transpose(output, (0, 3, 1, 2)) + assert output.shape == (1, 96, 240, 360) def test_wan_upsample(self): batch_size = 1 @@ -249,13 +262,13 @@ def test_wan_resample(self): torch_wan_resample = TorchWanResample(dim=dim, mode=mode) torch_output = torch_wan_resample(dummy_input) assert torch_output.shape == (batch, dim, t, h // 2, w // 2) - - wan_resample = WanResample(dim, mode=mode, rngs=rngs) - # channels is always last here - input_shape = (batch, t, h, w, dim) - dummy_input = jnp.ones(input_shape) - output = wan_resample(dummy_input) - assert output.shape == (batch, t, h // 2, w // 2, dim) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_resample = WanResample(dim, mode=mode, rngs=rngs) + # channels is always last here + input_shape = (batch, t, h, w, dim) + dummy_input = jnp.ones(input_shape) + output = wan_resample(dummy_input) + assert output.shape == (batch, t, h // 2, w // 2, dim) def test_3d_conv(self): key = jax.random.key(0) @@ -286,28 +299,29 @@ def test_3d_conv(self): dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels)) # Instantiate the module - causal_conv_layer = WanCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(kernel_d, kernel_h, kernel_w), - padding=(padding_d, padding_h, padding_w), - rngs=rngs, # Pass rngs for initialization, - mesh=mesh, - ) + with self.mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + causal_conv_layer = WanCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs, # Pass rngs for initialization, + mesh=mesh, + ) - # --- Test Case 1: No Cache --- - output_no_cache = causal_conv_layer(dummy_input) - assert output_no_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 1: No Cache --- + output_no_cache = causal_conv_layer(dummy_input) + assert output_no_cache.shape == (1, 10, 32, 32, 16) - # --- Test Case 2: With Cache --- - output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) - assert output_with_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 2: With Cache --- + output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) + assert output_with_cache.shape == (1, 10, 32, 32, 16) - # --- Test Case 3: With Cache larger than padding --- - larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) - dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) - output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) - assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 3: With Cache larger than padding --- + larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) + dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) + output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) + assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) def test_wan_residual(self): key = jax.random.key(0) @@ -331,21 +345,20 @@ def test_wan_residual(self): dim = 96 input_shape = (batch, t, height, width, dim) expected_output_shape = (batch, t, height, width, dim) - - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) - assert dummy_output.shape == expected_output_shape - - # --- Test Case 1: different in/out dim --- - in_dim = 96 - out_dim = 196 - expected_output_shape = (batch, t, height, width, out_dim) - - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) - assert dummy_output.shape == expected_output_shape + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + # --- Test Case 1: different in/out dim --- + in_dim = 96 + out_dim = 196 + expected_output_shape = (batch, t, height, width, out_dim) + + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape def test_wan_attention(self): key = jax.random.key(0) @@ -356,10 +369,11 @@ def test_wan_attention(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) - dummy_input = jnp.ones(input_shape) - output = wan_attention(dummy_input) - assert output.shape == input_shape + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) + dummy_input = jnp.ones(input_shape) + output = wan_attention(dummy_input) + assert output.shape == input_shape def test_wan_midblock(self): key = jax.random.key(0) @@ -380,10 +394,11 @@ def test_wan_midblock(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - output = wan_midblock(dummy_input) - assert output.shape == input_shape + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + output = wan_midblock(dummy_input) + assert output.shape == input_shape def test_wan_decode(self): key = jax.random.key(0) @@ -404,30 +419,31 @@ def test_wan_decode(self): num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - batch = 1 - t = 13 - channels = 16 - height = 60 - width = 90 - input_shape = (batch, t, height, width, channels) - input = jnp.ones(input_shape) - - latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) - latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) - input = input / latents_std + latents_mean - dummy_output = wan_vae.decode(input, feat_cache=vae_cache) - assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=mesh, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + t = 13 + channels = 16 + height = 60 + width = 90 + input_shape = (batch, t, height, width, channels) + input = jnp.ones(input_shape) + + latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) + latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) + input = input / latents_std + latents_mean + dummy_output = wan_vae.decode(input, feat_cache=vae_cache) + assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) def test_wan_encode(self): key = jax.random.key(0) @@ -448,26 +464,27 @@ def test_wan_encode(self): num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - batch = 1 - channels = 3 - t = 49 - height = 480 - width = 720 - input_shape = (batch, channels, t, height, width) - input = jnp.ones(input_shape) - output = wan_vae.encode(input, feat_cache=vae_cache) - assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=mesh, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + channels = 3 + t = 49 + height = 480 + width = 720 + input_shape = (batch, channels, t, height, width) + input = jnp.ones(input_shape) + output = wan_vae.encode(input, feat_cache=vae_cache) + assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) def test_load_checkpoint(self): def vae_encode(video, wan_vae, vae_cache, key): @@ -487,9 +504,9 @@ def vae_encode(video, wan_vae, vae_cache, key): config = pyconfig.config devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - - wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) - vae_cache = AutoencoderKLWanCache(wan_vae) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) + vae_cache = AutoencoderKLWanCache(wan_vae) video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 4369d6d0..f23836a5 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -17,7 +17,7 @@ import os import datetime import functools -from pprint import pprint +import pprint import numpy as np import threading from concurrent.futures import ThreadPoolExecutor diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py index d7457e56..81818d79 100644 --- a/tests/schedulers/test_scheduler_flax.py +++ b/tests/schedulers/test_scheduler_flax.py @@ -335,8 +335,8 @@ def test_full_loop_no_noise(self): result_mean = jnp.mean(jnp.abs(sample)) if jax_device == "tpu": - assert abs(result_sum - 257.29) < 1.5e-2 - assert abs(result_mean - 0.3349905) < 2e-5 + assert abs(result_sum - 263.11) < 1.5e-2 + assert abs(result_mean - 0.34259) < 2e-5 else: assert abs(result_sum - 255.1113) < 1e-2 assert abs(result_mean - 0.332176) < 1e-3