1 #include "BatchIndexer.hpp" 2 #include "utils/PVLog.hpp" 10 BatchIndexer::BatchIndexer(
11 std::string
const &objName,
16 enum BatchMethod batchMethod,
17 bool initializeFromCheckpointFlag) {
19 mGlobalBatchCount = globalBatchCount;
20 mBatchMethod = batchMethod;
21 mFileCount = fileCount ? fileCount : 1;
22 mBatchWidth = batchWidth;
23 mBatchOffset = batchOffset;
24 mInitializeFromCheckpointFlag = initializeFromCheckpointFlag;
25 mIndices.resize(mBatchWidth, 0);
26 mStartIndices.resize(mBatchWidth, 0);
27 mSkipAmounts.resize(mBatchWidth, 0);
31 int BatchIndexer::nextIndex(
int localBatchIndex) {
32 int result = getIndex(localBatchIndex);
33 int newIndex = mIndices.at(localBatchIndex) + mSkipAmounts.at(localBatchIndex);
34 if (newIndex >= mFileCount) {
36 if (mWrapToStartIndex) {
37 newIndex = mStartIndices.at(localBatchIndex);
40 newIndex %= mFileCount;
43 mIndices.at(localBatchIndex) = newIndex;
47 int BatchIndexer::getIndex(
int localBatchIndex) {
48 if (mBatchMethod != RANDOM) {
49 return mIndices.at(localBatchIndex);
51 return mIndexLookupTable.at(mIndices.at(localBatchIndex));
54 void BatchIndexer::specifyBatching(
int localBatchIndex,
int startIndex,
int skipAmount) {
55 mStartIndices.at(localBatchIndex) = startIndex % mFileCount;
56 mSkipAmounts.at(localBatchIndex) = skipAmount < 1 ? 1 : skipAmount;
59 void BatchIndexer::initializeBatch(
int localBatchIndex) {
60 int globalBatchIndex = mBatchOffset + localBatchIndex;
61 switch (mBatchMethod) {
66 mStartIndices.at(localBatchIndex) + globalBatchIndex,
72 mStartIndices.at(localBatchIndex)
73 + globalBatchIndex * (mFileCount / mGlobalBatchCount),
78 mSkipAmounts.at(localBatchIndex) < 1,
79 "BatchIndexer batchMethod was set to BYSPECIFIED, but no values were specified.\n");
82 mIndices.at(localBatchIndex) = mStartIndices.at(localBatchIndex);
85 void BatchIndexer::setRandomSeed(
unsigned int seed) {
95 void BatchIndexer::shuffleLookupTable() {
96 if (mBatchMethod != RANDOM) {
99 mIndexLookupTable.clear();
100 mIndexLookupTable.resize(mFileCount);
101 for (
int i = 0; i < mFileCount; ++i) {
102 mIndexLookupTable.at(i) = i;
104 std::mt19937 rng(mRandomSeed++);
105 std::shuffle(mIndexLookupTable.begin(), mIndexLookupTable.end(), rng);
108 Response::Status BatchIndexer::registerData(Checkpointer *checkpointer) {
109 auto status = CheckpointerDataInterface::registerData(checkpointer);
113 checkpointer->registerCheckpointData<
int>(
115 std::string(
"FrameNumbers"),
120 if (mBatchMethod == RANDOM) {
121 checkpointer->registerCheckpointData<
unsigned int>(
123 std::string(
"RandomSeed"),
129 return Response::SUCCESS;
132 Response::Status BatchIndexer::processCheckpointRead() {
134 return Response::SUCCESS;
137 Response::Status BatchIndexer::readStateFromCheckpoint(Checkpointer *checkpointer) {
138 if (mInitializeFromCheckpointFlag) {
139 checkpointer->readNamedCheckpointEntry(mObjName,
"FrameNumbers",
false );
141 return Response::SUCCESS;
144 return Response::NO_ACTION;
149 for (
int k = 0; k < mBatchWidth; k++) {
153 "BatchIndexer \"%s\" has index %d=%d, but fileCount is only %d.\n",
160 "BatchIndexer \"%s\" has index %d=%d. Indices cannot be negative.\n",
static bool completed(Status &a)