5 #include "BorderExchange.hpp" 6 #include "include/pv_common.h" 7 #include "utils/PVAssert.hpp" 8 #include "utils/conversions.h" 12 BorderExchange::BorderExchange(MPIBlock
const &mpiBlock,
PVLayerLoc const &loc) {
13 mMPIBlock = &mpiBlock;
19 BorderExchange::~BorderExchange() { freeDatatypes(); }
21 void BorderExchange::newDatatypes() {
23 int count, blocklength, stride;
24 mDatatypes.resize(NUM_NEIGHBORHOOD);
26 int const leftBorder = mLayerLoc.halo.lt;
27 int const rightBorder = mLayerLoc.halo.rt;
28 int const bottomBorder = mLayerLoc.halo.dn;
29 int const topBorder = mLayerLoc.halo.up;
31 int const nf = mLayerLoc.nf;
34 blocklength = nf * mLayerLoc.nx;
35 stride = nf * (mLayerLoc.nx + leftBorder + rightBorder);
38 MPI_Type_vector(count, blocklength, stride, MPI_FLOAT, &mDatatypes[LOCAL]);
39 MPI_Type_commit(&mDatatypes[LOCAL]);
44 blocklength = nf * leftBorder;
45 MPI_Type_vector(count, blocklength, stride, MPI_FLOAT, &mDatatypes[NORTHWEST]);
46 MPI_Type_commit(&mDatatypes[NORTHWEST]);
49 blocklength = nf * mLayerLoc.nx;
50 MPI_Type_vector(count, blocklength, stride, MPI_FLOAT, &mDatatypes[NORTH]);
51 MPI_Type_commit(&mDatatypes[NORTH]);
54 blocklength = nf * rightBorder;
55 MPI_Type_vector(count, blocklength, stride, MPI_FLOAT, &mDatatypes[NORTHEAST]);
56 MPI_Type_commit(&mDatatypes[NORTHEAST]);
61 blocklength = nf * leftBorder;
62 MPI_Type_vector(count, blocklength, stride, MPI_FLOAT, &mDatatypes[WEST]);
63 MPI_Type_commit(&mDatatypes[WEST]);
66 blocklength = nf * rightBorder;
67 MPI_Type_vector(count, blocklength, stride, MPI_FLOAT, &mDatatypes[EAST]);
68 MPI_Type_commit(&mDatatypes[EAST]);
73 blocklength = nf * leftBorder;
74 MPI_Type_vector(count, blocklength, stride, MPI_FLOAT, &mDatatypes[SOUTHWEST]);
75 MPI_Type_commit(&mDatatypes[SOUTHWEST]);
78 blocklength = nf * mLayerLoc.nx;
79 MPI_Type_vector(count, blocklength, stride, MPI_FLOAT, &mDatatypes[SOUTH]);
80 MPI_Type_commit(&mDatatypes[SOUTH]);
83 blocklength = nf * rightBorder;
84 MPI_Type_vector(count, blocklength, stride, MPI_FLOAT, &mDatatypes[SOUTHEAST]);
85 MPI_Type_commit(&mDatatypes[SOUTHEAST]);
91 void BorderExchange::freeDatatypes() {
93 for (
auto &d : mDatatypes) {
100 void BorderExchange::initNeighbors() {
101 neighbors.resize(NUM_NEIGHBORHOOD, -1);
107 for (
int i = 0; i < NUM_NEIGHBORHOOD; i++) {
114 neighbors[i] = mMPIBlock->
getRank();
119 void BorderExchange::exchange(
float *data, std::vector<MPI_Request> &req) {
121 PVHalo const &halo = mLayerLoc.halo;
122 if (halo.lt == 0 && halo.rt == 0 && halo.dn == 0 && halo.up == 0) {
128 for (
int n = 1; n < NUM_NEIGHBORHOOD; n++) {
129 if (neighbors[n] == mMPIBlock->
getRank())
132 auto sz = req.size();
140 exchangeCounter * 16 + mTags[revDir],
145 for (
int n = 1; n < NUM_NEIGHBORHOOD; n++) {
146 if (neighbors[n] == mMPIBlock->
getRank())
149 auto sz = req.size();
156 exchangeCounter * 16 + mTags[n],
161 exchangeCounter = (exchangeCounter == 2047) ? 1024 : exchangeCounter + 1;
167 int BorderExchange::wait(std::vector<MPI_Request> &req) {
168 int status = MPI_Waitall(req.size(), req.data(), MPI_STATUSES_IGNORE);
181 int rankRowColumn = commId % (numRows * numColumns);
182 int row = rowFromRank(rankRowColumn, numRows, numColumns);
183 int column = columnFromRank(rankRowColumn, numRows, numColumns);
186 case LOCAL:
return commId;
187 case NORTHWEST: neighborRank = northwest(row, column, numRows, numColumns);
break;
188 case NORTH: neighborRank = north(row, column, numRows, numColumns);
break;
189 case NORTHEAST: neighborRank = northeast(row, column, numRows, numColumns);
break;
190 case WEST: neighborRank = west(row, column, numRows, numColumns);
break;
191 case EAST: neighborRank = east(row, column, numRows, numColumns);
break;
192 case SOUTHWEST: neighborRank = southwest(row, column, numRows, numColumns);
break;
193 case SOUTH: neighborRank = south(row, column, numRows, numColumns);
break;
194 case SOUTHEAST: neighborRank = southeast(row, column, numRows, numColumns);
break;
195 default: pvAssert(0);
break;
197 if (neighborRank >= 0) {
198 int rankBatchStart = commId - rankRowColumn;
199 neighborRank += rankBatchStart;
204 int BorderExchange::northwest(
int row,
int column,
int numRows,
int numColumns) {
205 if (hasNorthwesternNeighbor(row, column, numRows, numColumns)) {
206 int const neighborRow = row - (row > 0);
207 int const neighborColumn = column - (column > 0);
208 return rankFromRowAndColumn(neighborRow, neighborColumn, numRows, numColumns);
215 int BorderExchange::north(
int row,
int column,
int numRows,
int numColumns) {
216 if (hasNorthernNeighbor(row, column, numRows, numColumns)) {
217 int const neighborRow = row - 1;
218 int const neighborColumn = column;
219 return rankFromRowAndColumn(neighborRow, neighborColumn, numRows, numColumns);
226 int BorderExchange::northeast(
int row,
int column,
int numRows,
int numColumns) {
227 if (hasNortheasternNeighbor(row, column, numRows, numColumns)) {
228 int const neighborRow = row - (row > 0);
229 int const neighborColumn = column + (column < numColumns - 1);
230 return rankFromRowAndColumn(neighborRow, neighborColumn, numRows, numColumns);
237 int BorderExchange::west(
int row,
int column,
int numRows,
int numColumns) {
238 if (hasWesternNeighbor(row, column, numRows, numColumns)) {
239 int const neighborRow = row;
240 int const neighborColumn = column - 1;
241 return rankFromRowAndColumn(neighborRow, neighborColumn, numRows, numColumns);
248 int BorderExchange::east(
int row,
int column,
int numRows,
int numColumns) {
249 if (hasEasternNeighbor(row, column, numRows, numColumns)) {
250 int const neighborRow = row;
251 int const neighborColumn = column + 1;
252 return rankFromRowAndColumn(neighborRow, neighborColumn, numRows, numColumns);
259 int BorderExchange::southwest(
int row,
int column,
int numRows,
int numColumns) {
260 if (hasSouthwesternNeighbor(row, column, numRows, numColumns)) {
261 int const neighborRow = row + (row < numRows - 1);
262 int const neighborColumn = column - (column > 0);
263 return rankFromRowAndColumn(neighborRow, neighborColumn, numRows, numColumns);
270 int BorderExchange::south(
int row,
int column,
int numRows,
int numColumns) {
271 if (hasSouthernNeighbor(row, column, numRows, numColumns)) {
272 int const neighborRow = row + 1;
273 int const neighborColumn = column;
274 return rankFromRowAndColumn(neighborRow, neighborColumn, numRows, numColumns);
281 int BorderExchange::southeast(
int row,
int column,
int numRows,
int numColumns) {
282 if (hasSoutheasternNeighbor(row, column, numRows, numColumns)) {
283 int const neighborRow = row + (row < numRows - 1);
284 int const neighborColumn = column + (column < numColumns - 1);
285 return rankFromRowAndColumn(neighborRow, neighborColumn, numRows, numColumns);
292 bool BorderExchange::hasNorthwesternNeighbor(
int row,
int column,
int numRows,
int numColumns) {
293 return (row > 0) || (column > 0);
296 bool BorderExchange::hasNorthernNeighbor(
int row,
int column,
int numRows,
int numColumns) {
300 bool BorderExchange::hasNortheasternNeighbor(
int row,
int column,
int numRows,
int numColumns) {
301 return (row > 0) || (column < numColumns - 1);
304 bool BorderExchange::hasWesternNeighbor(
int row,
int column,
int numRows,
int numColumns) {
308 bool BorderExchange::hasEasternNeighbor(
int row,
int column,
int numRows,
int numColumns) {
309 return column < numColumns - 1;
312 bool BorderExchange::hasSouthwesternNeighbor(
int row,
int column,
int numRows,
int numColumns) {
313 return (row < numRows - 1) || (column > 0);
316 bool BorderExchange::hasSouthernNeighbor(
int row,
int column,
int numRows,
int numColumns) {
317 return row < numRows - 1;
320 bool BorderExchange::hasSoutheasternNeighbor(
int row,
int column,
int numRows,
int numColumns) {
321 return (row < numRows - 1) || (column < numColumns - 1);
326 if (neighbor == commId) {
329 int revdir = 9 - direction;
333 int rankRowColumn = commId % (numRows * numCols);
334 int row = rowFromRank(rankRowColumn, numRows, numCols);
335 int col = columnFromRank(rankRowColumn, numRows, numCols);
341 pvAssert(revdir == SOUTHEAST);
356 pvAssert(revdir == SOUTHWEST);
358 pvAssert(col < numCols - 1);
361 if (col == numCols - 1) {
366 case WEST: pvAssert(col > 0);
break;
367 case EAST: pvAssert(col < numCols - 1);
break;
369 pvAssert(revdir == NORTHEAST);
370 if (row == numRows - 1) {
375 pvAssert(row < numRows - 1);
379 case SOUTH: pvAssert(row < numRows - 1);
break;
381 pvAssert(revdir == NORTHWEST);
382 if (row == numRows - 1) {
383 pvAssert(col < numCols - 1);
386 if (col == numCols - 1) {
387 pvAssert(row < numRows - 1);
400 const int nx = mLayerLoc.nx;
401 const int ny = mLayerLoc.ny;
402 const int leftBorder = mLayerLoc.halo.lt;
403 const int topBorder = mLayerLoc.halo.dn;
405 const int sx = strideXExtended(&mLayerLoc);
406 const int sy = strideYExtended(&mLayerLoc);
411 case LOCAL: offset = sx * leftBorder + sy * topBorder;
break;
412 case NORTHWEST: offset = 0;
break;
413 case NORTH: offset = sx * leftBorder;
break;
414 case NORTHEAST: offset = sx * leftBorder + sx * nx;
break;
415 case WEST: offset = sy * topBorder;
break;
416 case EAST: offset = sx * leftBorder + sx * nx + sy * topBorder;
break;
417 case SOUTHWEST: offset = sy * (topBorder + ny);
break;
418 case SOUTH: offset = sx * leftBorder + sy * (topBorder + ny);
break;
419 case SOUTHEAST: offset = sx * leftBorder + sx * nx + sy * (topBorder + ny);
break;
424 return (std::size_t)offset;
432 const size_t nx = mLayerLoc.nx;
433 const size_t ny = mLayerLoc.ny;
434 const size_t leftBorder = mLayerLoc.halo.lt;
435 const size_t topBorder = mLayerLoc.halo.up;
437 const size_t sx = strideXExtended(&mLayerLoc);
438 const size_t sy = strideYExtended(&mLayerLoc);
445 bool hasNorthNeighbor = row > 0;
446 bool hasWestNeighbor = col > 0;
447 bool hasEastNeighbor = col < numCols - 1;
448 bool hasSouthNeighbor = row < numRows - 1;
452 case LOCAL: offset = sx * leftBorder + sy * topBorder;
break;
454 offset = sx * hasWestNeighbor * leftBorder + sy * hasNorthNeighbor * topBorder;
456 case NORTH: offset = sx * leftBorder + sy * topBorder;
break;
458 offset = sx * (nx + !hasEastNeighbor * leftBorder) + sy * hasNorthNeighbor * topBorder;
460 case WEST: offset = sx * leftBorder + sy * topBorder;
break;
461 case EAST: offset = sx * nx + sy * topBorder;
break;
463 offset = sx * hasWestNeighbor * leftBorder + sy * (ny + !hasSouthNeighbor * topBorder);
465 case SOUTH: offset = sx * leftBorder + sy * ny;
break;
467 return sx * (nx + !hasEastNeighbor * leftBorder)
468 + sy * (ny + !hasSouthNeighbor * topBorder);
474 return (std::size_t)offset;
483 std::vector<int>
const BorderExchange::mTags = {0, 33, 34, 35, 34, 34, 35, 34, 33};
485 int BorderExchange::exchangeCounter = 1024;
int getNumColumns() const
int reverseDirection(int commId, int direction)
std::size_t recvOffset(int direction)
std::size_t sendOffset(int direction)
int neighborIndex(int commId, int direction)
int getColumnIndex() const
int getBatchDimension() const