Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 151 additions & 159 deletions cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Large diffs are not rendered by default.

46 changes: 30 additions & 16 deletions cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>

// Forward declarations for NCCL device API types
struct ncclComm;
typedef struct ncclComm* ncclComm_t;
struct ncclDevComm;

TRTLLM_NAMESPACE_BEGIN

namespace kernels::moe_comm
Expand Down Expand Up @@ -47,10 +52,8 @@ struct DispatchKernelPointers
void* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers
int payload_bytes_per_token[kMaxPayloads]; // Bytes per token for each payload

// Completion flags for synchronization
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
// NCCL device communicator for LSA barrier synchronization
ncclDevComm const* dev_comm;

// Local aux data pointers
int* send_counters; // [ep_size] How many tokens have been sent to each target rank
Expand All @@ -74,10 +77,12 @@ struct CombineKernelPointers
void* src_data_ptrs[kMaxPayloads]; // src_data_ptrs[0] is output
void const* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers (const)

// Completion flags for synchronization
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
// NCCL device communicator for LSA barrier synchronization
ncclDevComm const* dev_comm;

// Atomic counter reused for intra-rank barrier notification.
// Block 0 sets it to 1 after NCCL barrier; other blocks spin on it.
int* local_token_counter;

// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
int const* topk_target_ranks; // target rank per k, -1 for duplicates
Expand Down Expand Up @@ -107,8 +112,10 @@ struct MoeA2ADispatchParams
int num_payloads; // Number of different payload types
PayloadDescriptor payloads[kMaxPayloads]; // Array of payload descriptors

// NCCL device communicator for LSA barrier synchronization
ncclDevComm const* dev_comm;

// Local aux data
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
int* local_token_counter; // Atomic counter for completed tokens on this rank
int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank
int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
Expand All @@ -118,8 +125,6 @@ struct MoeA2ADispatchParams

// Distributed aux data and recv buffers
int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has [ep_size] counters
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload

// Optional: Statistics for EPLB
Expand All @@ -134,7 +139,7 @@ struct MoeA2ADispatchParams

// Dispatch kernels
void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params);
// Prepare for dispatch: zero send_counters, local_token_counter and increment flag_val
// Prepare for dispatch: zero send_counters and local_token_counter
void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params);

// Combine phase parameters
Expand All @@ -160,18 +165,19 @@ struct MoeA2ACombineParams
int elements_per_token; // Number of elements per token
nvinfer1::DataType dtype; // Data type for proper summation

// NCCL device communicator for LSA barrier synchronization
ncclDevComm const* dev_comm;

// Local aux data
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
int* local_token_counter; // Reused as a flag for intra-rank synchronization for combine
int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
// per k, -1 for duplicates
int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index
// per k, -1 for duplicates
int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target

// Distributed aux data and recv buffers
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)

// CUDA stream
cudaStream_t stream;
Expand All @@ -182,6 +188,14 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params);

void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params);

// NCCL DevComm lifecycle helpers (defined in .cu, callable from host code).
// create_moe_nccl_dev_comm creates an NCCL communicator + device communicator with the
// specified number of LSA barriers. The returned pointer is a DEVICE pointer to ncclDevComm
// that can be passed directly to kernel parameters.
// This is a collective call -- all ranks in the group must call it.
ncclDevComm* create_moe_nccl_dev_comm(int ep_size, int ep_rank, int num_lsa_barriers);
void destroy_moe_nccl_dev_comm(ncclDevComm* dev_comm);

// Sanitize expert IDs for invalid tokens
// expert_ids: [ep_size, max_tokens_per_rank, top_k] (int32)
// recv_counters: [ep_size] (int32), number of valid tokens per source
Expand Down
24 changes: 8 additions & 16 deletions cpp/tensorrt_llm/thop/moeAlltoAllMeta.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,24 @@ namespace moe_comm
// Enum for indexing into moe_a2a_metainfo tensor
enum MoeA2AMetaInfoIndex : int64_t
{
FLAG_VAL_OFFSET_INDEX = 0,
LOCAL_TOKEN_COUNTER_OFFSET_INDEX = 1,
SEND_COUNTERS_OFFSET_INDEX = 2,
RECV_COUNTERS_OFFSET_INDEX = 3,
// Dispatch completion flags offset
DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = 4,
// Combine completion flags offset
COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5,
TOPK_TARGET_RANKS_OFFSET_INDEX = 6,
TOPK_SEND_INDICES_OFFSET_INDEX = 7,
EPLB_GATHERED_STATS_OFFSET_INDEX = 8,
PAYLOAD_DATA_OFFSET_INDEX = 9,
NUM_METAINFO_FIELDS = 10
LOCAL_TOKEN_COUNTER_OFFSET_INDEX = 0,
SEND_COUNTERS_OFFSET_INDEX = 1,
RECV_COUNTERS_OFFSET_INDEX = 2,
TOPK_TARGET_RANKS_OFFSET_INDEX = 3,
TOPK_SEND_INDICES_OFFSET_INDEX = 4,
EPLB_GATHERED_STATS_OFFSET_INDEX = 5,
PAYLOAD_DATA_OFFSET_INDEX = 6,
NUM_METAINFO_FIELDS = 7
};

using MoeA2ADataOffsets = std::array<int64_t, NUM_METAINFO_FIELDS>;

inline std::vector<std::pair<char const*, int64_t>> getMoeA2AMetaInfoIndexPairs()
{
return {
{"MOE_A2A_FLAG_VAL_OFFSET_INDEX", FLAG_VAL_OFFSET_INDEX},
{"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", LOCAL_TOKEN_COUNTER_OFFSET_INDEX},
{"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", SEND_COUNTERS_OFFSET_INDEX},
{"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", RECV_COUNTERS_OFFSET_INDEX},
{"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX", DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX},
{"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX},
{"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX},
{"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX},
{"MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX", EPLB_GATHERED_STATS_OFFSET_INDEX},
Expand Down
51 changes: 27 additions & 24 deletions cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include <torch/types.h>
#include <vector>

// Forward declaration for NCCL device communicator (defined in moeAlltoAllKernels.cu)
struct ncclDevComm;

TRTLLM_NAMESPACE_BEGIN

namespace torch_ext
Expand Down Expand Up @@ -52,10 +55,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu
MoeA2ADataOffsets offsets;
size_t offset = 0;

// flag_val
offsets[FLAG_VAL_OFFSET_INDEX] = offset;
offset += SIZEOF_INT32;

// local_token_counter
offsets[LOCAL_TOKEN_COUNTER_OFFSET_INDEX] = offset;
offset += SIZEOF_INT32;
Expand All @@ -68,16 +67,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu
offsets[RECV_COUNTERS_OFFSET_INDEX] = offset;
offset += epSize * SIZEOF_INT32;

// dispatch completion flags
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX] = offset;
offset += epSize * SIZEOF_INT32;

// combine completion flags
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
offsets[COMBINE_COMPLETION_FLAGS_OFFSET_INDEX] = offset;
offset += epSize * SIZEOF_INT32;

// topk_target_ranks: [maxNumTokens, kMaxTopK]
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
offsets[TOPK_TARGET_RANKS_OFFSET_INDEX] = offset;
Expand All @@ -103,7 +92,7 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu
}

// Initialize auxiliary data in workspace
// This function sets up the initial values for flag_val and completion_flags
// This function sets up the initial workspace values and creates the NCCL device communicator
//
// Inputs:
// - workspace: [ep_size, size_per_rank] unified virtual memory workspace
Expand All @@ -114,6 +103,10 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu
//
// Returns:
// - metainfo: Tensor containing offsets for auxiliary data

// Global storage for NCCL device communicator (persistent across dispatch/combine calls).
static ncclDevComm* g_moeDevComm = nullptr;

torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens,
torch::optional<int64_t> eplbStatsNumExperts)
{
Expand Down Expand Up @@ -142,10 +135,21 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank,
metainfo[i] = static_cast<int64_t>(offsets[i]);
}

// Synchronize among ranks
// Synchronize among ranks before NCCL initialization
cudaDeviceSynchronize();
tensorrt_llm::mpi::MpiComm::world().barrier();

// Create NCCL device communicator with 2 LSA barriers:
// - Index 0: dispatch barrier (warp 0 of last-token CTA)
// - Index 1: combine barrier (warp 0 of an elected CTA, other CTAs wait via local_token_counter)
// This is a collective call -- all EP ranks must participate.
if (g_moeDevComm == nullptr)
{
using tensorrt_llm::kernels::moe_comm::create_moe_nccl_dev_comm;
g_moeDevComm
= create_moe_nccl_dev_comm(static_cast<int>(epSize), static_cast<int>(epRank), /*num_lsa_barriers=*/2);
}
Comment on lines +138 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

NCCL device communicator initialization is well-guarded and documented.

The initialization sequence is sound:

  • cudaDeviceSynchronize() + MPI barrier ensures all ranks are ready before the collective NCCL init
  • The if (g_moeDevComm == nullptr) guard prevents double creation
  • The 2 LSA barrier allocation (index 0 for dispatch, index 1 for combine) is well documented

One edge case: if moeA2AInitializeOp is called with different epSize/epRank values on a subsequent call (while g_moeDevComm != nullptr), the existing communicator would be silently reused despite being configured for different parameters. The Python side (NVLinkOneSided.__init__) has workspace reuse assertions, but the C++ NCCL comm doesn't have equivalent validation.

🤖 Prompt for AI Agents
In `@cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp` around lines 138 - 151, The global
communicator g_moeDevComm may be reused with mismatched epSize/epRank; modify
moeA2AInitializeOp to validate that an existing g_moeDevComm was created with
the same epSize/epRank before reusing it (call-site: the if (g_moeDevComm ==
nullptr) block and the early-reuse path). If parameters differ, either log/error
and abort or recreate the communicator by destroying the old g_moeDevComm and
calling create_moe_nccl_dev_comm(static_cast<int>(epSize),
static_cast<int>(epRank), /*num_lsa_barriers=*/2); ensure the check compares the
stored communicator's configuration (add fields or accessors to g_moeDevComm
metadata if needed) and include a clear diagnostic message mentioning
epSize/epRank and the expected vs actual values.


return metainfo;
}

Expand Down Expand Up @@ -186,6 +190,7 @@ std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(
using tensorrt_llm::kernels::moe_comm::PayloadDescriptor;
using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams;
using tensorrt_llm::kernels::moe_comm::moe_a2a_dispatch_launch;
using tensorrt_llm::kernels::moe_comm::moe_a2a_prepare_dispatch_launch;
using tensorrt_llm::kernels::moe_comm::kMaxTopK;
using tensorrt_llm::kernels::moe_comm::kMaxPayloads;

Expand Down Expand Up @@ -320,7 +325,7 @@ std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(
params.num_payloads = num_payloads;
std::copy(payloadDescriptors.begin(), payloadDescriptors.end(), &params.payloads[0]);

params.flag_val = reinterpret_cast<uint32_t*>(rankWorkSpacePtr + offsets[FLAG_VAL_OFFSET_INDEX]);
params.dev_comm = g_moeDevComm;
params.local_token_counter = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[LOCAL_TOKEN_COUNTER_OFFSET_INDEX]);
params.send_counters = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[SEND_COUNTERS_OFFSET_INDEX]);
params.topk_target_ranks = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[TOPK_TARGET_RANKS_OFFSET_INDEX]);
Expand All @@ -332,8 +337,6 @@ std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(

params.recv_counters[target_rank]
= reinterpret_cast<int*>(targetWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]);
params.completion_flags[target_rank]
= reinterpret_cast<uint32_t*>(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]);
if (enableEplb)
{
params.eplb_gathered_stats[target_rank]
Expand Down Expand Up @@ -362,7 +365,7 @@ std::tuple<std::vector<torch::Tensor>, int64_t, torch::Tensor> moeA2ADispatchOp(

params.stream = at::cuda::getCurrentCUDAStream();

// Prepare for dispatch (zero counters/indices and increment flag_val)
// Prepare for dispatch (zero counters/indices)
moe_a2a_prepare_dispatch_launch(params);

// Launch the dispatch kernel
Expand Down Expand Up @@ -417,6 +420,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
{
using tensorrt_llm::kernels::moe_comm::MoeA2ACombineParams;
using tensorrt_llm::kernels::moe_comm::moe_a2a_combine_launch;
using tensorrt_llm::kernels::moe_comm::moe_a2a_prepare_combine_launch;
using tensorrt_llm::kernels::moe_comm::kMaxTopK;

// Validate inputs
Expand Down Expand Up @@ -505,24 +509,23 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
params.elements_per_token = static_cast<int>(elementsPerToken);
params.dtype = nvDtype;

params.flag_val = reinterpret_cast<uint32_t*>(rankWorkSpacePtr + offsets[FLAG_VAL_OFFSET_INDEX]);
params.dev_comm = g_moeDevComm;
params.local_token_counter = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[LOCAL_TOKEN_COUNTER_OFFSET_INDEX]);
params.topk_target_ranks = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[TOPK_TARGET_RANKS_OFFSET_INDEX]);
params.topk_send_indices = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[TOPK_SEND_INDICES_OFFSET_INDEX]);
params.recv_counters = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]);

for (int target_rank = 0; target_rank < epSize; target_rank++)
{
uint8_t* target_workspace_ptr = workspacePtr + target_rank * workspace.stride(0);
params.completion_flags[target_rank]
= reinterpret_cast<uint32_t*>(target_workspace_ptr + offsets[COMBINE_COMPLETION_FLAGS_OFFSET_INDEX]);
params.recv_buffers[target_rank] = target_workspace_ptr + combinePayloadOffset;
}

params.stream = at::cuda::getCurrentCUDAStream();

moe_a2a_prepare_combine_launch(params);

// Launch the combine kernel
// Launch the combine kernel (includes per-CTA NCCL LSA barrier)
moe_a2a_combine_launch(params);
cudaError_t result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "moe_a2a_combine kernel launch failed: ", cudaGetErrorString(result));
Expand Down
6 changes: 0 additions & 6 deletions tensorrt_llm/_torch/distributed/moe_alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,12 @@ def _init_constants(cls):
if cls._METAINFO_INDEX is None:
thop = _tllm_internal.thop
cls._METAINFO_INDEX = {
"FLAG_VAL_OFFSET_INDEX":
int(thop.MOE_A2A_FLAG_VAL_OFFSET_INDEX),
"LOCAL_TOKEN_COUNTER_OFFSET_INDEX":
int(thop.MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX),
"SEND_COUNTERS_OFFSET_INDEX":
int(thop.MOE_A2A_SEND_COUNTERS_OFFSET_INDEX),
"RECV_COUNTERS_OFFSET_INDEX":
int(thop.MOE_A2A_RECV_COUNTERS_OFFSET_INDEX),
"DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX":
int(thop.MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX),
"COMBINE_COMPLETION_FLAGS_OFFSET_INDEX":
int(thop.MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX),
"TOPK_TARGET_RANKS_OFFSET_INDEX":
int(thop.MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX),
"TOPK_SEND_INDICES_OFFSET_INDEX":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,9 @@ class NVLinkOneSided(Communication):
_WORKSPACE: dict | None = None

# MetaInfo indices - initialized from C++ constants
FLAG_VAL_OFFSET_INDEX = None
LOCAL_TOKEN_COUNTER_OFFSET_INDEX = None
SEND_COUNTERS_OFFSET_INDEX = None
RECV_COUNTERS_OFFSET_INDEX = None
DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = None
COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = None
EPLB_GATHERED_STATS_OFFSET_INDEX = None
PAYLOAD_DATA_OFFSET_INDEX = None

Expand Down Expand Up @@ -119,20 +116,13 @@ def calculate_required_workspace_size(
@classmethod
def _init_constants(cls):
"""Initialize constants from C++ if not already done."""
if cls.FLAG_VAL_OFFSET_INDEX is None:
if cls.LOCAL_TOKEN_COUNTER_OFFSET_INDEX is None:
thop = _tllm_internal.thop
cls.FLAG_VAL_OFFSET_INDEX = int(thop.MOE_A2A_FLAG_VAL_OFFSET_INDEX)
cls.LOCAL_TOKEN_COUNTER_OFFSET_INDEX = int(
thop.MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX
)
cls.SEND_COUNTERS_OFFSET_INDEX = int(thop.MOE_A2A_SEND_COUNTERS_OFFSET_INDEX)
cls.RECV_COUNTERS_OFFSET_INDEX = int(thop.MOE_A2A_RECV_COUNTERS_OFFSET_INDEX)
cls.DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = int(
thop.MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX
)
cls.COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = int(
thop.MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX
)
cls.EPLB_GATHERED_STATS_OFFSET_INDEX = int(
thop.MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX
)
Expand Down
Loading
Loading