PetaVision  Alpha
BatchIndexer.cpp
1 #include "BatchIndexer.hpp"
2 #include "utils/PVLog.hpp"
3 #include <algorithm>
4 #include <random>
5 
6 namespace PV {
7 
8 // This takes in the global batch index of local batch 0 for the second argument.
9 // Should this be the value of commBatch() instead?
10 BatchIndexer::BatchIndexer(
11  std::string const &objName,
12  int globalBatchCount,
13  int batchOffset,
14  int batchWidth,
15  int fileCount,
16  enum BatchMethod batchMethod,
17  bool initializeFromCheckpointFlag) {
18  mObjName = objName;
19  mGlobalBatchCount = globalBatchCount;
20  mBatchMethod = batchMethod;
21  mFileCount = fileCount ? fileCount : 1;
22  mBatchWidth = batchWidth;
23  mBatchOffset = batchOffset;
24  mInitializeFromCheckpointFlag = initializeFromCheckpointFlag;
25  mIndices.resize(mBatchWidth, 0);
26  mStartIndices.resize(mBatchWidth, 0);
27  mSkipAmounts.resize(mBatchWidth, 0);
28  shuffleLookupTable();
29 }
30 
31 int BatchIndexer::nextIndex(int localBatchIndex) {
32  int result = getIndex(localBatchIndex);
33  int newIndex = mIndices.at(localBatchIndex) + mSkipAmounts.at(localBatchIndex);
34  if (newIndex >= mFileCount) {
35  shuffleLookupTable();
36  if (mWrapToStartIndex) {
37  newIndex = mStartIndices.at(localBatchIndex);
38  }
39  else {
40  newIndex %= mFileCount;
41  }
42  }
43  mIndices.at(localBatchIndex) = newIndex;
44  return newIndex;
45 }
46 
47 int BatchIndexer::getIndex(int localBatchIndex) {
48  if (mBatchMethod != RANDOM) {
49  return mIndices.at(localBatchIndex);
50  }
51  return mIndexLookupTable.at(mIndices.at(localBatchIndex));
52 }
53 
54 void BatchIndexer::specifyBatching(int localBatchIndex, int startIndex, int skipAmount) {
55  mStartIndices.at(localBatchIndex) = startIndex % mFileCount;
56  mSkipAmounts.at(localBatchIndex) = skipAmount < 1 ? 1 : skipAmount;
57 }
58 
59 void BatchIndexer::initializeBatch(int localBatchIndex) {
60  int globalBatchIndex = mBatchOffset + localBatchIndex;
61  switch (mBatchMethod) {
62  case RANDOM:
63  case BYFILE:
64  specifyBatching(
65  localBatchIndex,
66  mStartIndices.at(localBatchIndex) + globalBatchIndex,
67  mGlobalBatchCount);
68  break;
69  case BYLIST:
70  specifyBatching(
71  localBatchIndex,
72  mStartIndices.at(localBatchIndex)
73  + globalBatchIndex * (mFileCount / mGlobalBatchCount),
74  1);
75  break;
76  case BYSPECIFIED:
77  FatalIf(
78  mSkipAmounts.at(localBatchIndex) < 1,
79  "BatchIndexer batchMethod was set to BYSPECIFIED, but no values were specified.\n");
80  break;
81  }
82  mIndices.at(localBatchIndex) = mStartIndices.at(localBatchIndex);
83 }
84 
85 void BatchIndexer::setRandomSeed(unsigned int seed) {
86  mRandomSeed = seed;
87  shuffleLookupTable();
88 }
89 
90 // This clears the current file index lookup table and fills it with
91 // randomly ordered ints from 0 to mFileCount. The random seed is
92 // incremented so the next time this is called it will result in a new order.
93 // Two objects with BatchIndexers with the same seed will randomize the order
94 // in the same manner.
95 void BatchIndexer::shuffleLookupTable() {
96  if (mBatchMethod != RANDOM) {
97  return;
98  }
99  mIndexLookupTable.clear();
100  mIndexLookupTable.resize(mFileCount);
101  for (int i = 0; i < mFileCount; ++i) {
102  mIndexLookupTable.at(i) = i;
103  }
104  std::mt19937 rng(mRandomSeed++);
105  std::shuffle(mIndexLookupTable.begin(), mIndexLookupTable.end(), rng);
106 }
107 
108 Response::Status BatchIndexer::registerData(Checkpointer *checkpointer) {
109  auto status = CheckpointerDataInterface::registerData(checkpointer);
110  if (!Response::completed(status)) {
111  return status;
112  }
113  checkpointer->registerCheckpointData<int>(
114  mObjName,
115  std::string("FrameNumbers"),
116  mIndices.data(),
117  mIndices.size(),
118  false /*do not broadcast*/,
119  false /*not constant*/);
120  if (mBatchMethod == RANDOM) {
121  checkpointer->registerCheckpointData<unsigned int>(
122  mObjName,
123  std::string("RandomSeed"),
124  &mRandomSeed,
125  1,
126  false /*do not broadcast*/,
127  false /*not constant*/);
128  }
129  return Response::SUCCESS;
130 }
131 
132 Response::Status BatchIndexer::processCheckpointRead() {
133  checkIndices();
134  return Response::SUCCESS;
135 }
136 
137 Response::Status BatchIndexer::readStateFromCheckpoint(Checkpointer *checkpointer) {
138  if (mInitializeFromCheckpointFlag) {
139  checkpointer->readNamedCheckpointEntry(mObjName, "FrameNumbers", false /*not constant*/);
140  checkIndices();
141  return Response::SUCCESS;
142  }
143  else {
144  return Response::NO_ACTION;
145  }
146 }
147 
149  for (int k = 0; k < mBatchWidth; k++) {
150  int n = getIndex(k);
151  FatalIf(
152  n >= mFileCount,
153  "BatchIndexer \"%s\" has index %d=%d, but fileCount is only %d.\n",
154  mObjName.c_str(),
155  k,
156  n,
157  mFileCount);
158  FatalIf(
159  n < 0,
160  "BatchIndexer \"%s\" has index %d=%d. Indices cannot be negative.\n",
161  mObjName.c_str(),
162  k,
163  n);
164  }
165 }
166 
167 } // end namespace PV
static bool completed(Status &a)
Definition: Response.hpp:49