Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

Expand Down Expand Up @@ -98,6 +97,7 @@ celerybeat-schedule

# Environments
.env
.history
.venv
env/
venv/
Expand Down
30 changes: 30 additions & 0 deletions docs/attention_blocks_flowchart.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Attention block sizes

## Description
- "block_q": Block sizes (HBM TO VMEM and VREG) to tile along Q sequence in forward pass
- "block_kv_compute" : Sub Block size (VMEM to VREG) of "block_kv" where compute is performed in forward pass. It must be factor or same as "block_kv"
- "block_kv" : Block sizes (HBM TO VMEM) to tile along KV sequence in forward pass
- "block_q_dkv" : Block sizes along Q sequence in backward pass with fused kernel to compute gradient of q, k , v. It must be factor or same as block_q
- "block_kv_dkv" : Block sizes along KV sequence in backward pass. It must be factor or same as block_kv
- "block_kv_dkv_compute" : Sub Block Sizes of block_kv_dkv, must be factor or same as "block_kv_dkv"
- "block_q_dq" : Block sizes along Q sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_q"
- "block_kv_dq" : Block sizes along KV to tiline on KV sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_kv"
- "use_fused_bwd_kernel" : This means fused bwd kernel is used where DQ, DK, DV are computed in single kernel. It usually more perfomant but comes with slight HBM memory overhead.

## Flowchart

Maxdiffusion automatically adheres to this flowchart to ensure working, and there is a log that will inform you on the modifications that maxdiffusion makes to the specified block sizes.

![alt text](attention_blocks_flowchart.png)

> "tokamax_flash" uses the splash attention implementation in [tokamax-repo](https://github.com/openxla/tokamax/blob/main/tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel.py) This kernel only supports fused backward pass where gradients for q,k,v are computed in a single kernel so "block_q_dq" and "block_kv_dq" are not used

## How block sizes matter for perfomance and accuracy

Block sizes key to saturating HBM bandwidth and ensuring maximum possible overlap of computation on cores with HBM use and VMEM to VREG. It is highly recommended to tune them.

Block sizes also have an effect on the sequence length. Sequence length is multiple of resolution and number of frames (video), along with VAE scale down factors and patchifying ratios. This sequence length or shard of this sequence length needs to be multiple of the block sizes specified. Therefore maxdiffusion pads the sequence lengths to the nearest multiple of the block sizes. It is advisable to choose block sizes which are factor of sequence length, atleast for the Q block sizes.

> In cross attention Image or Video tokens are attending to text tokens sequence length of text tokens is really small and potentially smaller than specified block size so KV block sizes are overwritten to safe values.

> KV block sizes must be multiple of 128 since the size of register is 8x128 and in attention KV sequence dim lies on 128 for the multiplications as K is transposed.
Binary file added docs/attention_blocks_flowchart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 0 additions & 93 deletions preview-xpk.sh

This file was deleted.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ftfy
tensorboard>=2.17.0
tensorboardx>=2.6.2.2
tensorboard-plugin-profile>=2.15.2
tokamax
Jinja2
scikit-image
parameterized
Expand Down
36 changes: 35 additions & 1 deletion src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
BlockSizes = splash_attention_kernel.BlockSizes

AxisNames = tuple[str, ...]

# Physical axis names for device meshes.
DATA = "data"
FSDP = "fsdp"
TENSOR = "tensor"
# Logical axis names for model parameters and activations.
BATCH = "activation_batch"
LENGTH = "activation_length"
KV_LENGTH = "activation_kv_length"
Expand All @@ -48,3 +52,33 @@
WAN2_2 = "wan2.2"

WAN_MODEL = WAN2_1

# For setting self/cross attention independently in splash kernel
SELF_ATTN_HEAD = "activation_self_attn_heads"
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"


WAN_MODEL = "Wan2.1"

### Common axis rules for ring attention ###
RING_ATTENTION_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, FSDP],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, FSDP],
]

SEQUENCE_PARALLEL_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, None],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, None],
]
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'dot_product' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
10 changes: 10 additions & 0 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'dot_product' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
10 changes: 10 additions & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
# to override default block sizes for flash attention
# flash_block_sizes:
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

#flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {
"block_q" : 256,
"block_kv_compute" : 256,
Expand Down
32 changes: 27 additions & 5 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,17 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
flash_min_seq_length: 4096
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.1

flash_block_sizes: {
Expand All @@ -70,7 +80,7 @@ flash_block_sizes: {
"block_q_dkv" : 2048,
"block_kv_dkv" : 2048,
"block_kv_dkv_compute" : 512,
"use_fused_bwd_kernel" : True
"use_fused_bwd_kernel": True
}
# Use on v6e
# flash_block_sizes: {
Expand All @@ -79,11 +89,22 @@ flash_block_sizes: {
# "block_kv" : 2048,
# "block_q_dkv" : 3024,
# "block_kv_dkv" : 2048,
# "block_kv_dkv_compute" : 2048,
# "block_kv_dkv_compute" : 1024,
# "block_q_dq" : 3024,
# "block_kv_dq" : 2048,
# "use_fused_bwd_kernel": False,
# }
# Use on v5p
# flash_block_sizes: {
# "block_q" : 3024,
# "block_kv_compute" : 1024,
# "block_kv" : 2048,
# "block_q_dkv" : 1024,
# "block_kv_dkv" : 3072,
# "block_kv_dkv_compute" : 256,
# "block_q_dq" : 1024,
# "block_kv_dq" : 3072
# }
# GroupNorm groups
norm_num_groups: 32

Expand Down Expand Up @@ -144,8 +165,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],

['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
Expand Down Expand Up @@ -279,7 +301,7 @@ flow_shift: 3.0
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
fps: 24
fps: 16
save_final_checkpoint: False

# SDXL Lightning parameters
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
flash_min_seq_length: 4096
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.1

flash_block_sizes: {
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'dot_product' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
Loading
Loading