From 6bb59bb18a7f469377cc325db4303055032e748e Mon Sep 17 00:00:00 2001 From: michelle-yooh Date: Tue, 23 Dec 2025 18:55:15 +0000 Subject: [PATCH] Fix block size handling for cross attention in TPU flash attention --- src/maxdiffusion/models/attention_flax.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 555caaf9..9b57e110 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -186,17 +186,19 @@ def _tpu_flash_attention( kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size - if flash_block_sizes: + # ensure that for cross attention we override the block sizes. + if flash_block_sizes and key.shape[1] == query.shape[1]: block_sizes = flash_block_sizes else: + block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(q_max_block_size, query.shape[2]), + block_q=block_size_q, block_kv_compute=min(kv_max_block_size, key.shape[2]), block_kv=min(kv_max_block_size, key.shape[2]), - block_q_dkv=min(q_max_block_size, query.shape[2]), + block_q_dkv=block_size_q, block_kv_dkv=min(kv_max_block_size, key.shape[2]), block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]), - block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq, + block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q, block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]), use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, )