PetaVision  Alpha
MPIBlock.cpp
1 #include "MPIBlock.hpp"
2 #include "utils/PVAssert.hpp"
3 #include "utils/PVLog.hpp"
4 #include "utils/conversions.h"
5 #include <cmath>
6 #include <stdexcept>
7 
8 namespace PV {
9 
11  MPI_Comm comm,
12  int globalNumRows,
13  int globalNumColumns,
14  int globalBatchDimension,
15  int blockNumRows,
16  int blockNumColumns,
17  int blockBatchDimension) {
18 
19  mGlobalComm = comm;
20  initGlobalDimensions(comm, globalNumRows, globalNumColumns, globalBatchDimension);
21  initBlockDimensions(blockNumRows, blockNumColumns, blockBatchDimension);
22  initBlockLocation(comm);
23  createBlockComm(comm);
24 }
25 
27  MPI_Comm comm,
28  int globalNumRows,
29  int globalNumColumns,
30  int globalBatchDimension) {
31 
32  int globalRank, numProcsAvailable;
33  MPI_Comm_rank(comm, &globalRank);
34  MPI_Comm_size(comm, &numProcsAvailable);
35 
36  mGlobalNumRows = globalNumRows;
37  mGlobalNumColumns = globalNumColumns;
38  mGlobalBatchDimension = globalBatchDimension;
39 
40  bool numRowsDefined = globalNumRows > 0;
41  bool numColumnsDefined = globalNumColumns > 0;
42  mGlobalBatchDimension = globalBatchDimension > 0 ? globalBatchDimension : 1;
43 
44  int procsLeft = numProcsAvailable / mGlobalBatchDimension;
45  if (numRowsDefined && numColumnsDefined) {
46  mGlobalNumRows = globalNumRows;
47  mGlobalNumColumns = globalNumColumns;
48  }
49  if (numRowsDefined && !numColumnsDefined) {
50  mGlobalNumRows = globalNumRows;
51  mGlobalNumColumns = (int)ceil(procsLeft / globalNumRows);
52  }
53  if (!numRowsDefined && numColumnsDefined) {
54  mGlobalNumRows = (int)ceil(procsLeft / globalNumColumns);
55  mGlobalNumColumns = globalNumRows;
56  }
57  if (!numRowsDefined && !numColumnsDefined) {
58  double r = std::sqrt(procsLeft);
59  mGlobalNumRows = (int)r;
60  FatalIf(mGlobalNumRows == 0, "Not enough processes left\n");
61  mGlobalNumColumns = (int)ceil(procsLeft / mGlobalNumRows);
62  }
63 
64  int numProcsNeeded = mGlobalBatchDimension * mGlobalNumRows * mGlobalNumColumns;
65 
66  if (globalRank == 0) {
67  FatalIf(
68  numProcsNeeded > numProcsAvailable,
69  "Number of processes required (%d) is larger than the "
70  "number of processes available (%d)\n",
71  numProcsNeeded,
72  numProcsAvailable);
73  }
74 }
75 
76 void MPIBlock::initBlockDimensions(int blockNumRows, int blockNumColumns, int blockBatchDimension) {
77  mNumRows = blockNumRows > 0 ? blockNumRows : mGlobalNumRows;
78  mNumColumns = blockNumColumns > 0 ? blockNumColumns : mGlobalNumColumns;
79  mBatchDimension = blockBatchDimension > 0 ? blockBatchDimension : mGlobalBatchDimension;
80 }
81 
82 void MPIBlock::initBlockLocation(MPI_Comm comm) {
83  MPI_Comm_rank(comm, &mGlobalRank);
84  int const globalColumnIndex = mGlobalRank % mGlobalNumColumns;
85  int const globalRowIndex = (mGlobalRank / mGlobalNumColumns) % mGlobalNumRows;
86  int const globalBatchIndex = mGlobalRank / (mGlobalNumColumns * mGlobalNumRows);
87  int checkRank = globalBatchIndex * mGlobalNumRows + globalRowIndex;
88  checkRank = checkRank * mGlobalNumColumns + globalColumnIndex;
89  pvAssert(checkRank == mGlobalRank);
90 
91  mRowIndex = globalRowIndex % mNumRows;
92  mStartRow = globalRowIndex - mRowIndex;
93 
94  mColumnIndex = globalColumnIndex % mNumColumns;
95  mStartColumn = globalColumnIndex - mColumnIndex;
96 
97  mBatchIndex = globalBatchIndex % mBatchDimension;
98  mStartBatch = globalBatchIndex - mBatchIndex;
99 }
100 
101 void MPIBlock::createBlockComm(MPI_Comm comm) {
102  int rowBlock = mStartRow / mNumRows;
103  int columnBlock = mStartColumn / mNumColumns;
104  int batchBlock = mStartBatch / mBatchDimension;
105  int cellsInGlobalRow = calcNumCells(mNumRows, mGlobalNumRows);
106  int cellsInGlobalColumn = calcNumCells(mNumColumns, mGlobalNumColumns);
107  int cellIndex =
108  rankFromRowAndColumn(rowBlock, columnBlock, cellsInGlobalRow, cellsInGlobalColumn);
109  cellIndex += batchBlock * cellsInGlobalRow * cellsInGlobalColumn;
110 
111  int cellRank = rankFromRowAndColumn(mRowIndex, mColumnIndex, mNumRows, mNumColumns);
112  cellRank += mBatchIndex * mNumRows * mNumColumns;
113  int numProcsNeeded = mGlobalBatchDimension * mGlobalNumRows * mGlobalNumColumns;
114 
115  MPI_Comm_split(comm, cellIndex, cellRank, &mComm);
116  MPI_Comm_rank(mComm, &mRank);
117  if (mRank < numProcsNeeded && mRank != cellRank) {
118  Fatal().printf("Global rank %d, cellRank %d, mRank %d\n", mGlobalRank, cellRank, mRank);
119  }
120 }
121 
122 int MPIBlock::calcNumCells(int cellSize, int overallSize) {
123  int numCells = overallSize / cellSize; // integer division
124  if (overallSize % cellSize != 0) {
125  numCells++;
126  }
127  return numCells;
128 }
129 
131  int const rowIndex,
132  int const columnIndex,
133  int const batchIndex) const {
134  if (rowIndex < 0 || rowIndex >= getNumRows() || columnIndex < 0 || columnIndex >= getNumColumns()
135  || batchIndex >= getBatchDimension()) {
136  throw std::invalid_argument("calcRankFromRowColBatch");
137  }
138  return columnIndex + getNumColumns() * (rowIndex + getNumRows() * batchIndex);
139 }
140 
141 void MPIBlock::checkRankInBounds(int const rank) const {
142  if (rank < 0 || rank >= getSize()) {
143  throw std::invalid_argument("calcRowColBatchFromRank");
144  }
145 }
146 
148  int const rank,
149  int &rowIndex,
150  int &columnIndex,
151  int &batchIndex) const {
152  checkRankInBounds(rank);
153  columnIndex = calcColumnFromRankInternal(rank);
154  rowIndex = calcRowFromRankInternal(rank);
155  batchIndex = calcBatchIndexFromRankInternal(rank);
156 }
157 
158 int MPIBlock::calcRowFromRank(int const rank) const {
159  checkRankInBounds(rank);
160  return calcRowFromRankInternal(rank);
161 }
162 
163 int MPIBlock::calcColumnFromRank(int const rank) const {
164  checkRankInBounds(rank);
165  return calcColumnFromRankInternal(rank);
166 }
167 
168 int MPIBlock::calcBatchIndexFromRank(int const rank) const {
169  checkRankInBounds(rank);
170  return calcBatchIndexFromRankInternal(rank);
171 }
172 
173 int MPIBlock::calcRowFromRankInternal(int const rank) const {
174  return (rank / getNumColumns()) % getNumRows(); // Integer division
175 }
176 
177 int MPIBlock::calcColumnFromRankInternal(int const rank) const { return rank % getNumColumns(); }
178 
179 int MPIBlock::calcBatchIndexFromRankInternal(int const rank) const {
180  return rank / (getNumColumns() * getNumRows()); // Integer division
181 }
182 
183 } // end namespace PV
void initGlobalDimensions(MPI_Comm comm, int const globalNumRows, int const globalNumColumns, int const globalBatchDimension)
Definition: MPIBlock.cpp:26
void initBlockDimensions(int const blockNumRows, int const blockNumColumns, int const globalBatchDimension)
Definition: MPIBlock.cpp:76
MPIBlock(MPI_Comm comm, int globalNumRows, int globalNumColumns, int globalBatchDimension, int blockNumRows, int blockNumColumns, int blockBatchDimension)
Definition: MPIBlock.cpp:10
int getNumColumns() const
Definition: MPIBlock.hpp:130
int getNumRows() const
Definition: MPIBlock.hpp:125
void checkRankInBounds(int const rank) const
Definition: MPIBlock.cpp:141
int calcRankFromRowColBatch(int const rowIndex, int const columnIndex, int const batchIndex) const
Definition: MPIBlock.cpp:130
int getBatchDimension() const
Definition: MPIBlock.hpp:135
void initBlockLocation(MPI_Comm comm)
Definition: MPIBlock.cpp:82
void createBlockComm(MPI_Comm comm)
Definition: MPIBlock.cpp:101
int calcBatchIndexFromRank(int const rank) const
Definition: MPIBlock.cpp:168
int calcBatchIndexFromRankInternal(int const rank) const
Definition: MPIBlock.cpp:179
int calcColumnFromRank(int const rank) const
Definition: MPIBlock.cpp:163
int calcColumnFromRankInternal(int const rank) const
Definition: MPIBlock.cpp:177
void calcRowColBatchFromRank(int const rank, int &rowIndex, int &columnIndex, int &batchIndex) const
Definition: MPIBlock.cpp:147
int getSize() const
Definition: MPIBlock.hpp:141
int calcRowFromRank(int const rank) const
Definition: MPIBlock.cpp:158
int calcRowFromRankInternal(int const rank) const
Definition: MPIBlock.cpp:173
static int calcNumCells(int cellSize, int overallSize)
Definition: MPIBlock.cpp:122