Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 75 additions & 12 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ def _segment_ids_pos_to_seqlens_offsets(
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation

# Currently, this function is only exercised for THD qkv_layout.

# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
Expand Down Expand Up @@ -693,26 +695,87 @@ def get_seqlens_and_offsets(
self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq
):
"""
Acquire the seqlens/offsets for cuDNN backend
Acquire the seqlens/offsets for cuDNN backend.
"""
q_segment_ids, kv_segment_ids = self.segment_ids
q_segment_pos, kv_segment_pos = self.segment_pos
assert q_segment_ids.shape == q_segment_pos.shape
assert kv_segment_ids.shape == kv_segment_pos.shape
# No segment_ids/segment_pos
if q_segment_ids.size + kv_segment_ids.size == 0:
return self.seqlens, self.seq_offsets

if qkv_layout.is_thd():
q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets(
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
attn_mask_type,
window_size,
max_segments_per_seq,
# Allow segment_pos to have fewer leading dims than segment_ids if vmapped segment_ids and non-vmapped segment_pos
# e.g. when using from_segment_ids_and_pos() for segment_pos generation from segment_ids it is acceptable to have
# something like : segment_ids (B, batch, seq), segment_pos (batch, seq)).
if q_segment_ids.ndim < q_segment_pos.ndim or kv_segment_ids.ndim < kv_segment_pos.ndim:
raise AssertionError(
"segment_ids must not have fewer dims than segment_pos; got"
f" q_segment_ids.ndim={q_segment_ids.ndim},"
f" q_segment_pos.ndim={q_segment_pos.ndim},"
f" kv_segment_ids.ndim={kv_segment_ids.ndim},"
f" kv_segment_pos.ndim={kv_segment_pos.ndim}"
)
if not (
q_segment_ids.shape[-q_segment_pos.ndim :] == q_segment_pos.shape
and kv_segment_ids.shape[-kv_segment_pos.ndim :] == kv_segment_pos.shape
):
raise AssertionError(
"segment_pos trailing shape must match segment_ids; got"
f" q_segment_ids.shape={q_segment_ids.shape},"
f" q_segment_pos.shape={q_segment_pos.shape},"
f" kv_segment_ids.shape={kv_segment_ids.shape},"
f" kv_segment_pos.shape={kv_segment_pos.shape}"
)
# THD: compute seqlens/offsets.
if qkv_layout.is_thd():
# If there are more leading dims on segment_ids, e.g. vmap
if q_segment_ids.ndim > q_segment_pos.ndim or kv_segment_ids.ndim > kv_segment_pos.ndim:
# Flatten leading batch dims so that segment_ids and segment_pos have the same number of leading dims,
# vmap seqlens/offsets computation with segment_pos broadcast,
# reshape back to the original leading batch dims.
n_extra_batch_dims_q = q_segment_ids.ndim - q_segment_pos.ndim
n_extra_batch_dims_kv = kv_segment_ids.ndim - kv_segment_pos.ndim
extra_batch_shape_q = q_segment_ids.shape[:n_extra_batch_dims_q]
extra_batch_shape_kv = kv_segment_ids.shape[:n_extra_batch_dims_kv]
extra_flat_batch_size_q = jnp.prod(extra_batch_shape_q)
extra_flat_batch_size_kv = jnp.prod(extra_batch_shape_kv)
# vmap below requires same batch size on axis 0 for q_flat and kv_flat; JAX will raise if they differ.
q_flat = q_segment_ids.reshape(
extra_flat_batch_size_q, *q_segment_ids.shape[n_extra_batch_dims_q:]
)
kv_flat = kv_segment_ids.reshape(
extra_flat_batch_size_kv, *kv_segment_ids.shape[n_extra_batch_dims_kv:]
)

def single_extra_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
return _segment_ids_pos_to_seqlens_offsets(
seg_id_q,
seg_id_kv,
seg_pos_q,
seg_pos_kv,
attn_mask_type,
window_size,
max_segments_per_seq,
)

q_sl, kv_sl, q_off, kv_off = jax.vmap(
single_extra_batch, in_axes=(0, 0, None, None)
)(q_flat, kv_flat, q_segment_pos, kv_segment_pos)

q_seqlens = q_sl.reshape(*extra_batch_shape_q, *q_sl.shape[1:])
kv_seqlens = kv_sl.reshape(*extra_batch_shape_kv, *kv_sl.shape[1:])
q_offsets = q_off.reshape(*extra_batch_shape_q, *q_off.shape[1:])
kv_offsets = kv_off.reshape(*extra_batch_shape_kv, *kv_off.shape[1:])
else:
q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets(
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
attn_mask_type,
window_size,
max_segments_per_seq,
)
# BSHD: compute seqlens/offsets.
else:
q_seqlens, kv_seqlens = _segment_ids_to_seqlens(
q_segment_ids,
Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,10 +626,12 @@ def convert_to_2d(offsets, batch, max_seqlen):

@staticmethod
def batcher(batched_args, batch_dims, *, config):
# batch_dims: each element is the batch axis (0, ...) or None. Only 0 or None allowed.
check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims

# Pass through; segment_ids/segment_pos may have different batch dims (e.g. vmapped ids,
# replicated pos). get_seqlens_and_offsets() in attention.py handles conversion without expanding.
out_bdims = q_bdim, q_bdim, seed_bdim
return (
FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
Expand Down Expand Up @@ -1084,7 +1086,7 @@ def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims

# Pass through; segment_ids/segment_pos may have different batch dims. Conversion is in attention.py.
out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim
return (
FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
Expand Down
Loading