Skip to content

Commit a54e6a7

Browse files
committed
update
1 parent b9ec5a6 commit a54e6a7

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

include/mscclpp/proxy_channel.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <mscclpp/fifo.hpp>
99
#include <mscclpp/proxy.hpp>
1010
#include <mscclpp/semaphore.hpp>
11-
#include <unordered_map>
1211

1312
namespace mscclpp {
1413

@@ -41,10 +40,10 @@ class ProxyService : public BaseProxyService {
4140
/// @return The ID of the semaphore.
4241
SemaphoreId addSemaphore(std::shared_ptr<Connection> connection);
4342

44-
/// Add a pitch pair to the proxy service.
45-
/// @param id The ID of the semaphore.
43+
/// Add a 2D channel to the proxy service.
44+
/// @param connection The connection associated with the channel.
4645
/// @param pitch The pitch pair.
47-
void addPitch(SemaphoreId id, std::pair<uint64_t, uint64_t> pitch);
46+
SemaphoreId add2DChannel(std::shared_ptr<Connection> connection, std::pair<uint64_t, uint64_t> pitch);
4847

4948
/// Register a memory region with the proxy service.
5049
/// @param memory The memory region to register.
@@ -71,7 +70,7 @@ class ProxyService : public BaseProxyService {
7170
Communicator& communicator_;
7271
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
7372
std::vector<RegisteredMemory> memories_;
74-
std::unordered_map<SemaphoreId, std::pair<uint64_t, uint64_t>> pitches_;
73+
std::vector<std::pair<uint64_t, uint64_t>> pitches_;
7574
Proxy proxy_;
7675
int deviceNumaNode;
7776

src/proxy_channel.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Connectio
2929
return semaphores_.size() - 1;
3030
}
3131

32-
MSCCLPP_API_CPP void ProxyService::addPitch(SemaphoreId id, std::pair<uint64_t, uint64_t> pitch) {
32+
MSCCLPP_API_CPP SemaphoreId ProxyService::add2DChannel(std::shared_ptr<Connection> connection,
33+
std::pair<uint64_t, uint64_t> pitch) {
34+
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator_, connection));
35+
SemaphoreId id = semaphores_.size() - 1;
36+
if (id >= pitches_.size()) pitches_.resize(id + 1, std::pair<uint64_t, uint64_t>(0, 0));
3337
pitches_[id] = pitch;
38+
return id;
3439
}
3540

3641
MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) {

test/mp_unit/proxy_channel_tests.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ void ProxyChannelOneToOneTest::setupMeshConnections(
5858

5959
communicator->setup();
6060

61-
mscclpp::SemaphoreId cid = channelService->addSemaphore(conn);
62-
channelService->addPitch(cid, std::pair<size_t, size_t>(pitch, pitch));
61+
mscclpp::SemaphoreId cid = channelService->add2DChannel(conn, std::pair<size_t, size_t>(pitch, pitch));
6362
communicator->setup();
6463

6564
proxyChannels.emplace_back(mscclpp::deviceHandle(
@@ -77,13 +76,13 @@ __device__ size_t getTileElementOffset(int elementId, int width, int rowIndex, i
7776
}
7877

7978
__global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowIndex, int colIndex, int width,
80-
int hight, int* ret) {
79+
int height, int* ret) {
8180
DeviceHandle<mscclpp::SimpleProxyChannel>& proxyChan = gChannelOneToOneTestConstProxyChans;
8281
volatile int* sendBuff = (volatile int*)buff;
8382
int nTries = 1000;
8483
int flusher = 0;
8584
size_t offset = rowIndex * pitch + colIndex * sizeof(int);
86-
size_t nElem = width * hight;
85+
size_t nElem = width * height;
8786
size_t nElemPerPitch = pitch / sizeof(int);
8887
for (int i = 0; i < nTries; i++) {
8988
if (rank == 0) {
@@ -105,7 +104,7 @@ __global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowI
105104
}
106105
__syncthreads();
107106
// __threadfence_system(); // not necessary if we make sendBuff volatile
108-
if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), hight);
107+
if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), height);
109108
}
110109
if (rank == 1) {
111110
if (threadIdx.x == 0) proxyChan.wait();
@@ -125,7 +124,7 @@ __global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowI
125124
}
126125
__syncthreads();
127126
// __threadfence_system(); // not necessary if we make sendBuff volatile
128-
if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), hight);
127+
if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), height);
129128
}
130129
}
131130
flusher++;

0 commit comments

Comments
 (0)