PetaVision  Alpha
Communicator.cpp
1 /*
2  * Communicator.cpp
3  */
4 
5 #include <cassert>
6 #include <cmath>
7 #include <cstdio>
8 #include <cstdlib>
9 #include <cstring>
10 #include <iostream>
11 
12 #include "Communicator.hpp"
13 #include "io/io.hpp"
14 #include "utils/PVAssert.hpp"
15 #include "utils/PVLog.hpp"
16 #include "utils/conversions.h"
17 
18 namespace PV {
19 
20 int Communicator::gcd(int a, int b) {
21  int c;
22  while (a != 0) {
23  c = a;
24  a = b % a;
25  b = c;
26  }
27  return b;
28 }
29 
30 Communicator::Communicator(Arguments *argumentList) {
31  int totalSize;
32  MPI_Comm_rank(MPI_COMM_WORLD, &globalRank);
33  MPI_Comm_size(MPI_COMM_WORLD, &totalSize);
34 
35  numRows = argumentList->getIntegerArgument("NumRows");
36  numCols = argumentList->getIntegerArgument("NumColumns");
37  batchWidth = argumentList->getIntegerArgument("BatchWidth");
38 
39  bool rowsDefined = numRows != 0;
40  bool colsDefined = numCols != 0;
41  bool batchDefined = batchWidth != 0;
42 
43  if (!batchDefined) {
44  batchWidth = 1;
45  }
46 
47  int procsLeft = totalSize / batchWidth;
48  if (rowsDefined && !colsDefined) {
49  numCols = (int)ceil(procsLeft / numRows);
50  }
51  if (!rowsDefined && colsDefined) {
52  numRows = (int)ceil(procsLeft / numCols);
53  }
54  if (!rowsDefined && !colsDefined) {
55  double r = std::sqrt(procsLeft);
56  numRows = (int)r;
57  if (numRows == 0) {
58  Fatal() << "Not enough processes left\n";
59  }
60  numCols = (int)ceil(procsLeft / numRows);
61  }
62 
63  int commSize = batchWidth * numRows * numCols;
64 
65  // For debugging
66  if (globalRank == 0) {
67  InfoLog() << "Running with batchWidth=" << batchWidth << ", numRows=" << numRows
68  << ", and numCols=" << numCols << "\n";
69  }
70 
71  if (commSize > totalSize) {
72  Fatal() << "Number of required processes (NumRows * NumColumns * BatchWidth = " << commSize
73  << ") should be the same as, and cannot be larger than, the number of processes "
74  "launched ("
75  << totalSize << ")\n";
76  }
77 
78  globalMPIBlock =
79  new MPIBlock(MPI_COMM_WORLD, numRows, numCols, batchWidth, numRows, numCols, batchWidth);
80  isExtra = (globalRank >= commSize);
81  if (isExtra) {
82  WarnLog() << "Global process rank " << globalRank << " is extra, as only " << commSize
83  << " mpiProcesses are required. Process exiting\n";
84  return;
85  }
86  // globalMPIBlock's communicator now has only useful mpi processes
87 
88  // If RequireReturn was set, wait until global root process gets keyboard input.
89  bool requireReturn = argumentList->getBooleanArgument("RequireReturn");
90  if (requireReturn) {
91  fflush(stdout);
92  MPI_Barrier(globalCommunicator());
93  if (globalRank == 0) {
94  std::printf("Hit enter to begin! ");
95  fflush(stdout);
96  int charhit = -1;
97  while (charhit != '\n') {
98  charhit = std::getc(stdin);
99  }
100  }
101  MPI_Barrier(globalCommunicator());
102  }
103 
104  // Grab globalSize now that extra processes have been exited
105  MPI_Comm_size(globalCommunicator(), &globalSize);
106 
107  // Make new local communicator
108  localMPIBlock =
109  new MPIBlock{globalCommunicator(), numRows, numCols, batchWidth, numRows, numCols, 1};
110  // Set local rank
111  localRank = localMPIBlock->getRank();
112  // Make new batch communicator
113  batchMPIBlock =
114  new MPIBlock{globalCommunicator(), numRows, numCols, batchWidth, 1, 1, batchWidth};
115 
116  //#ifdef DEBUG_OUTPUT
117  // DebugLog().printf("[%2d]: Formed resized communicator, size==%d
118  // cols==%d rows==%d\n",
119  // icRank, icSize, numCols, numRows);
120  //#endif // DEBUG_OUTPUT
121 
122  // Grab local rank and check for errors
123  int tmpLocalRank;
124  MPI_Comm_size(communicator(), &localSize);
125  MPI_Comm_rank(communicator(), &tmpLocalRank);
126  // This should be equiv
127  pvAssert(tmpLocalRank == localRank);
128 
129  if (globalSize > 0) {
130  neighborInit();
131  }
132  MPI_Barrier(globalCommunicator());
133 }
134 
135 Communicator::~Communicator() {
136 #ifdef PV_USE_MPI
137  MPI_Barrier(globalCommunicator());
138 #endif
139  delete localMPIBlock;
140  delete batchMPIBlock;
141  delete globalMPIBlock;
142 }
143 
148  int num_neighbors = 0;
149  int num_borders = 0;
150 
151  // initialize neighbor and border lists
152  // (local borders and remote neighbors form the complete neighborhood)
153 
154  this->numNeighbors = numberOfNeighbors();
155  int tags[9] = {0, 1, 2, 3, 2, 2, 3, 2, 1};
156  // NW and SE corners have tag 1; edges have tag 2; NE and SW corners have
157  // tag 3.
158  // In the top row of processes in the hypercolumn, a process is both the
159  // northeast and east neighbor of the process to its left. If there is only
160  // one row, a process is the northeast, east, and southeast neighbor of the
161  // process to its left. The numbering of tags ensures that the
162  // MPI_Send/MPI_Irecv pairs can be distinguished.
163 
164  for (int i = 0; i < NUM_NEIGHBORHOOD; i++) {
165  int n = neighborIndex(localRank, i);
166  neighbors[i] = localRank; // default neighbor is self
167  if (n >= 0) {
168  neighbors[i] = n;
169  num_neighbors++;
170 #ifdef DEBUG_OUTPUT
171  DebugLog().printf(
172  "[%2d]: neighborInit: remote[%d] of %d is %d, i=%d, neighbor=%d\n",
173  localRank,
174  num_neighbors - 1,
175  this->numNeighbors,
176  n,
177  i,
178  neighbors[i]);
179 #endif // DEBUG_OUTPUT
180  }
181  else {
182 #ifdef DEBUG_OUTPUT
183  DebugLog().printf("[%2d]: neighborInit: i=%d, neighbor=%d\n", localRank, i, neighbors[i]);
184 #endif // DEBUG_OUTPUT
185  }
186  this->tags[i] = tags[i];
187  }
188  assert(this->numNeighbors == num_neighbors);
189 
190  return 0;
191 }
192 
196 int Communicator::commRow(int commId) { return rowFromRank(commId, numRows, numCols); }
197 
201 int Communicator::commColumn(int commId) { return columnFromRank(commId, numRows, numCols); }
202 
206 int Communicator::commBatch(int commId) {
207  return batchFromRank(commId, batchWidth, numRows, numCols);
208 }
209 
213 int Communicator::commIdFromRowColumn(int commRow, int commColumn) {
214  return rankFromRowAndColumn(commRow, commColumn, numRows, numCols);
215 }
216 
221 bool Communicator::hasNeighbor(int neighbor) {
222  int nbrIdx = neighborIndex(localRank, neighbor);
223  return nbrIdx >= 0;
224 }
225 
230 bool Communicator::hasNorthwesternNeighbor(int row, int column) {
231  return (hasNorthernNeighbor(row, column) || hasWesternNeighbor(row, column));
232 }
233 
238 bool Communicator::hasNorthernNeighbor(int row, int column) { return row > 0; }
239 
244 bool Communicator::hasNortheasternNeighbor(int row, int column) {
245  return (hasNorthernNeighbor(row, column) || hasEasternNeighbor(row, column));
246 }
247 
252 bool Communicator::hasWesternNeighbor(int row, int column) { return column > 0; }
253 
258 bool Communicator::hasEasternNeighbor(int row, int column) { return column < numCommColumns() - 1; }
259 
264 bool Communicator::hasSouthwesternNeighbor(int row, int column) {
265  return (hasSouthernNeighbor(row, column) || hasWesternNeighbor(row, column));
266 }
267 
272 bool Communicator::hasSouthernNeighbor(int row, int column) { return row < numCommRows() - 1; }
273 
278 bool Communicator::hasSoutheasternNeighbor(int row, int column) {
279  return (hasSouthernNeighbor(row, column) || hasEasternNeighbor(row, column));
280 }
281 
286  int n = 1 + hasNorthwesternNeighbor(commRow(), commColumn())
287  + hasNorthernNeighbor(commRow(), commColumn())
288  + hasNortheasternNeighbor(commRow(), commColumn())
289  + hasWesternNeighbor(commRow(), commColumn())
290  + hasEasternNeighbor(commRow(), commColumn())
291  + hasSouthwesternNeighbor(commRow(), commColumn())
292  + hasSouthernNeighbor(commRow(), commColumn())
293  + hasSoutheasternNeighbor(commRow(), commColumn());
294  return n;
295 }
296 
300 int Communicator::northwest(int commRow, int commColumn) {
301  int nbr_id = -NORTHWEST;
302  if (hasNorthwesternNeighbor(commRow, commColumn)) {
303  int nbr_row = commRow - (commRow > 0);
304  int nbr_column = commColumn - (commColumn > 0);
305  nbr_id = commIdFromRowColumn(nbr_row, nbr_column);
306  }
307  return nbr_id;
308 }
309 
313 int Communicator::north(int commRow, int commColumn) {
314  int nbr_id = -NORTH;
315  if (hasNorthernNeighbor(commRow, commColumn)) {
316  nbr_id = commIdFromRowColumn(commRow - 1, commColumn);
317  }
318  return nbr_id;
319 }
320 
324 int Communicator::northeast(int commRow, int commColumn) {
325  int nbr_id = -NORTHEAST;
326  if (hasNortheasternNeighbor(commRow, commColumn)) {
327  int nbr_row = commRow - (commRow > 0);
328  int nbr_column = commColumn + (commColumn < numCommColumns() - 1);
329  nbr_id = commIdFromRowColumn(nbr_row, nbr_column);
330  }
331  return nbr_id;
332 }
333 
337 int Communicator::west(int commRow, int commColumn) {
338  int nbr_id = -WEST;
339  if (hasWesternNeighbor(commRow, commColumn)) {
340  nbr_id = commIdFromRowColumn(commRow, commColumn - 1);
341  }
342  return nbr_id;
343 }
344 
348 int Communicator::east(int commRow, int commColumn) {
349  int nbr_id = -EAST;
350  if (hasEasternNeighbor(commRow, commColumn)) {
351  nbr_id = commIdFromRowColumn(commRow, commColumn + 1);
352  }
353  return nbr_id;
354 }
355 
359 int Communicator::southwest(int commRow, int commColumn) {
360  int nbr_id = -SOUTHWEST;
361  if (hasSouthwesternNeighbor(commRow, commColumn)) {
362  int nbr_row = commRow + (commRow < numCommRows() - 1);
363  int nbr_column = commColumn - (commColumn > 0);
364  nbr_id = commIdFromRowColumn(nbr_row, nbr_column);
365  }
366  return nbr_id;
367 }
368 
372 int Communicator::south(int commRow, int commColumn) {
373  int nbr_id = -SOUTH;
374  if (hasSouthernNeighbor(commRow, commColumn)) {
375  nbr_id = commIdFromRowColumn(commRow + 1, commColumn);
376  }
377  return nbr_id;
378 }
379 
383 int Communicator::southeast(int commRow, int commColumn) {
384  int nbr_id = -SOUTHEAST;
385  if (hasSoutheasternNeighbor(commRow, commColumn)) {
386  int nbr_row = commRow + (commRow < numCommRows() - 1);
387  int nbr_column = commColumn + (commColumn < numCommColumns() - 1);
388  nbr_id = commIdFromRowColumn(nbr_row, nbr_column);
389  }
390  return nbr_id;
391 }
392 
397 int Communicator::neighborIndex(int commId, int direction) {
398  int row = commRow(commId);
399  int column = commColumn(commId);
400  switch (direction) {
401  case LOCAL: /* local */ return commId;
402  case NORTHWEST: /* northwest */ return northwest(row, column);
403  case NORTH: /* north */ return north(row, column);
404  case NORTHEAST: /* northeast */ return northeast(row, column);
405  case WEST: /* west */ return west(row, column);
406  case EAST: /* east */ return east(row, column);
407  case SOUTHWEST: /* southwest */ return southwest(row, column);
408  case SOUTH: /* south */ return south(row, column);
409  case SOUTHEAST: /* southeast */ return southeast(row, column);
410  default: ErrorLog().printf("neighborIndex %d: bad index\n", direction); return -1;
411  }
412 }
413 
414 /*
415  * In a send/receive exchange, when rank A makes an MPI send to its neighbor in
416  * direction x,
417  * that neighbor must make a complementary MPI receive call. To get the tags
418  * correct,
419  * the receiver needs to know the direction that the sender was using in
420  * determining which
421  * process to send to.
422  *
423  * Thus, if every process does an MPI send in each direction, to the process of
424  * rank
425  * neighborIndex(icRank,direction) with tag[direction],
426  * every process must also do an MPI receive in each direction, to the process
427  * of rank
428  * neighborIndex(icRank,direction) with tag[reverseDirection(icRank,direction)].
429  */
430 int Communicator::reverseDirection(int commId, int direction) {
431  int neighbor = neighborIndex(commId, direction);
432  if (neighbor == commId) {
433  return -1;
434  }
435  int revdir = 9 - direction; // Correct unless at an edge of the MPI quilt
436  int col = commColumn(commId);
437  int row = commRow(commId);
438  switch (direction) {
439  case LOCAL:
440  assert(0); // Should have neighbor==commId, so should have already returned
441  break;
442  case NORTHWEST: /* northwest */
443  assert(revdir == SOUTHEAST);
444  if (row == 0) {
445  assert(col > 0);
446  revdir = NORTHEAST;
447  }
448  if (col == 0) {
449  assert(row > 0);
450  revdir = SOUTHWEST;
451  }
452  break;
453  case NORTH: /* north */
454  assert(commRow(commId) > 0); // If row==0, there is no north neighbor so
455  // should have already returned.
456  break;
457  case NORTHEAST: /* northeast */
458  assert(revdir == SOUTHWEST);
459  if (row == 0) {
460  assert(col < numCols - 1);
461  revdir = NORTHWEST;
462  }
463  if (col == numCols - 1) {
464  assert(row > 0);
465  revdir = SOUTHEAST;
466  }
467  break;
468  case WEST: /* west */ assert(commColumn(commId) > 0); break;
469  case EAST: /* east */ assert(commColumn(commId) < numCols - 1); break;
470  case SOUTHWEST: /* southwest */
471  assert(revdir == NORTHEAST);
472  if (row == numRows - 1) {
473  assert(col > 0);
474  revdir = SOUTHEAST;
475  }
476  if (col == 0) {
477  assert(row < numRows - 1);
478  revdir = NORTHWEST;
479  }
480  break;
481  case SOUTH: /* south */ assert(commRow(commId) < numRows - 1); break;
482  case SOUTHEAST: /* southeast */
483  assert(revdir == NORTHWEST);
484  if (row == numRows - 1) {
485  assert(col < numCols - 1);
486  revdir = SOUTHWEST;
487  }
488  if (col == numCols - 1) {
489  assert(row < numRows - 1);
490  revdir = NORTHEAST;
491  }
492  break;
493  default:
494  ErrorLog().printf("neighborIndex %d: bad index\n", direction);
495  revdir = -1;
496  break;
497  }
498  return revdir;
499 }
500 
501 // The following Communicator methods related to border exchange were moved to
502 // the BorderExchange class in utils/BorderExchange.{c,h}pp Feb 6, 2017.
503 // newDatatypes
504 // freeDatatypes
505 // exchange
506 // wait
507 // recvOffset
508 // sendOffset
509 
510 } // end namespace PV
bool hasWesternNeighbor(int commRow, int commColumn)
bool hasSoutheasternNeighbor(int commRow, int commColumn)
bool hasNortheasternNeighbor(int commRow, int commColumn)
int west(int commRow, int commColumn)
int northeast(int commRow, int commColumn)
bool hasSouthwesternNeighbor(int commRow, int commColumn)
bool hasSouthernNeighbor(int commRow, int commColumn)
int east(int commRow, int commColumn)
int north(int commRow, int commColumn)
int getRank() const
Definition: MPIBlock.hpp:100
int southwest(int commRow, int commColumn)
int northwest(int commRow, int commColumn)
int south(int commRow, int commColumn)
bool hasNeighbor(int neighborId)
bool hasNorthernNeighbor(int commRow, int commColumn)
int neighborIndex(int commId, int index)
bool hasEasternNeighbor(int commRow, int commColumn)
int southeast(int commRow, int commColumn)
bool hasNorthwesternNeighbor(int commRow, int commColumn)
int commIdFromRowColumn(int commRow, int commColumn)