1 #ifndef BATCHINDEXER_HPP_ 2 #define BATCHINDEXER_HPP_ 4 #include "checkpointing/CheckpointerDataInterface.hpp" 12 enum BatchMethod { BYFILE, BYLIST, BYSPECIFIED, RANDOM };
15 std::string
const &objName,
20 enum BatchMethod batchMethod,
21 bool initializeFromCheckpointFlag);
22 int nextIndex(
int localBatchIndex);
23 int getIndex(
int localBatchIndex);
24 void specifyBatching(
int localBatchIndex,
int startIndex,
int skipAmount);
25 void initializeBatch(
int localBatchIndex);
26 void shuffleLookupTable();
27 void setRandomSeed(
unsigned int seed);
28 void setIndices(
const std::vector<int> &indices) { mIndices = indices; }
29 void setWrapToStartIndex(
bool value) { mWrapToStartIndex = value; }
30 bool getWrapToStartIndex() {
return mWrapToStartIndex; }
31 std::vector<int> getIndices() {
return mIndices; }
33 virtual Response::Status registerData(
Checkpointer *checkpointer)
override;
36 virtual Response::Status processCheckpointRead()
override;
37 virtual Response::Status readStateFromCheckpoint(
Checkpointer *checkpointer)
override;
46 int mGlobalBatchCount = 0;
50 unsigned int mRandomSeed = 123456789;
51 bool mWrapToStartIndex =
true;
52 std::vector<int> mIndexLookupTable;
53 std::vector<int> mIndices;
54 std::vector<int> mStartIndices;
55 std::vector<int> mSkipAmounts;
56 BatchMethod mBatchMethod;
57 bool mInitializeFromCheckpointFlag =
false;
66 #endif // BATCHINDEXER_HPP_