Skip to content

Comments

[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692

Open
KshitijLakhani wants to merge 5 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/vmap-get-seg-ids-pos
Open

[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692
KshitijLakhani wants to merge 5 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/vmap-get-seg-ids-pos

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Feb 19, 2026

Description

What is the bug ?

TE provides a convenience function from_segment_ids_and_pos() which allows users to pass only segment ids and the function returns a SequenceDescriptor with internally generated segment pos and passed segment ids.

As mentioned in Issue #2685 , if a user were to vmap a function forward() which i) accepts the q,k,v,segment ids and then ii) calls from_segment_ids_and_pos() followed by iii) a call to DPA(), what happens is that JAX sees the segment ids as vmapped hence an extra leading dimension is added (e.g. 1,2,128) whereas the segment offsets are not given a leading dimension (e.g. 2,128). This results in the FusedAttn primitive impl() assert being triggered due to a shape mismatch between seg ids and seg pos as mentioned in issue #2685

What is the root cause for the bug ?

On debugging, it can be seen that the shape starts differing when the batcher is being traced for the FusedAttn primitive.
segment_ids in the primitive: treated as vmapped inputs hence batched → (1, 2, 128).
segment_pos in the primitive: treated as derived within the function hence not batched → (2, 128).

This PR aims to resolve this ensuring that segment_pos has the same leading batching dims as segment_ids so the end user can vmap wrap the TE API calls without worrying about the batching in TE.

Fixes #2685

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Ensure that the segment pos leading batch dims match that of segment ids in fused attn primitive's batcher

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

… the TE constructed segment pos are not thereby causing mismatches in impl()

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani self-assigned this Feb 19, 2026
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1 L2

@KshitijLakhani KshitijLakhani marked this pull request as ready for review February 20, 2026 06:54
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Greptile Summary

Fixed batch dimension mismatch between segment_ids and segment_pos in FusedAttn primitive's batcher when using vmap. When from_segment_ids_and_pos() is called inside a vmapped function, JAX adds a leading batch dimension to user-provided segment_ids but not to internally-generated segment_pos, causing shape mismatches. The fix expands segment_pos to match segment_ids batch dimensions by checking q and kv pairs independently, allowing users to vmap TE attention APIs without manual batch handling.

Confidence Score: 4/5

  • This PR is safe to merge with minor style improvements suggested
  • The fix correctly addresses the root cause of issue JAX vmap issue with TE Attention #2685 by independently checking and expanding batch dimensions for q and kv segment pairs. The logic properly handles the case where segment_ids has batch dims but segment_pos doesn't. The implementation includes appropriate guards (empty array checks, dimension validation, shape assertions) to prevent errors. Only minor style improvements suggested for consistency.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/attention.py Added batch dimension matching logic for segment_pos to align with segment_ids when vmapped; handles q and kv pairs independently

Last reviewed commit: 0967b86

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

for _ in range(leading_bdim):
expanded = lax.expand_dims(expanded, (0,))
batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape)
updated_batch_dims[seg_pos_idx] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0

Suggested change
updated_batch_dims[seg_pos_idx] = 0
updated_batch_dims[seg_pos_idx] = seg_id_bdim

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

for _ in range(leading_bdim):
expanded = lax.expand_dims(expanded, (0,))
batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape)
updated_batch_dims[seg_pos_idx] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0

Suggested change
updated_batch_dims[seg_pos_idx] = 0
updated_batch_dims[seg_pos_idx] = seg_id_bdim

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

JAX vmap issue with TE Attention

2 participants