@@ -868,6 +868,97 @@ def _cudnn_attention_backward_op(
868868 return grad_query , grad_key , grad_value
869869
870870
871+ # https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15135
872+ # forward declaration:
873+ # aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
874+ def _native_flash_attention_forward_op (
875+ ctx : torch .autograd .function .FunctionCtx ,
876+ query : torch .Tensor ,
877+ key : torch .Tensor ,
878+ value : torch .Tensor ,
879+ attn_mask : Optional [torch .Tensor ] = None ,
880+ dropout_p : float = 0.0 ,
881+ is_causal : bool = False ,
882+ scale : Optional [float ] = None ,
883+ enable_gqa : bool = False ,
884+ return_lse : bool = False ,
885+ _save_ctx : bool = True ,
886+ _parallel_config : Optional ["ParallelConfig" ] = None ,
887+ ):
888+ if enable_gqa :
889+ raise ValueError ("`enable_gqa` is not yet supported for native flash attention." )
890+
891+ tensors_to_save = ()
892+
893+ query = query .transpose (1 , 2 ).contiguous ()
894+ key = key .transpose (1 , 2 ).contiguous ()
895+ value = value .transpose (1 , 2 ).contiguous ()
896+ tensors_to_save += (query , key , value )
897+
898+ out , lse , cum_seq_q , cum_seq_k , max_q , max_k , philox_seed , philox_offset , debug_attn_mask = (
899+ torch .ops .aten ._scaled_dot_product_flash_attention (
900+ query = query ,
901+ key = key ,
902+ value = value ,
903+ dropout_p = dropout_p ,
904+ is_causal = is_causal ,
905+ return_debug_mask = False ,
906+ scale = scale ,
907+ )
908+ )
909+
910+ tensors_to_save += (out , lse , cum_seq_q , cum_seq_k , philox_seed , philox_offset )
911+ if _save_ctx :
912+ ctx .save_for_backward (* tensors_to_save )
913+ ctx .dropout_p = dropout_p
914+ ctx .is_causal = is_causal
915+ ctx .scale = scale
916+ ctx .max_q = max_q
917+ ctx .max_k = max_k
918+
919+ out = out .transpose (1 , 2 ).contiguous ()
920+ if lse is not None :
921+ lse = lse .transpose (1 , 2 ).contiguous ()
922+ return (out , lse ) if return_lse else out
923+
924+
925+ # https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15153
926+ # backward declaration:
927+ # aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
928+ def _native_flash_attention_backward_op (
929+ ctx : torch .autograd .function .FunctionCtx ,
930+ grad_out : torch .Tensor ,
931+ * args ,
932+ ** kwargs ,
933+ ):
934+ query , key , value , out , lse , cum_seq_q , cum_seq_k , philox_seed , philox_offset = ctx .saved_tensors
935+
936+ grad_out = grad_out .transpose (1 , 2 ).contiguous ()
937+ key = key .transpose (1 , 2 ).contiguous ()
938+ value = value .transpose (1 , 2 ).contiguous ()
939+
940+ grad_query , grad_key , grad_value = torch .ops .aten ._scaled_dot_product_flash_attention_backward (
941+ grad_out ,
942+ query ,
943+ key ,
944+ value ,
945+ out ,
946+ logsumexp = lse ,
947+ philox_seed = philox_seed ,
948+ philox_offset = philox_offset ,
949+ cum_seq_q = cum_seq_q ,
950+ cum_seq_k = cum_seq_k ,
951+ max_q = ctx .max_q ,
952+ max_k = ctx .max_k ,
953+ dropout_p = ctx .dropout_p ,
954+ is_causal = ctx .is_causal ,
955+ scale = ctx .scale ,
956+ )
957+ grad_query , grad_key , grad_value = (x .transpose (1 , 2 ).contiguous () for x in (grad_query , grad_key , grad_value ))
958+
959+ return grad_query , grad_key , grad_value
960+
961+
871962# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
872963def _flash_attention_forward_op (
873964 ctx : torch .autograd .function .FunctionCtx ,
@@ -1931,6 +2022,7 @@ def _native_efficient_attention(
19312022@_AttentionBackendRegistry .register (
19322023 AttentionBackendName ._NATIVE_FLASH ,
19332024 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
2025+ supports_context_parallel = True ,
19342026)
19352027def _native_flash_attention (
19362028 query : torch .Tensor ,
@@ -1943,22 +2035,40 @@ def _native_flash_attention(
19432035 return_lse : bool = False ,
19442036 _parallel_config : Optional ["ParallelConfig" ] = None ,
19452037) -> torch .Tensor :
1946- if return_lse :
1947- raise ValueError ("Native flash attention backend does not support setting `return_lse=True`." )
1948- query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
1949- with torch .nn .attention .sdpa_kernel (torch .nn .attention .SDPBackend .FLASH_ATTENTION ):
1950- out = torch .nn .functional .scaled_dot_product_attention (
1951- query = query ,
1952- key = key ,
1953- value = value ,
1954- attn_mask = None , # not supported
1955- dropout_p = dropout_p ,
1956- is_causal = is_causal ,
1957- scale = scale ,
1958- enable_gqa = enable_gqa ,
2038+ lse = None
2039+ if _parallel_config is None and not return_lse :
2040+ query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
2041+ with torch .nn .attention .sdpa_kernel (torch .nn .attention .SDPBackend .FLASH_ATTENTION ):
2042+ out = torch .nn .functional .scaled_dot_product_attention (
2043+ query = query ,
2044+ key = key ,
2045+ value = value ,
2046+ attn_mask = None , # not supported
2047+ dropout_p = dropout_p ,
2048+ is_causal = is_causal ,
2049+ scale = scale ,
2050+ enable_gqa = enable_gqa ,
2051+ )
2052+ out = out .permute (0 , 2 , 1 , 3 )
2053+ else :
2054+ out = _templated_context_parallel_attention (
2055+ query ,
2056+ key ,
2057+ value ,
2058+ None ,
2059+ dropout_p ,
2060+ is_causal ,
2061+ scale ,
2062+ enable_gqa ,
2063+ return_lse ,
2064+ forward_op = _native_flash_attention_forward_op ,
2065+ backward_op = _native_flash_attention_backward_op ,
2066+ _parallel_config = _parallel_config ,
19592067 )
1960- out = out .permute (0 , 2 , 1 , 3 )
1961- return out
2068+ if return_lse :
2069+ out , lse = out
2070+
2071+ return (out , lse ) if return_lse else out
19622072
19632073
19642074@_AttentionBackendRegistry .register (
0 commit comments