12 #include "Communicator.hpp" 14 #include "utils/PVAssert.hpp" 15 #include "utils/PVLog.hpp" 16 #include "utils/conversions.h" 20 int Communicator::gcd(
int a,
int b) {
30 Communicator::Communicator(Arguments *argumentList) {
32 MPI_Comm_rank(MPI_COMM_WORLD, &globalRank);
33 MPI_Comm_size(MPI_COMM_WORLD, &totalSize);
35 numRows = argumentList->getIntegerArgument(
"NumRows");
36 numCols = argumentList->getIntegerArgument(
"NumColumns");
37 batchWidth = argumentList->getIntegerArgument(
"BatchWidth");
39 bool rowsDefined = numRows != 0;
40 bool colsDefined = numCols != 0;
41 bool batchDefined = batchWidth != 0;
47 int procsLeft = totalSize / batchWidth;
48 if (rowsDefined && !colsDefined) {
49 numCols = (int)ceil(procsLeft / numRows);
51 if (!rowsDefined && colsDefined) {
52 numRows = (int)ceil(procsLeft / numCols);
54 if (!rowsDefined && !colsDefined) {
55 double r = std::sqrt(procsLeft);
58 Fatal() <<
"Not enough processes left\n";
60 numCols = (int)ceil(procsLeft / numRows);
63 int commSize = batchWidth * numRows * numCols;
66 if (globalRank == 0) {
67 InfoLog() <<
"Running with batchWidth=" << batchWidth <<
", numRows=" << numRows
68 <<
", and numCols=" << numCols <<
"\n";
71 if (commSize > totalSize) {
72 Fatal() <<
"Number of required processes (NumRows * NumColumns * BatchWidth = " << commSize
73 <<
") should be the same as, and cannot be larger than, the number of processes " 75 << totalSize <<
")\n";
79 new MPIBlock(MPI_COMM_WORLD, numRows, numCols, batchWidth, numRows, numCols, batchWidth);
80 isExtra = (globalRank >= commSize);
82 WarnLog() <<
"Global process rank " << globalRank <<
" is extra, as only " << commSize
83 <<
" mpiProcesses are required. Process exiting\n";
89 bool requireReturn = argumentList->getBooleanArgument(
"RequireReturn");
92 MPI_Barrier(globalCommunicator());
93 if (globalRank == 0) {
94 std::printf(
"Hit enter to begin! ");
97 while (charhit !=
'\n') {
98 charhit = std::getc(stdin);
101 MPI_Barrier(globalCommunicator());
105 MPI_Comm_size(globalCommunicator(), &globalSize);
109 new MPIBlock{globalCommunicator(), numRows, numCols, batchWidth, numRows, numCols, 1};
111 localRank = localMPIBlock->
getRank();
114 new MPIBlock{globalCommunicator(), numRows, numCols, batchWidth, 1, 1, batchWidth};
124 MPI_Comm_size(communicator(), &localSize);
125 MPI_Comm_rank(communicator(), &tmpLocalRank);
127 pvAssert(tmpLocalRank == localRank);
129 if (globalSize > 0) {
132 MPI_Barrier(globalCommunicator());
135 Communicator::~Communicator() {
137 MPI_Barrier(globalCommunicator());
139 delete localMPIBlock;
140 delete batchMPIBlock;
141 delete globalMPIBlock;
148 int num_neighbors = 0;
155 int tags[9] = {0, 1, 2, 3, 2, 2, 3, 2, 1};
164 for (
int i = 0; i < NUM_NEIGHBORHOOD; i++) {
166 neighbors[i] = localRank;
172 "[%2d]: neighborInit: remote[%d] of %d is %d, i=%d, neighbor=%d\n",
179 #endif // DEBUG_OUTPUT 183 DebugLog().printf(
"[%2d]: neighborInit: i=%d, neighbor=%d\n", localRank, i, neighbors[i]);
184 #endif // DEBUG_OUTPUT 186 this->tags[i] = tags[i];
188 assert(this->numNeighbors == num_neighbors);
196 int Communicator::commRow(
int commId) {
return rowFromRank(commId, numRows, numCols); }
201 int Communicator::commColumn(
int commId) {
return columnFromRank(commId, numRows, numCols); }
206 int Communicator::commBatch(
int commId) {
207 return batchFromRank(commId, batchWidth, numRows, numCols);
214 return rankFromRowAndColumn(commRow, commColumn, numRows, numCols);
301 int nbr_id = -NORTHWEST;
303 int nbr_row = commRow - (commRow > 0);
304 int nbr_column = commColumn - (commColumn > 0);
325 int nbr_id = -NORTHEAST;
327 int nbr_row = commRow - (commRow > 0);
328 int nbr_column = commColumn + (commColumn < numCommColumns() - 1);
360 int nbr_id = -SOUTHWEST;
362 int nbr_row = commRow + (commRow < numCommRows() - 1);
363 int nbr_column = commColumn - (commColumn > 0);
384 int nbr_id = -SOUTHEAST;
386 int nbr_row = commRow + (commRow < numCommRows() - 1);
387 int nbr_column = commColumn + (commColumn < numCommColumns() - 1);
398 int row = commRow(commId);
399 int column = commColumn(commId);
401 case LOCAL:
return commId;
402 case NORTHWEST:
return northwest(row, column);
403 case NORTH:
return north(row, column);
404 case NORTHEAST:
return northeast(row, column);
405 case WEST:
return west(row, column);
406 case EAST:
return east(row, column);
407 case SOUTHWEST:
return southwest(row, column);
408 case SOUTH:
return south(row, column);
409 case SOUTHEAST:
return southeast(row, column);
410 default: ErrorLog().printf(
"neighborIndex %d: bad index\n", direction);
return -1;
430 int Communicator::reverseDirection(
int commId,
int direction) {
432 if (neighbor == commId) {
435 int revdir = 9 - direction;
436 int col = commColumn(commId);
437 int row = commRow(commId);
443 assert(revdir == SOUTHEAST);
454 assert(commRow(commId) > 0);
458 assert(revdir == SOUTHWEST);
460 assert(col < numCols - 1);
463 if (col == numCols - 1) {
468 case WEST: assert(commColumn(commId) > 0);
break;
469 case EAST: assert(commColumn(commId) < numCols - 1);
break;
471 assert(revdir == NORTHEAST);
472 if (row == numRows - 1) {
477 assert(row < numRows - 1);
481 case SOUTH: assert(commRow(commId) < numRows - 1);
break;
483 assert(revdir == NORTHWEST);
484 if (row == numRows - 1) {
485 assert(col < numCols - 1);
488 if (col == numCols - 1) {
489 assert(row < numRows - 1);
494 ErrorLog().printf(
"neighborIndex %d: bad index\n", direction);
bool hasWesternNeighbor(int commRow, int commColumn)
bool hasSoutheasternNeighbor(int commRow, int commColumn)
bool hasNortheasternNeighbor(int commRow, int commColumn)
int west(int commRow, int commColumn)
int northeast(int commRow, int commColumn)
bool hasSouthwesternNeighbor(int commRow, int commColumn)
bool hasSouthernNeighbor(int commRow, int commColumn)
int east(int commRow, int commColumn)
int north(int commRow, int commColumn)
int southwest(int commRow, int commColumn)
int northwest(int commRow, int commColumn)
int south(int commRow, int commColumn)
bool hasNeighbor(int neighborId)
bool hasNorthernNeighbor(int commRow, int commColumn)
int neighborIndex(int commId, int index)
bool hasEasternNeighbor(int commRow, int commColumn)
int southeast(int commRow, int commColumn)
bool hasNorthwesternNeighbor(int commRow, int commColumn)
int commIdFromRowColumn(int commRow, int commColumn)