diff --git a/.azure-pipelines/templates/integration-test.yaml b/.azure-pipelines/templates/integration-test.yaml index ea0668be2..32fcffcd0 100644 --- a/.azure-pipelines/templates/integration-test.yaml +++ b/.azure-pipelines/templates/integration-test.yaml @@ -224,7 +224,9 @@ steps: export PATH=/usr/local/mpi/bin:\$PATH; \ export LD_LIBRARY_PATH=/root/mscclpp/build:\$LD_LIBRARY_PATH; \ cd /root/mscclpp; \ - ./build/test/perf/fifo_test"' + ./build/test/perf/fifo_test; \ + echo \"mpirun --allow-run-as-root -np 2 ./build/test/perf/fifo_test_multi_gpu_data_transfer\"; \ + mpirun --allow-run-as-root -np 2 ./build/test/perf/fifo_test_multi_gpu_data_transfer"' kill $CHILD_PID workingDirectory: '$(System.DefaultWorkingDirectory)' diff --git a/test/perf/CMakeLists.txt b/test/perf/CMakeLists.txt index 56f9e8c2e..2f875c36a 100644 --- a/test/perf/CMakeLists.txt +++ b/test/perf/CMakeLists.txt @@ -40,3 +40,4 @@ endfunction() # Add FIFO test add_perf_test_executable(fifo_test "framework.cc;fifo_test.cu") +add_perf_test_executable(fifo_test_multi_gpu_data_transfer "framework.cc;fifo_test_multi_gpu_data_transfer.cu") diff --git a/test/perf/fifo_test_multi_gpu_data_transfer.cu b/test/perf/fifo_test_multi_gpu_data_transfer.cu new file mode 100644 index 000000000..95444476e --- /dev/null +++ b/test/perf/fifo_test_multi_gpu_data_transfer.cu @@ -0,0 +1,411 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "framework.hpp" + +using namespace mscclpp::test; + +__constant__ mscclpp::PortChannelDeviceHandle gPortChannel; + +// New kernels for bidirectional data transfer +__global__ void kernelPutData(int* sendBuffer, mscclpp::PortChannelDeviceHandle portHandle, int numElements, int rank) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (tid == 0) { + portHandle.put(0, 0, numElements * sizeof(int)); + } + + // Only thread 0 signals completion + if (tid == 0) { + portHandle.signal(); + } +} + +__global__ void kernelGetData(int* recvBuffer, mscclpp::PortChannelDeviceHandle portHandle, int numElements, int rank, + int expectedValue) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int totalThreads = blockDim.x * gridDim.x; + + // Wait for signal from sender - only global thread 0 should wait + if (tid == 0) { + portHandle.wait(); + } + + __shared__ int errorCount; + if (threadIdx.x == 0) { + errorCount = 0; + } + __syncthreads(); + + int localErrors = 0; + + // Each thread validates a portion of the received data + for (int i = tid; i < numElements; i += totalThreads) { + if (recvBuffer[i] != expectedValue) { + localErrors++; + } + } + + // Accumulate errors from all threads + if (localErrors > 0) { + atomicAdd(&errorCount, localErrors); + } + + // Report validation results from thread 0 + __syncthreads(); + if (tid == 0) { + if (errorCount == 0) { + printf("GPU%d: Data validation PASSED - all %d elements correct (expected value: %d)\n", rank, numElements, + expectedValue); + } else { + printf("GPU%d: Data validation FAILED - %d errors found out of %d elements (expected value: %d)\n", rank, + errorCount, numElements, expectedValue); + } + assert(errorCount == 0); + } +} + +static void setupCuda(int& cudaDevice, int& numaNode) { + utils::CUDA_CHECK(cudaGetDevice(&cudaDevice)); + numaNode = mscclpp::getDeviceNumaNode(cudaDevice); + mscclpp::numaBind(numaNode); +} + +std::tuple runDataTransferKernelVariant(cudaStream_t stream, int numParallel, int rank, + mscclpp::PortChannelDeviceHandle portChannelHandle, + int* sendBuffer, int* recvBuffer, int numElements) { + int threadsPerBlock = std::min(numParallel, 512); + int threadBlocks = (numParallel + threadsPerBlock - 1) / threadsPerBlock; + threadBlocks = std::max(1, threadBlocks); // Ensure at least 1 block + + // Benchmark + utils::Timer timer; + timer.start(); + + // Launch both put and get operations simultaneously on each GPU for bidirectional transfer + if (rank == 0) { + // GPU0: Send data (value 1) and receive data (expecting value 2 from GPU1) + kernelPutData<<>>(sendBuffer, portChannelHandle, numElements, rank); + utils::CUDA_CHECK(cudaGetLastError()); + + kernelGetData<<>>(recvBuffer, portChannelHandle, numElements, rank, 2); + utils::CUDA_CHECK(cudaGetLastError()); + } else if (rank == 1) { + // GPU1: Send data (value 2) and receive data (expecting value 1 from GPU0) + kernelPutData<<>>(sendBuffer, portChannelHandle, numElements, rank); + utils::CUDA_CHECK(cudaGetLastError()); + + kernelGetData<<>>(recvBuffer, portChannelHandle, numElements, rank, 1); + utils::CUDA_CHECK(cudaGetLastError()); + } + + utils::CUDA_CHECK(cudaStreamSynchronize(stream)); + + timer.stop(); + + const int totalElements = numElements; + double throughput = totalElements / timer.elapsedSeconds(); + double duration_us = timer.elapsedMicroseconds(); + + utils::CUDA_CHECK(cudaDeviceSynchronize()); + + return {throughput, duration_us, totalElements}; +} + +void runDataTransferTestVariant(cudaStream_t stream, int numParallel, nlohmann::ordered_json& combinedMetrics, int rank, + mscclpp::PortChannelDeviceHandle portChannelHandle, int* sendBuffer, int* recvBuffer, + int numElements) { + // Run simultaneous bidirectional data transfer + printf("=== Running simultaneous bidirectional GPU0 ↔ GPU1 transfer ===\n"); + auto [throughput, duration, totalElements] = + runDataTransferKernelVariant(stream, numParallel, rank, portChannelHandle, sendBuffer, recvBuffer, numElements); + + auto formatThroughput = [](double thru) { + return double(int(thru * 10)) / 10.0; // Round to 1 decimal place + }; + + std::string prefix = "p" + std::to_string(numParallel) + "_"; + combinedMetrics[prefix + "data_throughput_elements_per_sec"] = formatThroughput(throughput); + combinedMetrics[prefix + "data_duration_us"] = duration; + combinedMetrics[prefix + "total_elements"] = totalElements; + combinedMetrics[prefix + "bandwidth_GB_per_sec"] = + formatThroughput((totalElements * sizeof(int)) / (duration / 1e6) / 1e9); +} + +struct FifoTestConfig { + int fifoSize; + std::vector parallelismLevels; + + // Constructor with default parallelism levels + FifoTestConfig(int size, const std::vector& parallel = {1, 2, 4, 8, 16}) + : fifoSize(size), parallelismLevels(parallel) {} +}; + +void runDataTransferTest(const FifoTestConfig& config, const mscclpp::test::TestContext& context) { + int rank = context.rank; + int worldSize = context.size; + auto communicator = context.communicator; + auto bootstrap = context.bootstrap; + + if (config.fifoSize <= 0) { + throw std::invalid_argument("FIFO size must be positive"); + } + if (config.parallelismLevels.empty()) { + throw std::invalid_argument("At least one parallelism level must be specified"); + } + + // Set the device for this process + cudaSetDevice(rank); + + // Define buffer size and allocate memory + const int nElem = 1024; + const int halfElements = nElem / 2; + std::shared_ptr buff = mscclpp::GpuBuffer(nElem).memory(); + + // Split buffer into send and receive halves + int* sendBuffer = buff.get(); + int* recvBuffer = buff.get() + halfElements; + + // Initialize send buffer with test data + int initValue = rank + 1; // GPU0 uses 1, GPU1 uses 2 + std::vector hostBuffer(halfElements, initValue); + cudaMemcpy(sendBuffer, hostBuffer.data(), halfElements * sizeof(int), cudaMemcpyHostToDevice); + + // Initialize receive buffer to zero + cudaMemset(recvBuffer, 0, halfElements * sizeof(int)); + + // Setup transport + mscclpp::TransportFlags transport = mscclpp::Transport::CudaIpc; + std::vector> connections; + if (worldSize > 1) { + for (int i = 0; i < worldSize; i++) { + if (i == rank) { + continue; + } + // Use different IB transports for different ranks + std::vector ibTransports{mscclpp::Transport::IB0, mscclpp::Transport::IB1}; + mscclpp::Transport selectedTransport = ibTransports[rank % ibTransports.size()]; + transport |= selectedTransport; + connections.push_back(communicator->connect(selectedTransport, i).get()); + } + } + + // Wait for all connections to be established + bootstrap->barrier(); + + // Create and start proxy service with specified FIFO size + auto proxyService = std::make_shared(config.fifoSize); + proxyService->startProxy(); + + // Register send buffer memory (first half) + mscclpp::RegisteredMemory sendBufRegMem = + communicator->registerMemory(sendBuffer, halfElements * sizeof(int), transport); + + // Register receive buffer memory (second half) + mscclpp::RegisteredMemory recvBufRegMem = + communicator->registerMemory(recvBuffer, halfElements * sizeof(int), transport); + + // Exchange memory with other ranks + std::vector> remoteSendMemFutures(worldSize); + std::vector> remoteRecvMemFutures(worldSize); + + for (int r = 0; r < worldSize; r++) { + if (r == rank) { + continue; + } + // Send our buffer info to other ranks + communicator->sendMemory(sendBufRegMem, r, 0); // tag 0 for send buffer + communicator->sendMemory(recvBufRegMem, r, 1); // tag 1 for recv buffer + + // Receive other ranks' buffer info + remoteSendMemFutures[r] = communicator->recvMemory(r, 0); + remoteRecvMemFutures[r] = communicator->recvMemory(r, 1); + } + + // Allocate and setup local semaphore flag + uint64_t* localSemaphoreFlag; + cudaMalloc(&localSemaphoreFlag, sizeof(uint64_t)); + cudaMemset(localSemaphoreFlag, 0, sizeof(uint64_t)); + + // Register semaphore flag + auto localFlagRegmem = communicator->registerMemory(localSemaphoreFlag, sizeof(uint64_t), transport); + + int cudaDevice, numaNode; + setupCuda(cudaDevice, numaNode); + + cudaStream_t stream; + utils::CUDA_CHECK(cudaStreamCreate(&stream)); + + // Create test name with parallelism range + std::string testName = "FifoDataTransferTest_Size" + std::to_string(config.fifoSize) + "_Parallel"; + + // Add parallelism range to test name (e.g., "P1-16" or "P1-4-16-64") + if (!config.parallelismLevels.empty()) { + testName += std::to_string(config.parallelismLevels.front()); + if (config.parallelismLevels.size() > 1) { + testName += "-" + std::to_string(config.parallelismLevels.back()); + + // If parallelism levels have non-standard steps, include more detail + if (config.parallelismLevels.size() > 2 && + (config.parallelismLevels[1] != 2 * config.parallelismLevels[0] || config.parallelismLevels.size() > 3)) { + testName = "FifoTest_Size" + std::to_string(config.fifoSize) + "_ParallelCustom"; + } + } + } + + // Print test configuration + if (utils::isMainRank()) { + std::stringstream ss; + ss << "Running FIFO test with size=" << config.fifoSize << ", parallelism_levels=["; + for (size_t i = 0; i < config.parallelismLevels.size(); ++i) { + if (i > 0) ss << ","; + ss << config.parallelismLevels[i]; + } + ss << "]"; + std::cout << ss.str() << std::endl; + } + + nlohmann::ordered_json combinedMetrics; + + // Prepare variables for the test variant + mscclpp::SemaphoreId semaphoreId = 0; + std::shared_ptr connection = nullptr; + mscclpp::RegisteredMemory remoteFlagRegMem = localFlagRegmem; + mscclpp::PortChannelDeviceHandle portChannelHandle; + + if (worldSize >= 2 && !connections.empty()) { + int peerRank = (rank == 0) ? 1 : 0; + int connIndex = peerRank < rank ? peerRank : peerRank - 1; + if (connIndex < connections.size()) { + connection = connections[connIndex]; + semaphoreId = proxyService->buildAndAddSemaphore(*communicator, connection); + // Setup port channel to copy from our send buffer to remote's receive buffer + auto portChannel = + proxyService->portChannel(semaphoreId, proxyService->addMemory(remoteRecvMemFutures[peerRank].get()), + proxyService->addMemory(sendBufRegMem)); + portChannelHandle = portChannel.deviceHandle(); + cudaMemcpyToSymbol(gPortChannel, &portChannelHandle, sizeof(portChannelHandle), 0, cudaMemcpyHostToDevice); + } + } + + for (int numParallel : config.parallelismLevels) { + // Add synchronization before each test iteration + MPI_Barrier(MPI_COMM_WORLD); + + runDataTransferTestVariant(stream, numParallel, combinedMetrics, rank, portChannelHandle, sendBuffer, recvBuffer, + halfElements); + + // Add synchronization after each test iteration + MPI_Barrier(MPI_COMM_WORLD); + } + + std::map testParams; + testParams["fifo_size"] = std::to_string(static_cast(config.fifoSize)); + testParams["elements_per_gpu"] = std::to_string(halfElements); + + // Add parallelism levels to test parameters + std::stringstream parallelismStream; + for (size_t i = 0; i < config.parallelismLevels.size(); ++i) { + if (i > 0) parallelismStream << ","; + parallelismStream << config.parallelismLevels[i]; + } + testParams["parallelism_levels"] = parallelismStream.str(); + + utils::recordResult(testName, "fifo_data_transfer", combinedMetrics, testParams); + + // Cleanup + utils::CUDA_CHECK(cudaStreamDestroy(stream)); + cudaFree(localSemaphoreFlag); + + proxyService->stopProxy(); +} + +void runAllDataTransferTests(const mscclpp::test::TestContext& context) { + // clang-format off + std::vector configs = { + {1, {1}}, + {128, {1, 8, 64, 128}}, + {512, {1, 8, 64, 256, 512}}, + }; + // clang-format on + + for (const auto& config : configs) { + runDataTransferTest(config, context); + } +} + +static void printUsage(char* argv0) { + std::stringstream ss; + ss << "Usage: " << argv0 << " [OPTIONS]\n" + << "\n" + << "Options:\n" + << " -o, --output-format FORMAT Output format: human or json (default: human)\n" + << " -f, --output-file FILE JSON output file path (default: report.jsonl)\n" + << " -v, --verbose Increase verbosity\n" + << " -h, --help Show this help message\n"; + std::cout << ss.str(); +} + +int main(int argc, char* argv[]) { + std::string outputFormat = "human"; + std::string outputFile = "report.jsonl"; + bool verbose = false; + + static struct option longOptions[] = {{"output-format", required_argument, 0, 'o'}, + {"output-file", required_argument, 0, 'f'}, + {"verbose", no_argument, 0, 'v'}, + {"help", no_argument, 0, 'h'}, + {0, 0, 0, 0}}; + + int c; + while ((c = getopt_long(argc, argv, "o:f:vh", longOptions, nullptr)) != -1) { + switch (c) { + case 'o': + outputFormat = optarg; + break; + case 'f': + outputFile = optarg; + break; + case 'v': + verbose = true; + break; + case 'h': + printUsage(argv[0]); + return 0; + default: + printUsage(argv[0]); + return 1; + } + } + + std::vector>> tests = { + {"AllDataTransferTests", "Data transfer tests with multiple configurations", runAllDataTransferTests}}; + + int result = utils::runMultipleTests(argc, argv, tests); + + if (utils::isMainRank()) { + if (outputFormat == "json") { + utils::writeResultsToFile(outputFile); + } else { + utils::printResults(verbose); + } + } + + utils::cleanupMPI(); + + return result; +} diff --git a/test/perf/framework.cc b/test/perf/framework.cc index 85f7abd81..6df157d35 100644 --- a/test/perf/framework.cc +++ b/test/perf/framework.cc @@ -145,9 +145,8 @@ void cudaCheck(cudaError_t err, const char* file, int line) { } } -int runMultipleTests( - int argc, char* argv[], - const std::vector>>& tests) { +int runMultipleTests(int argc, char* argv[], + const std::vector>& tests) { int totalResult = 0; // Initialize MPI once for all tests @@ -159,10 +158,47 @@ int runMultipleTests( int size = getMPISize(); int local_rank = rank; // For simplicity, assume local_rank = rank + // Check if any test needs TestContext + bool needsTestContext = false; + for (const auto& test : tests) { + const TestFunction& testFunction = std::get<2>(test); + if (std::holds_alternative>(testFunction)) { + needsTestContext = true; + break; + } + } + + // Only create communicator and bootstrap if needed + std::shared_ptr bootstrap; + std::shared_ptr comm; + TestContext context; + + if (needsTestContext) { + bootstrap = std::make_shared(rank, size); + mscclpp::UniqueId id; + if (isMainProcess()) { + id = mscclpp::TcpBootstrap::createUniqueId(); + } + MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(id); + + std::vector trans{mscclpp::Transport::IB0, mscclpp::Transport::IB1}; + cudaSetDevice(rank); + comm = std::make_shared(bootstrap); + bootstrap->barrier(); + + // Initialize test context + context.rank = rank; + context.size = size; + context.local_rank = local_rank; + context.communicator = comm; + context.bootstrap = bootstrap; + } + for (const auto& test : tests) { const std::string& testName = std::get<0>(test); const std::string& testDescription = std::get<1>(test); - const std::function& testFunction = std::get<2>(test); + const TestFunction& testFunction = std::get<2>(test); if (rank == 0) { std::cout << "Running test: " << testName << std::endl; @@ -171,12 +207,17 @@ int runMultipleTests( } } - // Don't clear results - accumulate them for all tests in the same file - // g_results.clear(); // Commented out to accumulate results - try { - // Run the individual test function with MPI information - testFunction(rank, size, local_rank); + // Run the appropriate test function based on its type + if (std::holds_alternative>(testFunction)) { + // Legacy API + const auto& legacyFunction = std::get>(testFunction); + legacyFunction(rank, size, local_rank); + } else { + // New API - pass TestContext + const auto& newFunction = std::get>(testFunction); + newFunction(context); + } // Synchronize before moving to next test MPI_Barrier(MPI_COMM_WORLD); @@ -189,20 +230,41 @@ int runMultipleTests( } } - // Don't cleanup MPI here - let the caller handle it - // finalizeMPI(); - } catch (const std::exception& e) { if (g_mpi_rank == 0) { std::cerr << "Error: " << e.what() << std::endl; } - finalizeMPI(); + return 1; } return totalResult; } +int runMultipleTests( + int argc, char* argv[], + const std::vector>>& tests) { + // Convert to unified format + std::vector> unifiedTests; + for (const auto& test : tests) { + unifiedTests.emplace_back(std::get<0>(test), std::get<1>(test), TestFunction(std::get<2>(test))); + } + + return runMultipleTests(argc, argv, unifiedTests); +} + +int runMultipleTests( + int argc, char* argv[], + const std::vector>>& tests) { + // Convert to unified format + std::vector> unifiedTests; + for (const auto& test : tests) { + unifiedTests.emplace_back(std::get<0>(test), std::get<1>(test), TestFunction(std::get<2>(test))); + } + + return runMultipleTests(argc, argv, unifiedTests); +} + } // namespace utils } // namespace test } // namespace mscclpp diff --git a/test/perf/framework.hpp b/test/perf/framework.hpp index e9b8c31f5..d66bfa503 100644 --- a/test/perf/framework.hpp +++ b/test/perf/framework.hpp @@ -10,15 +10,30 @@ #include #include #include +#include #include #include #include #include +#include #include namespace mscclpp { namespace test { +// Forward declarations +class Communicator; +class Connection; + +// Test context structure containing MPI and MSCCLPP objects +struct TestContext { + int rank; + int size; + int local_rank; + std::shared_ptr communicator; + std::shared_ptr bootstrap; +}; + // Test result structure struct TestResult { std::string test_name; @@ -33,11 +48,24 @@ struct TestResult { // Simple utility functions for testing namespace utils { +// Test function variant type +using TestFunction = std::variant, // Legacy API + std::function // New API + >; + // Test execution utilities int runMultipleTests( int argc, char* argv[], const std::vector>>& tests); +int runMultipleTests( + int argc, char* argv[], + const std::vector>>& tests); + +// Unified test execution API +int runMultipleTests(int argc, char* argv[], + const std::vector>& tests); + // MPI management void initializeMPI(int argc, char* argv[]); void cleanupMPI();