1 #include "MPIBlock.hpp" 2 #include "utils/PVAssert.hpp" 3 #include "utils/PVLog.hpp" 4 #include "utils/conversions.h" 14 int globalBatchDimension,
17 int blockBatchDimension) {
30 int globalBatchDimension) {
32 int globalRank, numProcsAvailable;
33 MPI_Comm_rank(comm, &globalRank);
34 MPI_Comm_size(comm, &numProcsAvailable);
36 mGlobalNumRows = globalNumRows;
37 mGlobalNumColumns = globalNumColumns;
38 mGlobalBatchDimension = globalBatchDimension;
40 bool numRowsDefined = globalNumRows > 0;
41 bool numColumnsDefined = globalNumColumns > 0;
42 mGlobalBatchDimension = globalBatchDimension > 0 ? globalBatchDimension : 1;
44 int procsLeft = numProcsAvailable / mGlobalBatchDimension;
45 if (numRowsDefined && numColumnsDefined) {
46 mGlobalNumRows = globalNumRows;
47 mGlobalNumColumns = globalNumColumns;
49 if (numRowsDefined && !numColumnsDefined) {
50 mGlobalNumRows = globalNumRows;
51 mGlobalNumColumns = (int)ceil(procsLeft / globalNumRows);
53 if (!numRowsDefined && numColumnsDefined) {
54 mGlobalNumRows = (int)ceil(procsLeft / globalNumColumns);
55 mGlobalNumColumns = globalNumRows;
57 if (!numRowsDefined && !numColumnsDefined) {
58 double r = std::sqrt(procsLeft);
59 mGlobalNumRows = (int)r;
60 FatalIf(mGlobalNumRows == 0,
"Not enough processes left\n");
61 mGlobalNumColumns = (int)ceil(procsLeft / mGlobalNumRows);
64 int numProcsNeeded = mGlobalBatchDimension * mGlobalNumRows * mGlobalNumColumns;
66 if (globalRank == 0) {
68 numProcsNeeded > numProcsAvailable,
69 "Number of processes required (%d) is larger than the " 70 "number of processes available (%d)\n",
77 mNumRows = blockNumRows > 0 ? blockNumRows : mGlobalNumRows;
78 mNumColumns = blockNumColumns > 0 ? blockNumColumns : mGlobalNumColumns;
79 mBatchDimension = blockBatchDimension > 0 ? blockBatchDimension : mGlobalBatchDimension;
83 MPI_Comm_rank(comm, &mGlobalRank);
84 int const globalColumnIndex = mGlobalRank % mGlobalNumColumns;
85 int const globalRowIndex = (mGlobalRank / mGlobalNumColumns) % mGlobalNumRows;
86 int const globalBatchIndex = mGlobalRank / (mGlobalNumColumns * mGlobalNumRows);
87 int checkRank = globalBatchIndex * mGlobalNumRows + globalRowIndex;
88 checkRank = checkRank * mGlobalNumColumns + globalColumnIndex;
89 pvAssert(checkRank == mGlobalRank);
91 mRowIndex = globalRowIndex % mNumRows;
92 mStartRow = globalRowIndex - mRowIndex;
94 mColumnIndex = globalColumnIndex % mNumColumns;
95 mStartColumn = globalColumnIndex - mColumnIndex;
97 mBatchIndex = globalBatchIndex % mBatchDimension;
98 mStartBatch = globalBatchIndex - mBatchIndex;
102 int rowBlock = mStartRow / mNumRows;
103 int columnBlock = mStartColumn / mNumColumns;
104 int batchBlock = mStartBatch / mBatchDimension;
105 int cellsInGlobalRow =
calcNumCells(mNumRows, mGlobalNumRows);
106 int cellsInGlobalColumn =
calcNumCells(mNumColumns, mGlobalNumColumns);
108 rankFromRowAndColumn(rowBlock, columnBlock, cellsInGlobalRow, cellsInGlobalColumn);
109 cellIndex += batchBlock * cellsInGlobalRow * cellsInGlobalColumn;
111 int cellRank = rankFromRowAndColumn(mRowIndex, mColumnIndex, mNumRows, mNumColumns);
112 cellRank += mBatchIndex * mNumRows * mNumColumns;
113 int numProcsNeeded = mGlobalBatchDimension * mGlobalNumRows * mGlobalNumColumns;
115 MPI_Comm_split(comm, cellIndex, cellRank, &mComm);
116 MPI_Comm_rank(mComm, &mRank);
117 if (mRank < numProcsNeeded && mRank != cellRank) {
118 Fatal().printf(
"Global rank %d, cellRank %d, mRank %d\n", mGlobalRank, cellRank, mRank);
123 int numCells = overallSize / cellSize;
124 if (overallSize % cellSize != 0) {
132 int const columnIndex,
133 int const batchIndex)
const {
136 throw std::invalid_argument(
"calcRankFromRowColBatch");
142 if (rank < 0 || rank >=
getSize()) {
143 throw std::invalid_argument(
"calcRowColBatchFromRank");
151 int &batchIndex)
const {
void initGlobalDimensions(MPI_Comm comm, int const globalNumRows, int const globalNumColumns, int const globalBatchDimension)
void initBlockDimensions(int const blockNumRows, int const blockNumColumns, int const globalBatchDimension)
MPIBlock(MPI_Comm comm, int globalNumRows, int globalNumColumns, int globalBatchDimension, int blockNumRows, int blockNumColumns, int blockBatchDimension)
int getNumColumns() const
void checkRankInBounds(int const rank) const
int calcRankFromRowColBatch(int const rowIndex, int const columnIndex, int const batchIndex) const
int getBatchDimension() const
void initBlockLocation(MPI_Comm comm)
void createBlockComm(MPI_Comm comm)
int calcBatchIndexFromRank(int const rank) const
int calcBatchIndexFromRankInternal(int const rank) const
int calcColumnFromRank(int const rank) const
int calcColumnFromRankInternal(int const rank) const
void calcRowColBatchFromRank(int const rank, int &rowIndex, int &columnIndex, int &batchIndex) const
int calcRowFromRank(int const rank) const
int calcRowFromRankInternal(int const rank) const
static int calcNumCells(int cellSize, int overallSize)