Skip to content

[None][Perf] Multi-stream attention, fuse rmsnorm add, fuse swiglu#11362

Closed
suyoggupta wants to merge 18 commits intoNVIDIA:mainfrom
nv-auto-deploy:sg/glm-opt
Closed

[None][Perf] Multi-stream attention, fuse rmsnorm add, fuse swiglu#11362
suyoggupta wants to merge 18 commits intoNVIDIA:mainfrom
nv-auto-deploy:sg/glm-opt

Conversation

@suyoggupta
Copy link
Collaborator

@suyoggupta suyoggupta commented Feb 8, 2026

Summary by CodeRabbit

  • New Features

    • Added GLM-4.7-Flash model deployment with comprehensive setup guide and configuration.
    • Introduced DeepSeekV3 and GLM4 MoE Lite model implementations for inference.
    • Added SwiGLU MLP fusion optimization for faster model execution.
    • Implemented Multi-Head Latent Attention (MLA) backend with FlashInfer support.
    • Added interleaved RoPE optimization with Triton kernel support.
    • Extended NVFP4 quantization support across models.
  • Refactor

    • Restructured attention operation implementations for better modularity.
    • Enhanced multi-stream execution for parallel computation.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

lucaslie and others added 15 commits February 6, 2026 06:47
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Co-authored-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
@suyoggupta suyoggupta requested review from a team as code owners February 8, 2026 00:27
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 8, 2026

📝 Walkthrough

Walkthrough

This pull request introduces comprehensive support for GLM-4.7-Flash deployment, replaces the MLA attention backend with FlashInfer-based implementations, adds SwiGLU MLP custom operations with fusion transforms, implements new custom model variants (DeepSeekV3, GLM4 MoE Lite), enhances RoPE with interleaved Q/K support, and adds multi-stream processing optimizations. The changes span deployment configurations, custom operators, graph transformations, and extensive test coverage.

Changes

Cohort / File(s) Summary
GLM-4.7-Flash Deployment
examples/auto_deploy/cookbooks/glm_4.7_flash_trtllm_cookbook.ipynb, examples/auto_deploy/glm_flash.yaml, examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml, examples/auto_deploy/model_registry/models.yaml
New deployment cookbook and model registry entries for GLM-4.7-Flash with TensorRT-LLM, including container setup, environment verification, server launch, and API usage examples with chunked prefill and multi-stream MoE enabled.
SwiGLU Custom Operations
tensorrt_llm/_torch/auto_deploy/custom_ops/linear/swiglu.py, tensorrt_llm/_torch/auto_deploy/custom_ops/linear/__init__.py
New SwiGLU MLP custom ops supporting standard, fused, and NVFP4-quantized variants with fake implementations for tracing.
SwiGLU Fusion Transform
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_swiglu.py
New two-stage SwiGLU fusion: pattern matching (MatchSwiGLUPattern) and weight concatenation fusion (FuseSwiGLU) for both standard and NVFP4 paths with dead-code cleanup.
MLA Backend Replacement
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py, tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py, tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py, tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py
Complete replacement of MLA backend: FlashInfer-based paged KV cache implementation with planning/metadata hooks, Torch reference backend with weight absorption logic, and source attention operation with support for layout variations and causal masking.
MLA Legacy Removal
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/mla.py
Removal of legacy fused MLA implementation and MultiHeadLatentAttention registry entry in favor of new backends.
Attention KV Cache Refactoring
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.py, tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py
Removal of public KV cache update and MLA variants from torch_attention.py; migration to private _update_kv_cache in torch_backend_attention.py with extended signature.
DeepSeekV3 Custom Model
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek.py, tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
New export-friendly DeepSeekV3 implementation with RMS normalization, rotary embeddings (including YaRN), SwiGLU-based MLP, MoE support, and multi-head latent attention using custom ops; integrated with AutoModelForCausalLMFactory.
GLM4 MoE Lite Custom Model
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_glm4_moe_lite.py, tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py
New GLM4 MoE Lite model with configuration registration, RMS normalization, rotary embeddings with YaRN support, MLA-based attention, fused MoE, and causal language modeling head for export deployments.
DeepSeek Patches Removal
tensorrt_llm/_torch/auto_deploy/models/patches/deepseek.py
Complete removal of deprecated monkey-patching system for DeepSeek attention, MoE, and RoPE in favor of custom model implementations.
RoPE Enhancements
tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope.py, tensorrt_llm/_torch/auto_deploy/custom_ops/rope/triton_rope_kernel.py, tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
New interleaved Q/K RoPE path: custom op apply_rope_on_interleaved_qk_inputs, Triton kernel rope_fwd_interleaved_kernel, and optimizer integration with input validation and cache tracing.
Multi-stream Processing
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_attn.py, tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py
New MultiStreamMLAAttn transform for auxiliary-stream KV projection parallelization; refactored MultiStreamMoE with lazy auxiliary-op derivation and generic record_event_passthrough synchronization.
Quantization & Fusion Updates
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py, tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py, tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
MoE fusion with per-expert NVFP4 scale handling; FusedAddRmsNorm refactored to direct FX graph manipulation for robustness in multi-user scenarios; quantization config expanded to exclude MLP gates.
Config & Infrastructure
tensorrt_llm/_torch/auto_deploy/config/default.yaml, tensorrt_llm/_torch/auto_deploy/utils/_graph.py, tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Default transforms updated with SwiGLU/NVFP4 pattern matching and fusion, MLA backend switches to flashinfer_mla; new create_derived_custom_op utility for dynamic op derivation; operator documentation removed.
MLA Operation Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_flashinfer_mla_op.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_torch_mla_op.py
Comprehensive test suites for FlashInfer and Torch MLA backends covering context/decode/chunked-prefill paths, variable sequence lengths, CUDA graph caching, paged/unpaged cache handling, and descriptor integration.
RoPE Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/rope/test_triton_rope.py
New tests for interleaved Q/K Triton RoPE comparing against PyTorch reference across batch/sequence/head configurations and dtypes.
Multi-stream Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream_attn.py
Tests for multi-stream MoE synchronization and MLA KV/Q parallelization including fork-point detection, shared input handling, and CUDA graph compatibility.
Custom Model Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_glm4_moe_lite_modeling.py
New test modules validating DeepSeekV3 and GLM4 MoE Lite component behavior (normalization, embeddings, attention, MoE) and HuggingFace interoperability with graph export verification.
SwiGLU & Fusion Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_swiglu.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_nvfp4_swiglu.py
Tests for SwiGLU pattern matching and fusion (with/without bias, multiple layers) and NVFP4-quantized SwiGLU paths with numerical equivalence validation.
Integration & Cleanup
tests/integration/defs/accuracy/test_llm_api_autodeploy.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_create_derived_custom_op.py, tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_update_kv_cache.py
New GLM4Flash accuracy test class; interleaved RoPE optimization test; refactored add-RMS-norm fusion test with multi-user patterns; derived custom op creation tests; DeepSeek config updates for flashinfer_mla compatibility; removed legacy DeepSeek patches and SDPA MLA tests.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is completely empty except for the template structure. No actual explanation of changes, rationale, test coverage, or implementation details are provided. Add a description explaining the changes, their purpose, performance benefits, and list the relevant tests that cover these new features.
Docstring Coverage ⚠️ Warning Docstring coverage is 55.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies three main performance optimizations being added: multi-stream attention, RMSNorm+add fusion, and SwiGLU fusion.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 16

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)

1-6: ⚠️ Potential issue | 🟡 Minor

Missing NVIDIA copyright header.

Per coding guidelines, all TensorRT-LLM .py source files should contain an NVIDIA copyright header with the year of latest meaningful modification. The file begins directly with import math. As per coding guidelines: "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification."

tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Update copyright year to 2026.

The coding guidelines require the copyright header to contain the year of latest meaningful modification. Since this file is being substantially rewritten in 2026, the header should reflect that.

-# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

As per coding guidelines: "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of latest meaningful modification."

🤖 Fix all issues with AI agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 41-62: The type annotation on _update_kv_cache is wrong: the
function mutates k_cache and v_cache in-place and does not return anything, so
change the signature return annotation from "-> torch.Tensor" to "-> None" (or
remove the return annotation) for correctness; update the function signature for
_update_kv_cache and, optionally, adjust the docstring to state it mutates
k_cache/v_cache in-place to make the behavior explicit.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py`:
- Around line 1-21: Add the required NVIDIA copyright header at the top of this
__init__ module: insert the standard multi-line NVIDIA copyright/license header
before the module docstring so the file begins with the header, keeping the
existing docstring and imports unchanged; this affects the module that exports
TorchBackendMLAAttention, FlashInferMLAAttention, torch_mla,
torch_backend_mla_with_cache, and flashinfer_mla_with_cache.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py`:
- Line 452: The custom op decorator for auto_deploy::flashinfer_mla_with_cache
incorrectly declares mutates_args=() even though
flashinfer.page.append_paged_mla_kv_cache mutates the KV caches; update the
`@torch.library.custom_op` on flashinfer_mla_with_cache to list the mutated
arguments (e.g. mutates_args=("ckv_cache","kpe_cache") or the correct parameter
names used in the function signature) so the tracer/compiler knows ckv_cache and
kpe_cache are modified in-place.
- Around line 1-18: Add the required NVIDIA copyright header at the top of the
file (before the module docstring) so the file complies with project guidelines;
update the header to include the standard NVIDIA ownership and license lines
used across the repo, then ensure the existing module-level docstring
(containing FlashInferMLAAttention and flashinfer_mla_with_cache descriptions)
follows the header without modifications.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py`:
- Line 297: The decorator on the custom op
auto_deploy::torch_cached_mla_with_cache incorrectly declares mutates_args=()
even though mla_cache is modified; update the `@torch.library.custom_op`(... )
declaration for "auto_deploy::torch_cached_mla_with_cache" to list the mutated
argument (e.g. mutates_args=("mla_cache",) or the appropriate positional index)
so Torch knows mla_cache is mutated in-place by
_torch_mla_generate_with_absorption and _update_mla_cache.
- Around line 1-18: Add the required NVIDIA copyright header to this source file
by inserting the standard multi-line copyright/license block at the top of
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py (before the
module docstring); ensure the header format matches other files in the repo and
remains above the existing docstring and imports so functions/classes like
torch_cached_mla_with_cache and TorchBackendMLAAttention retain their original
locations.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py`:
- Around line 1-7: The file torch_mla.py is missing the required NVIDIA
copyright header; add the project's standard NVIDIA copyright/header block at
the very top of the file (above the module docstring) following the project's
template, including the correct copyright year and organization name, and ensure
it matches other source files in the repo (use existing headers in files like
other modules under tensorrt_llm/_torch as reference).

In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek.py`:
- Line 1: Update the copyright header year from 2025 to 2026 at the top of the
file (the existing comment line starting with "# Copyright (c) 2025, NVIDIA
CORPORATION. All rights reserved."); modify that line to read 2026 so the file's
header reflects the correct year.

In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_glm4_moe_lite.py`:
- Around line 423-445: In forward (modeling_glm4_moe_lite.py) the variable
shared_output is referenced unconditionally but only set inside the if
self.shared_experts is not None branch; ensure shared_output is always defined
by initializing it before the conditional (e.g., shared_output =
torch.zeros_like(identity) or None) or by only adding it to final_hidden_states
when it exists (guard the addition), so update the forward method (references:
forward, self.shared_experts, shared_output, identity, final_hidden_states) to
avoid a NameError when self.shared_experts is None.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_attn.py`:
- Around line 217-218: Remove the debug file write that dumps gm.graph to a
predictable /tmp path: delete the with
open("/tmp/after_multi_stream_mla_attn.txt", "w") block (or wrap it behind a
dedicated debug flag/env check) so the library no longer writes gm.graph to disk
on every invocation; locate the snippet that writes str(gm.graph) to
"/tmp/after_multi_stream_mla_attn.txt" in multi_stream_attn.py and remove or
gate it to avoid insecure/polluting I/O on the hot path.
- Around line 1-12: Add the required NVIDIA copyright header to the top of the
file (above the module docstring) with the appropriate year of latest meaningful
modification and the standard NVIDIA header text used in this repo; update the
header year if this file was modified recently. Locate the file/module
identified by multi_stream_attn.py and ensure the header appears before the
existing triple-quoted docstring so it applies to the entire source file,
matching the repository's copyright header format used across other TensorRT-LLM
sources.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_flashinfer_mla_op.py`:
- Around line 773-778: The context-phase assertion in test_flashinfer_mla_op.py
is using tighter tolerances (atol=0.01, rtol=0.01) on the tensors
flashinfer_output_context_reshaped and torch_output_context_reshaped which can
cause flakes for large batch/sequence sizes; update this assertion to use the
same tolerances as the other context test (atol=0.05, rtol=0.02) so the
comparison in that block matches the rest of the suite.
- Around line 2135-2138: The test function
test_flashinfer_mla_init_decode_wrapper_with_buffers declares and parametrizes
dtype (torch.bfloat16) but never uses it; either remove the dtype param from the
`@pytest.mark.parametrize` and the function signature, or actually apply dtype
when constructing tensors in that test (e.g., pass dtype to
torch.tensor/torch.empty calls). Locate the test by the function name
test_flashinfer_mla_init_decode_wrapper_with_buffers and update the
decorator/signature to drop dtype if unused, or update all tensor creations
inside the function to use the dtype parameter so the parameterization is
meaningful.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_torch_mla_op.py`:
- Around line 33-34: Change the parameters that are currently annotated as
"scale: float = None" and "kv_lora_rank: int = None" to use explicit Optional
type hints (e.g. scale: Optional[float] = None, kv_lora_rank: Optional[int] =
None) and add "from typing import Optional" at the top of the module (or use the
3.10 union syntax if you prefer but the codebase targets 3.8+, so use
typing.Optional). This ensures the function signature (the parameters named
scale and kv_lora_rank) is PEP 484-compliant and avoids implicit Optional
typing.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py`:
- Around line 445-449: The test creates batch_info_host on the compute device
but it should be on CPU; change the allocation of batch_info_host in the test
(variable name batch_info_host in
tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py)
to use device="cpu" so that torch_backend_mla.py (which calls
batch_info_host.tolist()) does not trigger an implicit GPU->CPU transfer; keep
the other tensors on the device as-is.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_swiglu.py`:
- Line 5: Replace the wildcard import from
tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu with an explicit import
or a namespace import: either import the module (e.g., import
tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu as swiglu) or
explicitly import the required symbols (e.g., torch_swiglu_mlp) so the test file
no longer uses `from ...swiglu import *`; if the import is only performed for
module-level side effects, keep a module import and add a `# noqa: F401` comment
to indicate an intentional unused-import-side-effect.
🧹 Nitpick comments (34)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)

689-693: all_scales_equal performs three GPU comparisons eagerly—consider documenting or simplifying.

torch.all(...) returns a GPU tensor that is implicitly converted to bool in Python and expressions, causing three separate GPU→CPU syncs. This is fine at transform time (not on the inference hot path), but worth a brief comment noting this runs once during graph rewriting.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_update_kv_cache.py (1)

31-40: Test has no assertions — only smoke-tests the function.

The call signature update is correct, but the test only prints cache state without verifying correctness. Consider adding assertions to validate that k_cache and v_cache contain the expected values after the update.

Example assertion
     _update_kv_cache(
         k.view(batch_size * seq_length, n_heads, K_D_HEAD),
         v.view(batch_size * seq_length, n_heads, V_D_HEAD),
         k_cache,
         v_cache,
         torch.tensor([3, 1]).long(),
         torch.tensor([0, 0]),
         slot_idx=torch.tensor([0, 1]),
         seq_start=torch.tensor([0, 3]).long(),
     )
 
-    print("k_cache: " + str(k_cache))
-    print("v_cache: " + str(v_cache))
+    # Slot 0 should have tokens at positions 0-2, slot 1 at position 0
+    assert torch.all(k_cache[0, :3] == 1.0), "k_cache slot 0 not updated correctly"
+    assert torch.all(k_cache[0, 3:] == 0.0), "k_cache slot 0 has unexpected values"
+    assert torch.all(k_cache[1, 0] == 1.0), "k_cache slot 1 not updated correctly"
+    assert torch.all(k_cache[1, 1:] == 0.0), "k_cache slot 1 has unexpected values"
+    assert torch.all(v_cache[0, :3] == 1.0), "v_cache slot 0 not updated correctly"
+    assert torch.all(v_cache[1, 0] == 1.0), "v_cache slot 1 not updated correctly"
examples/auto_deploy/cookbooks/glm_4.7_flash_trtllm_cookbook.ipynb (1)

74-95: Committed output cells leak environment-specific details.

The output cell shows 8× H100 GPUs, Python 3.12.3, etc. This is fine for a demo notebook, but consider clearing outputs before committing so users aren't confused when their environment differs (e.g., different GPU count, different Python version). Alternatively, add a note that the output is illustrative.

tensorrt_llm/_torch/auto_deploy/utils/_graph.py (1)

32-33: Module-level registries are not thread-safe, but acceptable for current usage.

_derived_op_libs and _derived_op_registry are plain dicts without locking. This is fine since graph transforms run single-threaded per rank. If this utility is ever called from concurrent threads, consider adding a lock (similar to CudaStreamManager._Singleton).

tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py (1)

147-215: Reworked _execute_op_in_aux_stream is cleaner and more extensible.

The lazy aux-op creation (lines 173-175), routing-input detection heuristic (lines 185-189), and event placement after the latest routing input (line 195) are well-reasoned. One minor observation: node_order.get(inp, 0) on line 195 defaults to 0 for nodes not in the order map. All routing_inputs come from n.all_input_nodes which should be in the graph, so this default is safe but could mask bugs. Consider using node_order[inp] if you want to fail fast on unexpected state.

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_glm4_moe_lite.py (2)

1-1: Copyright header format differs from the repo standard.

Other files in this repo use the SPDX format (e.g., SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES). This file uses a different format. As per coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header with the year of latest meaningful modification.


686-699: attention_scaling computation is duplicated between _init_rope and Glm4MoeLiteAttention.__init__.

The mscale computation logic (check for rope_scaling, extract mscale_all_dim, call _yarn_get_mscale) is repeated in Glm4MoeLiteAttention.__init__ (lines 502–513) and Glm4MoeLiteModel._init_rope (lines 688–699). Consider extracting this to a shared helper to keep the two in sync.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/rope/test_triton_rope.py (2)

1-1: Unused Optional import.

Optional is imported but never used in this file. Only Tuple is needed (for _precompute_cos_sin_cache).

Proposed fix
-from typing import Optional, Tuple
+from typing import Tuple

7-8: Unnecessary noqa: F401 directive.

The static analysis tool reports this noqa directive targets a non-enabled rule (F401). The import is needed for its side-effect (custom op registration), but the noqa comment is redundant. Consider removing it or replacing with an inline comment explaining the side-effect purpose.

Proposed fix
 # Import after we've imported torch (to ensure custom ops are registered)
-from tensorrt_llm._torch.auto_deploy.custom_ops import triton_rope  # noqa: F401
+from tensorrt_llm._torch.auto_deploy.custom_ops import triton_rope  # registers custom ops
tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_glm4_moe_lite_modeling.py (2)

486-522: Remove debug print statements from test code.

Multiple verbose debug print calls (state dict keys, shapes, tensor slices) appear to be leftover debugging artifacts. They will clutter CI output without providing actionable information during normal test runs. Consider removing them or converting to conditional logging.

Proposed fix: remove debug prints
     hf_state_dict = hf_moe.state_dict()
 
-    # Debug: print state dict keys and shapes
-    print("\n=== HF MoE state_dict keys and shapes ===")
-    for k, v in hf_state_dict.items():
-        print(f"  {k}: {v.shape}")
-
     custom_state_dict = _convert_hf_moe_state_dict_to_custom(hf_state_dict, config.n_routed_experts)
 
-    print("\n=== Converted custom state_dict keys and shapes ===")
-    for k, v in custom_state_dict.items():
-        print(f"  {k}: {v.shape}")
-
-    print("\n=== Expected custom MoE state_dict keys ===")
-    for k, v in custom_moe.state_dict().items():
-        print(f"  {k}: {v.shape}")
-
     custom_moe.load_state_dict(custom_state_dict)
     custom_moe.eval()
 
     ...
 
-    print(f"\n=== Debug: intermediate_size = {intermediate_size} ===")
-    print(f"hf_gate_up shape: {hf_gate_up.shape}")
-    print(f"hf_gate_up[0, :2, :2]: {hf_gate_up[0, :2, :2]}")
-
     # Get the converted state dict values for comparison
     converted_gate_0 = custom_state_dict["experts.0.gate_proj.weight"]
-    print(f"converted_gate_0 shape: {converted_gate_0.shape}")
-    print(f"converted_gate_0[:2, :2]: {converted_gate_0[:2, :2]}")
-
     # After load_state_dict
     loaded_gate_0 = custom_moe.experts[0].gate_proj.weight
-    print(f"loaded_gate_0 shape: {loaded_gate_0.shape}")
-    print(f"loaded_gate_0[:2, :2]: {loaded_gate_0[:2, :2]}")

275-335: Move return to else block in try/except patterns.

Per coding guidelines: "keep the body of the try as small as possible and use the else block for the main logic." The return statements after the import should be in an else block. This applies to all five _get_hf_*_class() helpers.

Example fix for one function (apply similarly to all five)
 def _get_hf_model_class():
     try:
         from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
             Glm4MoeLiteForCausalLM as HFGlm4MoeLiteForCausalLM,
         )
-
-        return HFGlm4MoeLiteForCausalLM
     except ImportError:
         return None
+    else:
+        return HFGlm4MoeLiteForCausalLM
tests/integration/defs/accuracy/test_llm_api_autodeploy.py (1)

429-431: Remove commented-out code.

Line 430 has a commented-out alternative MODEL_NAME. This is dead code that should be removed or tracked in an issue.

Proposed fix
     MODEL_NAME = "zai-org/GLM-4.7-Flash"
-    `#MODEL_NAME` = "DeepInfra/GLM-4.7-Flash-NVFP4"
     MODEL_PATH = MODEL_NAME  # Model is in HF_CACHE
tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (1)

617-617: Rename unused rest variable.

Static analysis correctly flags rest as unused. Prefix it with an underscore.

Proposed fix
-    q_node, k_node, cos_node, sin_node, *rest = node.args
+    q_node, k_node, cos_node, sin_node, *_rest = node.args
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_mla.py (1)

110-115: Causal masking silently skipped when s_q != s_k.

When is_causal=True but s_q != s_k (e.g., decode with s_q=1), the causal mask is silently not applied. This is correct for MLA's decode path (single query attends to all past KV), but a brief inline comment explaining this design choice would help future readers avoid confusion.

Proposed clarification
     # Apply causal mask if specified
-    if is_causal and s_q == s_k:
+    # For prefill (s_q == s_k), apply the standard upper-triangular causal mask.
+    # When s_q != s_k (e.g., decode), no mask is needed since the query token(s)
+    # should attend to all available KV entries.
+    if is_causal and s_q == s_k:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_nvfp4_swiglu.py (1)

13-13: Remove unused noqa directive.

Ruff reports that # noqa: F401 is unnecessary here since F401 is not enabled. The side-effect import is fine on its own.

-import tensorrt_llm._torch.auto_deploy.custom_ops  # noqa: F401
+import tensorrt_llm._torch.auto_deploy.custom_ops
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_torch_mla_op.py (3)

19-19: Remove unused noqa directive.

Ruff reports that # noqa: F401 is unnecessary here.

-import tensorrt_llm._torch.auto_deploy  # noqa: F401
+import tensorrt_llm._torch.auto_deploy

124-127: Redundant if/else branches — both produce the same result.

Both the is_generate and else branches compute query_full identically.

-        if is_generate:
-            query_full = np.concatenate([q_nope_seq, q_pe_seq], axis=-1)
-        else:
-            query_full = np.concatenate([q_nope_seq, q_pe_seq], axis=-1)
+        query_full = np.concatenate([q_nope_seq, q_pe_seq], axis=-1)

630-675: Verify the numpy reference receives a pre-mutation cache snapshot.

_run_cached_mla(data) on line 652 mutates data["mla_cache"] in-place. The numpy reference on line 661 then receives a copy of the already-updated cache (via .cpu().float().numpy()). Since the numpy reference also writes the same values to the same positions, this happens to be idempotent. However, this is fragile—if the test ever changes to check prefill or multi-token scenarios, the double-update could silently mask bugs. Consider snapshotting the cache before running the op.

tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py (1)

340-340: Prefix unused unpacked variable with _.

num_prefill_tokens is unpacked but never used.

-    num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
+    num_prefill, _num_prefill_tokens, num_decode = batch_info_host.tolist()
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_swiglu.py (2)

289-289: Remove unused noqa directive.

Ruff reports the # noqa: E402 is unnecessary.

-from ...custom_ops.linear.swiglu import torch_nvfp4_swiglu_mlp  # noqa: E402
+from ...custom_ops.linear.swiglu import torch_nvfp4_swiglu_mlp

519-521: No validation that gate and up share the same input_scale/alpha.

The fusion uses gate_input_scale_node and gate_alpha_node for the fused op (lines 557, 559), discarding the up counterparts. The comment on line 549 states they share input, but there's no check. If a quantizer ever produces mismatched scales (e.g., different calibration), the fusion would silently produce wrong results.

Consider adding a debug assertion:

Suggested assertion
             up_input_scale_node = node.args[7]  # noqa: F841
             up_weight_scale_node = node.args[8]
             up_alpha_node = node.args[9]  # noqa: F841
+
+            # Verify shared input assumption for gate and up projections
+            gate_is = get_attr_by_name(gm, gate_input_scale_node.target)
+            up_is = get_attr_by_name(gm, up_input_scale_node.target)
+            assert torch.equal(gate_is, up_is), (
+                "NVFP4 SwiGLU fusion requires gate and up to share input_scale"
+            )
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py (2)

439-439: Prefix unused unpacked variable with _.

-    num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
+    num_prefill, _num_prefill_tokens, num_decode = batch_info_host.tolist()

582-596: Duplicated W_kn / W_v extraction between chunked prefill and decode.

The reshape-and-slice of kv_b_proj_weight into w_kn and w_v is identical in both the chunked prefill path (lines 584-590) and the decode path (lines 710-716). Consider extracting this into a small helper.

Also applies to: 710-722

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_deepseek.py (2)

218-221: Unused unpacked variables bsz and seq_len in MoEGate.forward.

These are destructured but never referenced.

-        bsz, seq_len, hidden_dim = hidden_states.shape
+        _bsz, _seq_len, hidden_dim = hidden_states.shape

Or simply:

-        bsz, seq_len, hidden_dim = hidden_states.shape
+        hidden_dim = hidden_states.shape[-1]

574-603: **kwargs captured but unused in DeepSeekV3Model.forward.

The **kwargs parameter is accepted but never forwarded or used. If this is for future extensibility or HuggingFace compatibility, consider adding a brief comment. Otherwise, it could mask typos in caller keyword args.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_flashinfer_mla_op.py (3)

19-19: Remove unused noqa directive.

Ruff reports the # noqa: F401 directive is unnecessary (F401 rule is not enabled). The side-effect import is fine; just remove the suppression comment.

-import tensorrt_llm._torch.auto_deploy  # noqa: F401
+import tensorrt_llm._torch.auto_deploy

636-640: Large parametrization may cause slow execution or OOM on smaller GPUs.

The combination batch_size=256, prefill_seq_length=1024, num_heads=8 allocates substantial GPU memory (ckv_cache alone ≈ 500+ MB in bf16, plus inputs and intermediaries). Consider either reducing the upper bounds or marking the largest combinations with @pytest.mark.slow so they can be skipped in fast CI runs.


1604-1618: Consider extracting repeated patterns into helpers to reduce duplication.

Several patterns appear 3–12 times across this file:

  1. Cache pre-fill loop (lines 1604–1618, 1753–1767, 1878–1892)
  2. Planner reset (lines 1621–1623, 1783–1785, 1908–1910, 2049–2051)
  3. qo_indptr + batch_indices/positions computation (~10 occurrences)
  4. The 18-argument flashinfer_mla_with_cache call (~12 occurrences)

Extracting helpers like _prefill_paged_cache(...), _reset_planner(device), _compute_batch_indices(...), and _run_flashinfer_mla(...) would reduce ~200+ lines of boilerplate and make future maintenance easier.

tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py (3)

66-70: Fixture named setup_method shadows the pytest hook of the same name.

In plain pytest classes, setup_method is a recognized hook that pytest calls automatically before each test method. Decorating it with @pytest.fixture(autouse=True) works but is confusing — readers may not know whether pytest treats it as a fixture or a hook. Consider renaming to something like _setup or setup_fixtures to avoid ambiguity. This same pattern repeats in every test class in this file.


394-420: MLA op callability tests may fail on CPU-only environments.

The test_torch_mla_callable and test_torch_cached_mla_callable tests fall back to "cpu" when CUDA is unavailable, but custom ops registered under torch.ops.auto_deploy (especially MLA/attention ops) are very likely CUDA-only. Consider marking these tests (and potentially other GPU-dependent tests in this file) with @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") to avoid spurious failures.

Also applies to: 422-471


482-494: Test relies on the private attribute _custom_model_mapping.

This is acceptable for a unit test verifying registration internals, but note that it will break silently if the attribute is renamed. If there's a public API to query registered models, prefer using that instead.

tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py (2)

86-93: erased check should also cover cast_node.

The stale-match guard on line 91 checks add_node and norm_node but not cast_node. If two norm nodes were to share the same intermediate cast (unlikely but possible), the second iteration would operate on an already-erased cast node and could fail.

Suggested fix
         for add_node, cast_node, norm_node in matches:
             # Safety: skip if a node in this match was already consumed
-            if id(add_node) in erased or id(norm_node) in erased:
+            if id(add_node) in erased or id(norm_node) in erased or (
+                cast_node is not None and id(cast_node) in erased
+            ):
                 continue

130-140: Use next(iter(cast_node.users)) instead of list(cast_node.users)[0].

This avoids materializing the entire users dict into a list just to get the first element.

Suggested fix
-                    with graph.inserting_before(list(cast_node.users)[0]):
+                    with graph.inserting_before(next(iter(cast_node.users))):
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py (1)

6-8: Remove unused # noqa directive.

Ruff reports the blanket noqa on this import is unnecessary — the import itself is valid and used (referenced by _count_fused_ops and test assertions). The noqa can be safely removed.

-from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import (  # noqa
+from tensorrt_llm._torch.auto_deploy.custom_ops.normalization.flashinfer_fused_add_rms_norm import (
     flashinfer_fused_add_rms_norm,
 )

Comment on lines +41 to +62
def _update_kv_cache(
key_states: torch.Tensor,
value_states: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
seq_len: torch.Tensor, # metadata
input_pos: torch.Tensor, # metadata
slot_idx: torch.Tensor,
seq_start: torch.Tensor,
) -> torch.Tensor:
"""
Reference implementation for update kv cache function. Assumes KV cache layout to be [B,S,N,D].
This function can be used to build reference attention implementations that use KV cache.
"""

for idx in range(seq_len.shape[0]):
k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Incorrect return type annotation: function returns None, not torch.Tensor.

_update_kv_cache mutates the caches in-place and has no return statement, yet the signature declares -> torch.Tensor.

Proposed fix
 def _update_kv_cache(
     key_states: torch.Tensor,
     value_states: torch.Tensor,
     k_cache: torch.Tensor,
     v_cache: torch.Tensor,
     seq_len: torch.Tensor,  # metadata
     input_pos: torch.Tensor,  # metadata
     slot_idx: torch.Tensor,
     seq_start: torch.Tensor,
-) -> torch.Tensor:
+) -> None:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _update_kv_cache(
key_states: torch.Tensor,
value_states: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
seq_len: torch.Tensor, # metadata
input_pos: torch.Tensor, # metadata
slot_idx: torch.Tensor,
seq_start: torch.Tensor,
) -> torch.Tensor:
"""
Reference implementation for update kv cache function. Assumes KV cache layout to be [B,S,N,D].
This function can be used to build reference attention implementations that use KV cache.
"""
for idx in range(seq_len.shape[0]):
k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
def _update_kv_cache(
key_states: torch.Tensor,
value_states: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
seq_len: torch.Tensor, # metadata
input_pos: torch.Tensor, # metadata
slot_idx: torch.Tensor,
seq_start: torch.Tensor,
) -> None:
"""
Reference implementation for update kv cache function. Assumes KV cache layout to be [B,S,N,D].
This function can be used to build reference attention implementations that use KV cache.
"""
for idx in range(seq_len.shape[0]):
k_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = key_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
v_cache[slot_idx[idx], input_pos[idx] : input_pos[idx] + seq_len[idx], :, :] = value_states[
seq_start[idx] : seq_start[idx] + seq_len[idx], ...
]
🤖 Prompt for AI Agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`
around lines 41 - 62, The type annotation on _update_kv_cache is wrong: the
function mutates k_cache and v_cache in-place and does not return anything, so
change the signature return annotation from "-> torch.Tensor" to "-> None" (or
remove the return annotation) for correctness; update the function signature for
_update_kv_cache and, optionally, adjust the docstring to state it mutates
k_cache/v_cache in-place to make the behavior explicit.

Comment on lines +1 to 21
"""MLA (Multi-head Latent Attention) custom ops.

"""Multi-head Latent Attention operations.

This module provides Multi-head Latent Attention (MLA) implementations:
- mla: MLA operations and attention descriptor
Exports:
- TorchBackendMLAAttention: Attention descriptor for MLA (registered as "torch_mla")
- FlashInferMLAAttention: Attention descriptor for FlashInfer MLA (registered as "flashinfer_mla")
- torch_mla: Source op for MLA attention
- torch_backend_mla_with_cache: Cached backend op with FlashInfer-compatible cache
- flashinfer_mla_with_cache: Cached backend op using FlashInfer MLA kernels
"""

from .flashinfer_mla import FlashInferMLAAttention, flashinfer_mla_with_cache
from .torch_backend_mla import TorchBackendMLAAttention, torch_backend_mla_with_cache
from .torch_mla import torch_mla

__all__ = [
"mla",
"TorchBackendMLAAttention",
"FlashInferMLAAttention",
"torch_mla",
"torch_backend_mla_with_cache",
"flashinfer_mla_with_cache",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing NVIDIA copyright header.

This file should contain a copyright header per coding guidelines. As per coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header.

The re-export surface and docstring are clean and well-organized.

🧰 Tools
🪛 Ruff (0.14.14)

[warning] 15-21: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py` around lines 1 -
21, Add the required NVIDIA copyright header at the top of this __init__ module:
insert the standard multi-line NVIDIA copyright/license header before the module
docstring so the file begins with the header, keeping the existing docstring and
imports unchanged; this affects the module that exports
TorchBackendMLAAttention, FlashInferMLAAttention, torch_mla,
torch_backend_mla_with_cache, and flashinfer_mla_with_cache.

Comment on lines +1 to +18
"""FlashInfer-based MLA (Multi-head Latent Attention) backend with paged caching.

This module provides:
- FlashInferMLAAttention: attention descriptor using FlashInfer MLA kernels
- flashinfer_mla_with_cache: cached backend op with paged KV cache

FlashInfer MLA uses:
- Regular prefill (input_pos == 0): BatchPrefillWithRaggedKVCacheWrapper with expanded K, V
- Chunked prefill (input_pos > 0): BatchMLAPagedAttentionWrapper with matrix absorption
- Decode: BatchMLAPagedAttentionWrapper with paged compressed KV cache

FlashInfer MLA Cache Layout (two separate caches):
ckv_cache: [num_pages, page_size, kv_lora_rank]
kpe_cache: [num_pages, page_size, qk_rope_head_dim]
- No num_heads dimension (MLA-specific optimization)

Reference: https://docs.flashinfer.ai/api/mla.html
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing NVIDIA copyright header.

As per coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header.

🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py` around
lines 1 - 18, Add the required NVIDIA copyright header at the top of the file
(before the module docstring) so the file complies with project guidelines;
update the header to include the standard NVIDIA ownership and license lines
used across the repo, then ensure the existing module-level docstring
(containing FlashInferMLAAttention and flashinfer_mla_with_cache descriptions)
follows the header without modifications.

)


@torch.library.custom_op("auto_deploy::flashinfer_mla_with_cache", mutates_args=())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

mutates_args=() is incorrect — ckv_cache and kpe_cache are mutated in-place.

flashinfer.page.append_paged_mla_kv_cache (line 541) writes to both ckv_cache and kpe_cache. Declaring empty mutates_args will cause the compiler/tracer to treat these caches as unmodified, potentially eliminating or reordering the writes.

-@torch.library.custom_op("auto_deploy::flashinfer_mla_with_cache", mutates_args=())
+@torch.library.custom_op("auto_deploy::flashinfer_mla_with_cache", mutates_args=("ckv_cache", "kpe_cache"))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@torch.library.custom_op("auto_deploy::flashinfer_mla_with_cache", mutates_args=())
`@torch.library.custom_op`("auto_deploy::flashinfer_mla_with_cache", mutates_args=("ckv_cache", "kpe_cache"))
🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py` at line
452, The custom op decorator for auto_deploy::flashinfer_mla_with_cache
incorrectly declares mutates_args=() even though
flashinfer.page.append_paged_mla_kv_cache mutates the KV caches; update the
`@torch.library.custom_op` on flashinfer_mla_with_cache to list the mutated
arguments (e.g. mutates_args=("ckv_cache","kpe_cache") or the correct parameter
names used in the function signature) so the tracer/compiler knows ckv_cache and
kpe_cache are modified in-place.

Comment on lines +1 to +18
"""Custom ops for MultiHead Latent Attention (MLA) with FlashInfer-compatible cache.

This module provides:
- torch_cached_mla_with_cache: cached backend op
- TorchBackendMLAAttention: attention descriptor

FlashInfer MLA Cache Layout:
mla_cache: [max_batch, max_seq, kv_lora_rank + qk_rope_head_dim]
- No num_heads dimension (MLA-specific optimization)
- compressed_kv_cached = mla_cache[:, :, :kv_lora_rank] (zero-copy slice)
- kpe_cached = mla_cache[:, :, kv_lora_rank:] (zero-copy slice)

The implementation uses:
- Prefill: Expand compressed_kv -> full K, V, compute normal attention
- Generate: Weight absorption for efficiency (Q @ W^T instead of expanding cached KV)

Reference: https://docs.flashinfer.ai/tutorials/kv_layout.html#mla-page-layout
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing NVIDIA copyright header.

As per coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header.

🤖 Prompt for AI Agents
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py` around
lines 1 - 18, Add the required NVIDIA copyright header to this source file by
inserting the standard multi-line copyright/license block at the top of
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py (before the
module docstring); ensure the header format matches other files in the repo and
remains above the existing docstring and imports so functions/classes like
torch_cached_mla_with_cache and TorchBackendMLAAttention retain their original
locations.

Comment on lines +773 to +778
assert torch.allclose(
flashinfer_output_context_reshaped.cpu().to(torch.float32),
torch_output_context_reshaped.cpu().to(torch.float32),
atol=0.01,
rtol=0.01,
), "Context phase outputs don't match"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Inconsistent tolerance: context assertion here uses atol=0.01 while other context tests use atol=0.05.

test_flashinfer_mla_op_context uses atol=0.05, rtol=0.02 for the same context-phase comparison, but this test uses atol=0.01, rtol=0.01. With larger batch sizes (256) and sequence lengths (1024), the tighter tolerance is more likely to cause flaky failures. Consider aligning tolerances across context-phase assertions for consistency.

🤖 Prompt for AI Agents
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_flashinfer_mla_op.py`
around lines 773 - 778, The context-phase assertion in test_flashinfer_mla_op.py
is using tighter tolerances (atol=0.01, rtol=0.01) on the tensors
flashinfer_output_context_reshaped and torch_output_context_reshaped which can
cause flakes for large batch/sequence sizes; update this assertion to use the
same tolerances as the other context test (atol=0.05, rtol=0.02) so the
comparison in that block matches the rest of the suite.

Comment on lines +2135 to +2138
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("device", ["cuda"])
def test_flashinfer_mla_init_decode_wrapper_with_buffers(batch_size, dtype, device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused dtype parameter.

dtype is parametrized and accepted as a function argument but never referenced in the test body (all tensors in this test use torch.int32). Either use it or remove it from the parametrize/signature.

Proposed fix
-@pytest.mark.parametrize("batch_size", [4])
-@pytest.mark.parametrize("dtype", [torch.bfloat16])
-@pytest.mark.parametrize("device", ["cuda"])
-def test_flashinfer_mla_init_decode_wrapper_with_buffers(batch_size, dtype, device):
+@pytest.mark.parametrize("batch_size", [4])
+@pytest.mark.parametrize("device", ["cuda"])
+def test_flashinfer_mla_init_decode_wrapper_with_buffers(batch_size, device):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("device", ["cuda"])
def test_flashinfer_mla_init_decode_wrapper_with_buffers(batch_size, dtype, device):
`@pytest.mark.parametrize`("batch_size", [4])
`@pytest.mark.parametrize`("device", ["cuda"])
def test_flashinfer_mla_init_decode_wrapper_with_buffers(batch_size, device):
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 2138-2138: Unused function argument: dtype

(ARG001)

🤖 Prompt for AI Agents
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_flashinfer_mla_op.py`
around lines 2135 - 2138, The test function
test_flashinfer_mla_init_decode_wrapper_with_buffers declares and parametrizes
dtype (torch.bfloat16) but never uses it; either remove the dtype param from the
`@pytest.mark.parametrize` and the function signature, or actually apply dtype
when constructing tensors in that test (e.g., pass dtype to
torch.tensor/torch.empty calls). Locate the test by the function name
test_flashinfer_mla_init_decode_wrapper_with_buffers and update the
decorator/signature to drop dtype if unused, or update all tensor creations
inside the function to use the dtype parameter so the parameterization is
meaningful.

Comment on lines +33 to +34
scale: float = None,
kv_lora_rank: int = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use explicit Optional type hints instead of implicit None defaults.

PEP 484 prohibits implicit Optional. These parameters accept None but are typed as float / int.

-    scale: float = None,
-    kv_lora_rank: int = None,
+    scale: Optional[float] = None,
+    kv_lora_rank: Optional[int] = None,

You'll need to add from typing import Optional at the top (or use float | None for Python 3.10+, but the codebase targets 3.8+).

🧰 Tools
🪛 Ruff (0.14.14)

[warning] 33-33: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


[warning] 34-34: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

🤖 Prompt for AI Agents
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/mla/test_torch_mla_op.py`
around lines 33 - 34, Change the parameters that are currently annotated as
"scale: float = None" and "kv_lora_rank: int = None" to use explicit Optional
type hints (e.g. scale: Optional[float] = None, kv_lora_rank: Optional[int] =
None) and add "from typing import Optional" at the top of the module (or use the
3.10 union syntax if you prefer but the codebase targets 3.8+, so use
typing.Optional). This ensures the function signature (the parameters named
scale and kv_lora_rank) is PEP 484-compliant and avoids implicit Optional
typing.

Comment on lines +445 to +449
batch_info_host = torch.tensor([0, 0, batch_size], dtype=torch.int32, device=device)
seq_len_tensor = torch.tensor([seq_len], dtype=torch.int32, device=device)
input_pos = torch.tensor([0], dtype=torch.int32, device=device)
cache_loc = torch.tensor([0], dtype=torch.int32, device=device)
cu_seqlen = torch.tensor([0], dtype=torch.int32, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

rg -n "torch_cached_mla_with_cache" --type=py -C 5 -g '!*test*'

Repository: NVIDIA/TensorRT-LLM

Length of output: 4882


🏁 Script executed:

sed -n '297,350p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2148


🏁 Script executed:

rg "torch_cached_mla_with_cache\|torch\.ops\.auto_deploy\.torch_cached_mla_with_cache" --type=py -B 3 -A 3 -g '!*test*'

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

sed -n '420,500p' tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2943


🏁 Script executed:

rg "host.*torch\.tensor\|torch\.tensor.*host" --type=py -B 2 -A 2 -g '!*test*' | head -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

sed -n '297,450p' tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 5453


🏁 Script executed:

head -20 tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 564


batch_info_host should be allocated on CPU (device="cpu"), not on the compute device.

The variable name _host follows the PyTorch/CUDA convention for CPU tensors. The function implementation (line 339 of torch_backend_mla.py) calls batch_info_host.tolist() directly, which is inefficient and anti-pattern when the tensor is on GPU—it forces an implicit GPU-to-CPU transfer. Metadata arguments like batch_info_host should reside on the host (CPU) by convention.

Fix
-        batch_info_host = torch.tensor([0, 0, batch_size], dtype=torch.int32, device=device)
+        batch_info_host = torch.tensor([0, 0, batch_size], dtype=torch.int32, device="cpu")
🤖 Prompt for AI Agents
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py`
around lines 445 - 449, The test creates batch_info_host on the compute device
but it should be on CPU; change the allocation of batch_info_host in the test
(variable name batch_info_host in
tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_custom.py)
to use device="cpu" so that torch_backend_mla.py (which calls
batch_info_host.tolist()) does not trigger an implicit GPU->CPU transfer; keep
the other tensors on the device as-is.

import torch
from torch.export import Dim

from tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu import * # noqa
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Avoid wildcard imports; maintain namespace.

The wildcard import makes it unclear which symbols are used and violates the coding guideline: "Always maintain the namespace when importing Python modules." If the intent is side-effect registration, use a specific import or a noqa: F401-annotated side-effect import.

Suggested fix
-from tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu import *  # noqa
+import tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu  # noqa: F401

If specific symbols (e.g. torch_swiglu_mlp) are needed directly, import them explicitly.

As per coding guidelines, "Always maintain the namespace when importing Python modules, even if only one class or function from a module is used."

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu import * # noqa
import tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu # noqa: F401
🤖 Prompt for AI Agents
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_swiglu.py`
at line 5, Replace the wildcard import from
tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu with an explicit import
or a namespace import: either import the module (e.g., import
tensorrt_llm._torch.auto_deploy.custom_ops.linear.swiglu as swiglu) or
explicitly import the required symbols (e.g., torch_swiglu_mlp) so the test file
no longer uses `from ...swiglu import *`; if the import is only performed for
module-level side effects, keep a module import and add a `# noqa: F401` comment
to indicate an intentional unused-import-side-effect.

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
with AutoDeployLLM(model=self.MODEL_PATH,
tokenizer=self.MODEL_PATH,
**kwargs) as llm:
task = MMLU(self.MODEL_NAME)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add a baseline at tests/integration/defs/accuracy/references/mmlu.yaml

**kwargs) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests/integration/defs/accuracy/references/gsm8k.yaml

@juney-nvidia
Copy link
Collaborator

@suyoggupta Hi Suyog, out of curiosity, why this PR can become so large ? :)

@suyoggupta
Copy link
Collaborator Author

@suyoggupta Hi Suyog, out of curiosity, why this PR can become so large ? :)

It'll shrink considerably once #11324 is merged and this branch is rebased.

@suyoggupta
Copy link
Collaborator Author

#11520

@suyoggupta suyoggupta closed this Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants