diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index 99a5e8c413d..4e109888907 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -20,8 +20,12 @@ #include "tensorrt_llm/common/vec_dtypes.cuh" #include "tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h" #include "tensorrt_llm/kernels/quantization.cuh" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h" #include #include +#include +#include #include TRTLLM_NAMESPACE_BEGIN @@ -34,10 +38,6 @@ using tensorrt_llm::common::launchWithPdlWhenEnabled; #define ENABLE_DEBUG_PRINT 0 #define DISABLE_SYNC_FOR_PROFILING 0 -#ifndef DISABLE_TIMEOUT -#define DISABLE_TIMEOUT 0 -#endif - // Macros for concise launch-time specialization #define SWITCH_BOOL(flag, NAME, ...) \ if (flag) \ @@ -147,13 +147,6 @@ using tensorrt_llm::common::launchWithPdlWhenEnabled; __VA_ARGS__ \ } -#if DISABLE_TIMEOUT -#define check_timeout(s) false -#else -// 300 * 2000 MHz - should be high enough on any GPU but will prevent a hang -#define check_timeout(s) ((clock64() - (s)) > (300ll * 2000ll * 1000ll * 1000ll)) -#endif - // ============================================================================ // Helper Functions for Expert-to-Rank Mapping // ============================================================================ @@ -344,8 +337,7 @@ __device__ void vectorized_dispatch(uint8_t const* src_ptr, int bytes_per_token, } } -__global__ void moeA2APrepareDispatchKernel( - int* send_counters, int* local_token_counter, int ep_size, uint32_t* flag_val_ptr) +__global__ void moeA2APrepareDispatchKernel(int* send_counters, int* local_token_counter, int ep_size) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); @@ -357,12 +349,10 @@ __global__ void moeA2APrepareDispatchKernel( { send_counters[idx] = 0; } - // Zero local_token_counter and increment flag_val + // Zero local_token_counter if (idx == 0) { *local_token_counter = 0; - // Increment flag_val for this dispatch round - *flag_val_ptr = *flag_val_ptr + 1; } } @@ -394,6 +384,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ else { // Threads that do not have a token to process should return. + // TODO: Not needed for one-block-per-token policy. If one-warp-per-token is deprecated, remove this check. if (local_token_idx >= local_num_tokens) return; @@ -504,6 +495,12 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ if (is_last_token) { + // Reset local_token_counter to 0 for using in the combine kernel. + if (lane_id == 0) + { + *ptrs.local_token_counter = 0; + } + // Store send_counters to recv_counters #pragma unroll 1 // No unroll as one iter is typically enough for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) @@ -528,54 +525,20 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ } #if !DISABLE_SYNC_FOR_PROFILING - uint32_t expected_value = *ptrs.flag_val; - +// Issue an release barrier to ensure that the data is visible to other ranks after the synchronization. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 // .acquire and .release qualifiers for fence instruction require sm_90 or higher. asm volatile("fence.release.sys;"); #else asm volatile("fence.acq_rel.sys;"); #endif -#pragma unroll 1 // No unroll as one iter is typically enough - for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) - { - uint32_t* flag_addr = &ptrs.completion_flags[target_rank][rank_id]; - asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value)); - -#if ENABLE_DEBUG_PRINT - printf("dispatch: +++Rank %d setting completion flag to %d for rank %d\n", rank_id, expected_value, - target_rank); -#endif - } -#pragma unroll 1 // No unroll - for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) - { - bool flag_set = false; - auto s = clock64(); - do - { - uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; - uint32_t flag_value; - // Acquire load to ensure visibility of peer's release-store - asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr)); -#if ENABLE_DEBUG_PRINT - printf( - "combine: ---Rank %d received completion flag from rank %d, flag_value: %d, expected_value: " - "%d, address: %p\n", - rank_id, peer_rank, flag_value, expected_value, flag_ptr); -#endif - flag_set = flag_value == expected_value; - } while (!flag_set && !check_timeout(s)); - - if (__builtin_expect(!flag_set, 0)) - { - printf("dispatch: ---Rank %d timed out waiting for completion flag from rank %d\n", rank_id, - peer_rank); - asm volatile("trap;"); - return; - } - } + // Use NCCL LSA barrier for inter-rank synchronization. + // Barrier index 0 is reserved for dispatch. + // Only the last-token warp (one warp across all CTAs) participates as a coop. + ncclCoopWarp coop = ncclCoopWarp(); + ncclLsaBarrierSession barrier(coop, *ptrs.dev_comm, ncclTeamTagLsa{}, /*index=*/0, /*multimem=*/true); + barrier.sync(coop, cuda::memory_order_relaxed); #endif } } @@ -584,7 +547,7 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params) { launchWithPdlWhenEnabled("moeA2APrepareDispatchKernel", moeA2APrepareDispatchKernel, 1, params.ep_size, 0, - params.stream, params.send_counters, params.local_token_counter, params.ep_size, params.flag_val); + params.stream, params.send_counters, params.local_token_counter, params.ep_size); } // ============================================================================ @@ -621,12 +584,8 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) } } - // Copy completion flag pointers - for (int i = 0; i < params.ep_size; i++) - { - kernel_ptrs.completion_flags[i] = params.completion_flags[i]; - } - kernel_ptrs.flag_val = params.flag_val; + // NCCL device communicator for LSA barrier + kernel_ptrs.dev_comm = params.dev_comm; // Copy communication tracking pointers kernel_ptrs.send_counters = params.send_counters; @@ -1003,19 +962,13 @@ __device__ void vectorized_combine( // Copy payload to recv buffer using vectorized copy; supports warp/block token mapping template __global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t const* payload_bytes, - int bytes_per_token, int ep_size, int max_tokens_per_rank, uint32_t* flag_val_ptr, int const* recv_counters) + int bytes_per_token, int ep_size, int max_tokens_per_rank, int const* recv_counters) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); #endif - if (blockIdx.x == 0 && threadIdx.x == 0) - { - // Increment flag_val for this combine round - *flag_val_ptr = *flag_val_ptr + 1; - } - if (payload_bytes == nullptr) { return; @@ -1053,92 +1006,73 @@ __global__ void moeA2ACombineKernel( const CombineKernelPointers ptrs, // Combine-specific struct, src_data_ptrs[0] is output int max_tokens_per_rank, int elements_per_token, int local_num_tokens, int rank_id, int ep_size) { - int local_token_idx = ThreadingPolicy::token_idx(); - int const size_per_token = elements_per_token * sizeof(T); - - if (local_num_tokens == 0) - { - // Special case: If local_num_tokens == 0, - // we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization. - // Other threads should return. - if (local_token_idx > 0) - return; - } - else - { - // Threads that do not have a token to process should return. - if (local_token_idx >= local_num_tokens) - return; - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - cudaGridDependencySynchronize(); - cudaTriggerProgrammaticLaunchCompletion(); -#endif - #if !DISABLE_SYNC_FOR_PROFILING - // In-kernel readiness synchronization at start of combine: - // - One warp signals readiness to all peers with current flag_val. - // - The first warp of each block waits for all peers' readiness (equality), then __syncthreads. - bool is_first_warp = threadIdx.x / warpSize == 0; - if (is_first_warp) - { - int lane_id = threadIdx.x % warpSize; - uint32_t expected_value = *ptrs.flag_val; - if (blockIdx.x == 0) + // The first warp in each CTA performs the synchronization. + { + bool is_first_warp = threadIdx.x / warpSize == 0; + + // local_token_counter is reused as a 3-state flag: + // 0 → no CTA elected yet + // -1 → one CTA elected, barrier in progress + // 1 → barrier done, peer data is visible + // The first CTA to arrive wins the election (CAS 0→-1) and performs the + // NCCL inter-rank barrier (index 1 for combine). All other CTAs spin until the flag becomes 1. + if (is_first_warp) { -#pragma unroll 1 // No unroll - for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) + int lane_id = threadIdx.x % warpSize; + bool elected = false; + if (lane_id == 0) { - uint32_t* flag_addr = &ptrs.completion_flags[peer_rank][rank_id]; - asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value)); -#if ENABLE_DEBUG_PRINT - printf("combine: +++Rank %d setting completion flag to %d for rank %d\n", rank_id, expected_value, - peer_rank); -#endif + int old = atomicCAS(ptrs.local_token_counter, 0, -1); + elected = (old == 0); } - } + elected = __shfl_sync(0xffffffff, elected, 0); -#pragma unroll 1 // No unroll - for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) - { - bool flag_set = false; - auto s = clock64(); - do + if (elected) { - uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; - uint32_t flag_value; - // Acquire load to ensure visibility of peer's release-store - asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr)); -#if ENABLE_DEBUG_PRINT - printf( - "combine: ---Rank %d received completion flag from rank %d, flag_value: %d, expected_value: " - "%d, " - "address: %p\n", - rank_id, peer_rank, flag_value, expected_value, flag_ptr); -#endif - flag_set = flag_value == expected_value; - } while (!flag_set && !check_timeout(s)); - - if (__builtin_expect(!flag_set, 0)) + ncclCoopWarp coop; + ncclLsaBarrierSession barrier(coop, *ptrs.dev_comm, ncclTeamTagLsa{}, /*index=*/1); + barrier.sync(coop, cuda::memory_order_relaxed); + // Signal all other blocks that the inter-rank barrier is done. + if (lane_id == 0) + { + asm volatile("st.relaxed.gpu.s32 [%0], %1;" ::"l"(ptrs.local_token_counter), "r"(1)); + } + } + else { - printf("combine: ---Rank %d timed out waiting for completion flag from rank %d\n", rank_id, peer_rank); - asm volatile("trap;"); - return; + // Wait for the elected CTA to complete the inter-rank barrier. + if (lane_id == 0) + { + int local_token_counter_val; + do + { + asm volatile("ld.relaxed.gpu.s32 %0, [%1];" + : "=r"(local_token_counter_val) + : "l"(ptrs.local_token_counter)); + } while (local_token_counter_val != 1); + } } - } + +// Acquire system-scope visibility so this block sees peer data. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 - // .acquire and .release qualifiers for fence instruction require sm_90 or higher. - asm volatile("fence.acquire.sys;"); + asm volatile("fence.acquire.sys;"); #else - asm volatile("fence.acq_rel.sys;"); + asm volatile("fence.acq_rel.sys;"); #endif + } + // Synchronize the rest warps in the CTA with the first warp. + __syncthreads(); } - __syncthreads(); #endif - if (local_num_tokens == 0) + int local_token_idx = ThreadingPolicy::token_idx(); + int const size_per_token = elements_per_token * sizeof(T); + + // Threads that do not have a token to process should return. + // TODO: Not needed for one-block-per-token policy. If one-warp-per-token is deprecated, remove this check. + if (local_token_idx >= local_num_tokens) return; // Get output location for this token (using src_data_ptrs[0] as output) @@ -1178,7 +1112,7 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params) auto kernel_fn = params.one_block_per_token ? moeA2APrepareCombineKernel : moeA2APrepareCombineKernel; launchWithPdlWhenEnabled("moeA2APrepareCombineKernel", kernel_fn, grid, kBlockSize, 0, params.stream, - recv_buffer_bytes, payload_bytes, bytes_per_token, params.ep_size, params.max_tokens_per_rank, params.flag_val, + recv_buffer_bytes, payload_bytes, bytes_per_token, params.ep_size, params.max_tokens_per_rank, params.recv_counters); } @@ -1197,16 +1131,20 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) // Configure kernel launch int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize(); int const kWarpsPerBlock = kBlockSize / 32; // warpSize - int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock); - int grid_size_block = params.local_num_tokens; - // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization. - if (grid_size_warp == 0) + + int grid_size; + if (params.one_block_per_token) { - grid_size_warp = 1; + grid_size = params.local_num_tokens; } - if (grid_size_block == 0) + else { - grid_size_block = 1; + grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock); + } + // Even if local_num_tokens is 0, we need at least 1 block for the barrier. + if (grid_size == 0) + { + grid_size = 1; } // Prepare kernel pointers struct for combine @@ -1221,25 +1159,20 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) kernel_ptrs.recv_buffers[rank][0] = params.recv_buffers[rank]; } - // Copy completion flag pointers - for (int i = 0; i < params.ep_size; i++) - { - kernel_ptrs.completion_flags[i] = params.completion_flags[i]; - } - kernel_ptrs.flag_val = params.flag_val; + // NCCL device communicator and intra-rank barrier flag + kernel_ptrs.dev_comm = params.dev_comm; + kernel_ptrs.local_token_counter = params.local_token_counter; // Copy communication tracking pointers kernel_ptrs.topk_target_ranks = params.topk_target_ranks; kernel_ptrs.topk_send_indices = params.topk_send_indices; - int grid = params.one_block_per_token ? grid_size_block : grid_size_warp; - // Launch appropriate kernel with compact macros SWITCH_DTYPE(params.dtype, TKernelType, { SWITCH_POLICY(params.one_block_per_token, Policy, { SWITCH_TOP_K(params.top_k, TOP_K, { auto kernel_fn = moeA2ACombineKernel; - launchWithPdlWhenEnabled("moeA2ACombineKernel", kernel_fn, grid, kBlockSize, 0, params.stream, + launchWithPdlWhenEnabled("moeA2ACombineKernel", kernel_fn, grid_size, kBlockSize, 0, params.stream, kernel_ptrs, params.max_tokens_per_rank, params.elements_per_token, params.local_num_tokens, params.ep_rank, params.ep_size); }); @@ -1283,6 +1216,65 @@ void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv stream, expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id); } +// ============================================================================ +// NCCL DevComm Lifecycle Helpers +// ============================================================================ + +// Static storage for the NCCL communicator used by the DevComm. +// The NCCL comm is kept alive because ncclDevCommDestroy needs it. +static ncclComm_t g_moeNcclComm = nullptr; + +ncclDevComm* create_moe_nccl_dev_comm(int ep_size, int ep_rank, int num_lsa_barriers) +{ + // 1. Create NCCL communicator via MPI bootstrap (collective call) + ncclUniqueId ncclId; + if (ep_rank == 0) + { + TLLM_NCCL_CHECK(ncclGetUniqueId(&ncclId)); + } + tensorrt_llm::mpi::MpiComm::world().bcastValue(ncclId, 0); + + ncclComm_t ncclComm; + TLLM_NCCL_CHECK(ncclCommInitRank(&ncclComm, ep_size, ncclId, ep_rank)); + + // 2. Create device communicator with LSA barrier support (collective call) + ncclDevCommRequirements reqs = {}; + reqs.lsaBarrierCount = num_lsa_barriers; + reqs.lsaMultimem = true; + + ncclDevComm hostDevComm; + TLLM_NCCL_CHECK(ncclDevCommCreate(ncclComm, &reqs, &hostDevComm)); + + // 3. Copy to device memory so kernels can access it via pointer + ncclDevComm* deviceDevComm = nullptr; + TLLM_CUDA_CHECK(cudaMalloc(&deviceDevComm, sizeof(ncclDevComm))); + TLLM_CUDA_CHECK(cudaMemcpy(deviceDevComm, &hostDevComm, sizeof(ncclDevComm), cudaMemcpyHostToDevice)); + + // Store the NCCL comm so we can clean up later + g_moeNcclComm = ncclComm; + + return deviceDevComm; +} + +void destroy_moe_nccl_dev_comm(ncclDevComm* dev_comm) +{ + if (dev_comm != nullptr) + { + // Copy back to host for ncclDevCommDestroy + ncclDevComm hostDevComm; + cudaMemcpy(&hostDevComm, dev_comm, sizeof(ncclDevComm), cudaMemcpyDeviceToHost); + + if (g_moeNcclComm != nullptr) + { + ncclDevCommDestroy(g_moeNcclComm, &hostDevComm); + ncclCommDestroy(g_moeNcclComm); + g_moeNcclComm = nullptr; + } + + cudaFree(dev_comm); + } +} + } // namespace kernels::moe_comm TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h index 942b2424bbc..ad9f6dffc27 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h @@ -20,6 +20,11 @@ #include #include +// Forward declarations for NCCL device API types +struct ncclComm; +typedef struct ncclComm* ncclComm_t; +struct ncclDevComm; + TRTLLM_NAMESPACE_BEGIN namespace kernels::moe_comm @@ -47,10 +52,8 @@ struct DispatchKernelPointers void* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers int payload_bytes_per_token[kMaxPayloads]; // Bytes per token for each payload - // Completion flags for synchronization - uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source - // rank has signaled the target rank - uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + // NCCL device communicator for LSA barrier synchronization + ncclDevComm const* dev_comm; // Local aux data pointers int* send_counters; // [ep_size] How many tokens have been sent to each target rank @@ -74,10 +77,12 @@ struct CombineKernelPointers void* src_data_ptrs[kMaxPayloads]; // src_data_ptrs[0] is output void const* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers (const) - // Completion flags for synchronization - uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source - // rank has signaled the target rank - uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + // NCCL device communicator for LSA barrier synchronization + ncclDevComm const* dev_comm; + + // Atomic counter reused for intra-rank barrier notification. + // Block 0 sets it to 1 after NCCL barrier; other blocks spin on it. + int* local_token_counter; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]) int const* topk_target_ranks; // target rank per k, -1 for duplicates @@ -107,8 +112,10 @@ struct MoeA2ADispatchParams int num_payloads; // Number of different payload types PayloadDescriptor payloads[kMaxPayloads]; // Array of payload descriptors + // NCCL device communicator for LSA barrier synchronization + ncclDevComm const* dev_comm; + // Local aux data - uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) int* local_token_counter; // Atomic counter for completed tokens on this rank int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank @@ -118,8 +125,6 @@ struct MoeA2ADispatchParams // Distributed aux data and recv buffers int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has [ep_size] counters - uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source - // rank has signaled the target rank void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload // Optional: Statistics for EPLB @@ -134,7 +139,7 @@ struct MoeA2ADispatchParams // Dispatch kernels void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params); -// Prepare for dispatch: zero send_counters, local_token_counter and increment flag_val +// Prepare for dispatch: zero send_counters and local_token_counter void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params); // Combine phase parameters @@ -160,8 +165,11 @@ struct MoeA2ACombineParams int elements_per_token; // Number of elements per token nvinfer1::DataType dtype; // Data type for proper summation + // NCCL device communicator for LSA barrier synchronization + ncclDevComm const* dev_comm; + // Local aux data - uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + int* local_token_counter; // Reused as a flag for intra-rank synchronization for combine int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank // per k, -1 for duplicates int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index @@ -169,9 +177,7 @@ struct MoeA2ACombineParams int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target // Distributed aux data and recv buffers - uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source - // rank has signaled the target rank - void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload) + void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload) // CUDA stream cudaStream_t stream; @@ -182,6 +188,14 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params); void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params); +// NCCL DevComm lifecycle helpers (defined in .cu, callable from host code). +// create_moe_nccl_dev_comm creates an NCCL communicator + device communicator with the +// specified number of LSA barriers. The returned pointer is a DEVICE pointer to ncclDevComm +// that can be passed directly to kernel parameters. +// This is a collective call -- all ranks in the group must call it. +ncclDevComm* create_moe_nccl_dev_comm(int ep_size, int ep_rank, int num_lsa_barriers); +void destroy_moe_nccl_dev_comm(ncclDevComm* dev_comm); + // Sanitize expert IDs for invalid tokens // expert_ids: [ep_size, max_tokens_per_rank, top_k] (int32) // recv_counters: [ep_size] (int32), number of valid tokens per source diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h b/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h index 76f083fde44..2ec88a72c6a 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h +++ b/cpp/tensorrt_llm/thop/moeAlltoAllMeta.h @@ -33,19 +33,14 @@ namespace moe_comm // Enum for indexing into moe_a2a_metainfo tensor enum MoeA2AMetaInfoIndex : int64_t { - FLAG_VAL_OFFSET_INDEX = 0, - LOCAL_TOKEN_COUNTER_OFFSET_INDEX = 1, - SEND_COUNTERS_OFFSET_INDEX = 2, - RECV_COUNTERS_OFFSET_INDEX = 3, - // Dispatch completion flags offset - DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = 4, - // Combine completion flags offset - COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5, - TOPK_TARGET_RANKS_OFFSET_INDEX = 6, - TOPK_SEND_INDICES_OFFSET_INDEX = 7, - EPLB_GATHERED_STATS_OFFSET_INDEX = 8, - PAYLOAD_DATA_OFFSET_INDEX = 9, - NUM_METAINFO_FIELDS = 10 + LOCAL_TOKEN_COUNTER_OFFSET_INDEX = 0, + SEND_COUNTERS_OFFSET_INDEX = 1, + RECV_COUNTERS_OFFSET_INDEX = 2, + TOPK_TARGET_RANKS_OFFSET_INDEX = 3, + TOPK_SEND_INDICES_OFFSET_INDEX = 4, + EPLB_GATHERED_STATS_OFFSET_INDEX = 5, + PAYLOAD_DATA_OFFSET_INDEX = 6, + NUM_METAINFO_FIELDS = 7 }; using MoeA2ADataOffsets = std::array; @@ -53,12 +48,9 @@ using MoeA2ADataOffsets = std::array; inline std::vector> getMoeA2AMetaInfoIndexPairs() { return { - {"MOE_A2A_FLAG_VAL_OFFSET_INDEX", FLAG_VAL_OFFSET_INDEX}, {"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", LOCAL_TOKEN_COUNTER_OFFSET_INDEX}, {"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", SEND_COUNTERS_OFFSET_INDEX}, {"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", RECV_COUNTERS_OFFSET_INDEX}, - {"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX", DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX}, - {"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX}, {"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX}, {"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX}, {"MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX", EPLB_GATHERED_STATS_OFFSET_INDEX}, diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index d81ae4e3990..f468e2451b7 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -25,6 +25,9 @@ #include #include +// Forward declaration for NCCL device communicator (defined in moeAlltoAllKernels.cu) +struct ncclDevComm; + TRTLLM_NAMESPACE_BEGIN namespace torch_ext @@ -52,10 +55,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu MoeA2ADataOffsets offsets; size_t offset = 0; - // flag_val - offsets[FLAG_VAL_OFFSET_INDEX] = offset; - offset += SIZEOF_INT32; - // local_token_counter offsets[LOCAL_TOKEN_COUNTER_OFFSET_INDEX] = offset; offset += SIZEOF_INT32; @@ -68,16 +67,6 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu offsets[RECV_COUNTERS_OFFSET_INDEX] = offset; offset += epSize * SIZEOF_INT32; - // dispatch completion flags - offset = alignOffset(offset, CACHELINE_ALIGNMENT); - offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX] = offset; - offset += epSize * SIZEOF_INT32; - - // combine completion flags - offset = alignOffset(offset, CACHELINE_ALIGNMENT); - offsets[COMBINE_COMPLETION_FLAGS_OFFSET_INDEX] = offset; - offset += epSize * SIZEOF_INT32; - // topk_target_ranks: [maxNumTokens, kMaxTopK] offset = alignOffset(offset, CACHELINE_ALIGNMENT); offsets[TOPK_TARGET_RANKS_OFFSET_INDEX] = offset; @@ -103,7 +92,7 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu } // Initialize auxiliary data in workspace -// This function sets up the initial values for flag_val and completion_flags +// This function sets up the initial workspace values and creates the NCCL device communicator // // Inputs: // - workspace: [ep_size, size_per_rank] unified virtual memory workspace @@ -114,6 +103,10 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens, int eplbStatsNu // // Returns: // - metainfo: Tensor containing offsets for auxiliary data + +// Global storage for NCCL device communicator (persistent across dispatch/combine calls). +static ncclDevComm* g_moeDevComm = nullptr; + torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens, torch::optional eplbStatsNumExperts) { @@ -142,10 +135,21 @@ torch::Tensor moeA2AInitializeOp(torch::Tensor const& workspace, int64_t epRank, metainfo[i] = static_cast(offsets[i]); } - // Synchronize among ranks + // Synchronize among ranks before NCCL initialization cudaDeviceSynchronize(); tensorrt_llm::mpi::MpiComm::world().barrier(); + // Create NCCL device communicator with 2 LSA barriers: + // - Index 0: dispatch barrier (warp 0 of last-token CTA) + // - Index 1: combine barrier (warp 0 of an elected CTA, other CTAs wait via local_token_counter) + // This is a collective call -- all EP ranks must participate. + if (g_moeDevComm == nullptr) + { + using tensorrt_llm::kernels::moe_comm::create_moe_nccl_dev_comm; + g_moeDevComm + = create_moe_nccl_dev_comm(static_cast(epSize), static_cast(epRank), /*num_lsa_barriers=*/2); + } + return metainfo; } @@ -186,6 +190,7 @@ std::tuple, int64_t, torch::Tensor> moeA2ADispatchOp( using tensorrt_llm::kernels::moe_comm::PayloadDescriptor; using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams; using tensorrt_llm::kernels::moe_comm::moe_a2a_dispatch_launch; + using tensorrt_llm::kernels::moe_comm::moe_a2a_prepare_dispatch_launch; using tensorrt_llm::kernels::moe_comm::kMaxTopK; using tensorrt_llm::kernels::moe_comm::kMaxPayloads; @@ -320,7 +325,7 @@ std::tuple, int64_t, torch::Tensor> moeA2ADispatchOp( params.num_payloads = num_payloads; std::copy(payloadDescriptors.begin(), payloadDescriptors.end(), ¶ms.payloads[0]); - params.flag_val = reinterpret_cast(rankWorkSpacePtr + offsets[FLAG_VAL_OFFSET_INDEX]); + params.dev_comm = g_moeDevComm; params.local_token_counter = reinterpret_cast(rankWorkSpacePtr + offsets[LOCAL_TOKEN_COUNTER_OFFSET_INDEX]); params.send_counters = reinterpret_cast(rankWorkSpacePtr + offsets[SEND_COUNTERS_OFFSET_INDEX]); params.topk_target_ranks = reinterpret_cast(rankWorkSpacePtr + offsets[TOPK_TARGET_RANKS_OFFSET_INDEX]); @@ -332,8 +337,6 @@ std::tuple, int64_t, torch::Tensor> moeA2ADispatchOp( params.recv_counters[target_rank] = reinterpret_cast(targetWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]); - params.completion_flags[target_rank] - = reinterpret_cast(targetWorkSpacePtr + offsets[DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]); if (enableEplb) { params.eplb_gathered_stats[target_rank] @@ -362,7 +365,7 @@ std::tuple, int64_t, torch::Tensor> moeA2ADispatchOp( params.stream = at::cuda::getCurrentCUDAStream(); - // Prepare for dispatch (zero counters/indices and increment flag_val) + // Prepare for dispatch (zero counters/indices) moe_a2a_prepare_dispatch_launch(params); // Launch the dispatch kernel @@ -417,6 +420,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke { using tensorrt_llm::kernels::moe_comm::MoeA2ACombineParams; using tensorrt_llm::kernels::moe_comm::moe_a2a_combine_launch; + using tensorrt_llm::kernels::moe_comm::moe_a2a_prepare_combine_launch; using tensorrt_llm::kernels::moe_comm::kMaxTopK; // Validate inputs @@ -505,7 +509,8 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke params.elements_per_token = static_cast(elementsPerToken); params.dtype = nvDtype; - params.flag_val = reinterpret_cast(rankWorkSpacePtr + offsets[FLAG_VAL_OFFSET_INDEX]); + params.dev_comm = g_moeDevComm; + params.local_token_counter = reinterpret_cast(rankWorkSpacePtr + offsets[LOCAL_TOKEN_COUNTER_OFFSET_INDEX]); params.topk_target_ranks = reinterpret_cast(rankWorkSpacePtr + offsets[TOPK_TARGET_RANKS_OFFSET_INDEX]); params.topk_send_indices = reinterpret_cast(rankWorkSpacePtr + offsets[TOPK_SEND_INDICES_OFFSET_INDEX]); params.recv_counters = reinterpret_cast(rankWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]); @@ -513,8 +518,6 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke for (int target_rank = 0; target_rank < epSize; target_rank++) { uint8_t* target_workspace_ptr = workspacePtr + target_rank * workspace.stride(0); - params.completion_flags[target_rank] - = reinterpret_cast(target_workspace_ptr + offsets[COMBINE_COMPLETION_FLAGS_OFFSET_INDEX]); params.recv_buffers[target_rank] = target_workspace_ptr + combinePayloadOffset; } @@ -522,7 +525,7 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke moe_a2a_prepare_combine_launch(params); - // Launch the combine kernel + // Launch the combine kernel (includes per-CTA NCCL LSA barrier) moe_a2a_combine_launch(params); cudaError_t result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, "moe_a2a_combine kernel launch failed: ", cudaGetErrorString(result)); diff --git a/tensorrt_llm/_torch/distributed/moe_alltoall.py b/tensorrt_llm/_torch/distributed/moe_alltoall.py index cc593499c7c..0825cd942d7 100644 --- a/tensorrt_llm/_torch/distributed/moe_alltoall.py +++ b/tensorrt_llm/_torch/distributed/moe_alltoall.py @@ -94,18 +94,12 @@ def _init_constants(cls): if cls._METAINFO_INDEX is None: thop = _tllm_internal.thop cls._METAINFO_INDEX = { - "FLAG_VAL_OFFSET_INDEX": - int(thop.MOE_A2A_FLAG_VAL_OFFSET_INDEX), "LOCAL_TOKEN_COUNTER_OFFSET_INDEX": int(thop.MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX), "SEND_COUNTERS_OFFSET_INDEX": int(thop.MOE_A2A_SEND_COUNTERS_OFFSET_INDEX), "RECV_COUNTERS_OFFSET_INDEX": int(thop.MOE_A2A_RECV_COUNTERS_OFFSET_INDEX), - "DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX": - int(thop.MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX), - "COMBINE_COMPLETION_FLAGS_OFFSET_INDEX": - int(thop.MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX), "TOPK_TARGET_RANKS_OFFSET_INDEX": int(thop.MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX), "TOPK_SEND_INDICES_OFFSET_INDEX": diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index df5b834e289..b25698b1889 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -59,12 +59,9 @@ class NVLinkOneSided(Communication): _WORKSPACE: dict | None = None # MetaInfo indices - initialized from C++ constants - FLAG_VAL_OFFSET_INDEX = None LOCAL_TOKEN_COUNTER_OFFSET_INDEX = None SEND_COUNTERS_OFFSET_INDEX = None RECV_COUNTERS_OFFSET_INDEX = None - DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = None - COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = None EPLB_GATHERED_STATS_OFFSET_INDEX = None PAYLOAD_DATA_OFFSET_INDEX = None @@ -119,20 +116,13 @@ def calculate_required_workspace_size( @classmethod def _init_constants(cls): """Initialize constants from C++ if not already done.""" - if cls.FLAG_VAL_OFFSET_INDEX is None: + if cls.LOCAL_TOKEN_COUNTER_OFFSET_INDEX is None: thop = _tllm_internal.thop - cls.FLAG_VAL_OFFSET_INDEX = int(thop.MOE_A2A_FLAG_VAL_OFFSET_INDEX) cls.LOCAL_TOKEN_COUNTER_OFFSET_INDEX = int( thop.MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX ) cls.SEND_COUNTERS_OFFSET_INDEX = int(thop.MOE_A2A_SEND_COUNTERS_OFFSET_INDEX) cls.RECV_COUNTERS_OFFSET_INDEX = int(thop.MOE_A2A_RECV_COUNTERS_OFFSET_INDEX) - cls.DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = int( - thop.MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX - ) - cls.COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = int( - thop.MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX - ) cls.EPLB_GATHERED_STATS_OFFSET_INDEX = int( thop.MOE_A2A_EPLB_GATHERED_STATS_OFFSET_INDEX ) diff --git a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py index ddca5fb8ab0..38ec0427750 100644 --- a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py +++ b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py @@ -278,22 +278,6 @@ def run_moe_a2a_dispatch_single_rank(ep_size, all_num_tokens, top_k, expert_id_payload_index=expert_id_payload_index, eplb_local_stats=eplb_local_stats) - # Verify completion flags after dispatch - completion_flags_offset = moe_a2a.metainfo[MoeAlltoAll._METAINFO_INDEX[ - "DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX"]].item() - completion_flags = moe_a2a.workspace[ - rank, completion_flags_offset:completion_flags_offset + - ep_size * 4].view(torch.int32).cpu() - flag_val_offset = moe_a2a.metainfo[ - MoeAlltoAll._METAINFO_INDEX["FLAG_VAL_OFFSET_INDEX"]].item() - expected_flag_val = moe_a2a.workspace[rank, - flag_val_offset:flag_val_offset + - 4].view(torch.int32).cpu() - - assert torch.all(completion_flags == expected_flag_val), ( - f"Rank {rank} completion flags: {completion_flags}, expected flag val: {expected_flag_val}" - ) - # Read counters and compact routing tensors from workspace send_counters_offset = moe_a2a.metainfo[ MoeAlltoAll._METAINFO_INDEX["SEND_COUNTERS_OFFSET_INDEX"]].item() @@ -736,21 +720,6 @@ def run_moe_a2a_dispatch_moe_combine_single_rank(ep_size, all_num_tokens, top_k, combined_output = moe_a2a.combine(hidden_states_recv, max_num_tokens) - # Verify completion flags after combine - completion_flags_offset = moe_a2a.metainfo[MoeAlltoAll._METAINFO_INDEX[ - "COMBINE_COMPLETION_FLAGS_OFFSET_INDEX"]].item() - completion_flags_ptr = moe_a2a.workspace[ - rank, completion_flags_offset:completion_flags_offset + ep_size * 4] - completion_flags = completion_flags_ptr.view(torch.int32).cpu() - flag_val_offset = moe_a2a.metainfo[ - MoeAlltoAll._METAINFO_INDEX["FLAG_VAL_OFFSET_INDEX"]].item() - expected_flag_val = moe_a2a.workspace[rank, - flag_val_offset:flag_val_offset + - 4].view(torch.int32).cpu() - assert torch.all(completion_flags == expected_flag_val), ( - f"Rank {rank} completion flags: {completion_flags}, expected flag val: {expected_flag_val}" - ) - # Return results for verification return ( token_selected_experts.cpu(),