Skip to content

Commit 17c0e79

Browse files
sywangyisayakpaul
andauthored
support CP in native flash attention (#12829)
Signed-off-by: Wang, Yi <yi.a.wang@intel.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 1567243 commit 17c0e79

File tree

1 file changed

+125
-15
lines changed

1 file changed

+125
-15
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 125 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,97 @@ def _cudnn_attention_backward_op(
868868
return grad_query, grad_key, grad_value
869869

870870

871+
# https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15135
872+
# forward declaration:
873+
# aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
874+
def _native_flash_attention_forward_op(
875+
ctx: torch.autograd.function.FunctionCtx,
876+
query: torch.Tensor,
877+
key: torch.Tensor,
878+
value: torch.Tensor,
879+
attn_mask: Optional[torch.Tensor] = None,
880+
dropout_p: float = 0.0,
881+
is_causal: bool = False,
882+
scale: Optional[float] = None,
883+
enable_gqa: bool = False,
884+
return_lse: bool = False,
885+
_save_ctx: bool = True,
886+
_parallel_config: Optional["ParallelConfig"] = None,
887+
):
888+
if enable_gqa:
889+
raise ValueError("`enable_gqa` is not yet supported for native flash attention.")
890+
891+
tensors_to_save = ()
892+
893+
query = query.transpose(1, 2).contiguous()
894+
key = key.transpose(1, 2).contiguous()
895+
value = value.transpose(1, 2).contiguous()
896+
tensors_to_save += (query, key, value)
897+
898+
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
899+
torch.ops.aten._scaled_dot_product_flash_attention(
900+
query=query,
901+
key=key,
902+
value=value,
903+
dropout_p=dropout_p,
904+
is_causal=is_causal,
905+
return_debug_mask=False,
906+
scale=scale,
907+
)
908+
)
909+
910+
tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
911+
if _save_ctx:
912+
ctx.save_for_backward(*tensors_to_save)
913+
ctx.dropout_p = dropout_p
914+
ctx.is_causal = is_causal
915+
ctx.scale = scale
916+
ctx.max_q = max_q
917+
ctx.max_k = max_k
918+
919+
out = out.transpose(1, 2).contiguous()
920+
if lse is not None:
921+
lse = lse.transpose(1, 2).contiguous()
922+
return (out, lse) if return_lse else out
923+
924+
925+
# https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15153
926+
# backward declaration:
927+
# aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
928+
def _native_flash_attention_backward_op(
929+
ctx: torch.autograd.function.FunctionCtx,
930+
grad_out: torch.Tensor,
931+
*args,
932+
**kwargs,
933+
):
934+
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
935+
936+
grad_out = grad_out.transpose(1, 2).contiguous()
937+
key = key.transpose(1, 2).contiguous()
938+
value = value.transpose(1, 2).contiguous()
939+
940+
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward(
941+
grad_out,
942+
query,
943+
key,
944+
value,
945+
out,
946+
logsumexp=lse,
947+
philox_seed=philox_seed,
948+
philox_offset=philox_offset,
949+
cum_seq_q=cum_seq_q,
950+
cum_seq_k=cum_seq_k,
951+
max_q=ctx.max_q,
952+
max_k=ctx.max_k,
953+
dropout_p=ctx.dropout_p,
954+
is_causal=ctx.is_causal,
955+
scale=ctx.scale,
956+
)
957+
grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value))
958+
959+
return grad_query, grad_key, grad_value
960+
961+
871962
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
872963
def _flash_attention_forward_op(
873964
ctx: torch.autograd.function.FunctionCtx,
@@ -1931,6 +2022,7 @@ def _native_efficient_attention(
19312022
@_AttentionBackendRegistry.register(
19322023
AttentionBackendName._NATIVE_FLASH,
19332024
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
2025+
supports_context_parallel=True,
19342026
)
19352027
def _native_flash_attention(
19362028
query: torch.Tensor,
@@ -1943,22 +2035,40 @@ def _native_flash_attention(
19432035
return_lse: bool = False,
19442036
_parallel_config: Optional["ParallelConfig"] = None,
19452037
) -> torch.Tensor:
1946-
if return_lse:
1947-
raise ValueError("Native flash attention backend does not support setting `return_lse=True`.")
1948-
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1949-
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
1950-
out = torch.nn.functional.scaled_dot_product_attention(
1951-
query=query,
1952-
key=key,
1953-
value=value,
1954-
attn_mask=None, # not supported
1955-
dropout_p=dropout_p,
1956-
is_causal=is_causal,
1957-
scale=scale,
1958-
enable_gqa=enable_gqa,
2038+
lse = None
2039+
if _parallel_config is None and not return_lse:
2040+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
2041+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
2042+
out = torch.nn.functional.scaled_dot_product_attention(
2043+
query=query,
2044+
key=key,
2045+
value=value,
2046+
attn_mask=None, # not supported
2047+
dropout_p=dropout_p,
2048+
is_causal=is_causal,
2049+
scale=scale,
2050+
enable_gqa=enable_gqa,
2051+
)
2052+
out = out.permute(0, 2, 1, 3)
2053+
else:
2054+
out = _templated_context_parallel_attention(
2055+
query,
2056+
key,
2057+
value,
2058+
None,
2059+
dropout_p,
2060+
is_causal,
2061+
scale,
2062+
enable_gqa,
2063+
return_lse,
2064+
forward_op=_native_flash_attention_forward_op,
2065+
backward_op=_native_flash_attention_backward_op,
2066+
_parallel_config=_parallel_config,
19592067
)
1960-
out = out.permute(0, 2, 1, 3)
1961-
return out
2068+
if return_lse:
2069+
out, lse = out
2070+
2071+
return (out, lse) if return_lse else out
19622072

19632073

19642074
@_AttentionBackendRegistry.register(

0 commit comments

Comments
 (0)