[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692
[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692KshitijLakhani wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
… the TE constructed segment pos are not thereby causing mismatches in impl() Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
|
/te-ci jax L0 L1 L2 |
Greptile SummaryFixed batch dimension mismatch between Confidence Score: 4/5
Important Files Changed
Last reviewed commit: 0967b86 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks!
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
da19f26 to
35d6d0f
Compare
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
| for _ in range(leading_bdim): | ||
| expanded = lax.expand_dims(expanded, (0,)) | ||
| batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) | ||
| updated_batch_dims[seg_pos_idx] = 0 |
There was a problem hiding this comment.
consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0
| updated_batch_dims[seg_pos_idx] = 0 | |
| updated_batch_dims[seg_pos_idx] = seg_id_bdim |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| for _ in range(leading_bdim): | ||
| expanded = lax.expand_dims(expanded, (0,)) | ||
| batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) | ||
| updated_batch_dims[seg_pos_idx] = 0 |
There was a problem hiding this comment.
consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0
| updated_batch_dims[seg_pos_idx] = 0 | |
| updated_batch_dims[seg_pos_idx] = seg_id_bdim |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
What is the bug ?
TE provides a convenience function
from_segment_ids_and_pos()which allows users to pass only segment ids and the function returns aSequenceDescriptorwith internally generated segment pos and passed segment ids.As mentioned in Issue #2685 , if a user were to vmap a function forward() which i) accepts the q,k,v,segment ids and then ii) calls
from_segment_ids_and_pos()followed by iii) a call toDPA(), what happens is that JAX sees the segment ids as vmapped hence an extra leading dimension is added (e.g. 1,2,128) whereas the segment offsets are not given a leading dimension (e.g. 2,128). This results in the FusedAttn primitive impl() assert being triggered due to a shape mismatch between seg ids and seg pos as mentioned in issue #2685What is the root cause for the bug ?
On debugging, it can be seen that the shape starts differing when the batcher is being traced for the FusedAttn primitive.
segment_idsin the primitive: treated as vmapped inputs hence batched → (1, 2, 128).segment_posin the primitive: treated as derived within the function hence not batched → (2, 128).This PR aims to resolve this ensuring that
segment_poshas the same leading batching dims assegment_idsso the end user can vmap wrap the TE API calls without worrying about the batching in TE.Fixes #2685
Type of change
Changes
Ensure that the segment pos leading batch dims match that of segment ids in fused attn primitive's batcher
Checklist: