PetaVision  Alpha
BatchIndexer.hpp
1 #ifndef BATCHINDEXER_HPP_
2 #define BATCHINDEXER_HPP_
3 
4 #include "checkpointing/CheckpointerDataInterface.hpp"
5 #include <vector>
6 
7 namespace PV {
8 
10 
11  public:
12  enum BatchMethod { BYFILE, BYLIST, BYSPECIFIED, RANDOM };
13 
15  std::string const &objName,
16  int globalBatchCount,
17  int globalBatchIndex,
18  int batchWidth,
19  int fileCount,
20  enum BatchMethod batchMethod,
21  bool initializeFromCheckpointFlag);
22  int nextIndex(int localBatchIndex);
23  int getIndex(int localBatchIndex);
24  void specifyBatching(int localBatchIndex, int startIndex, int skipAmount);
25  void initializeBatch(int localBatchIndex);
26  void shuffleLookupTable();
27  void setRandomSeed(unsigned int seed);
28  void setIndices(const std::vector<int> &indices) { mIndices = indices; }
29  void setWrapToStartIndex(bool value) { mWrapToStartIndex = value; }
30  bool getWrapToStartIndex() { return mWrapToStartIndex; }
31  std::vector<int> getIndices() { return mIndices; }
32 
33  virtual Response::Status registerData(Checkpointer *checkpointer) override;
34 
35  protected:
36  virtual Response::Status processCheckpointRead() override;
37  virtual Response::Status readStateFromCheckpoint(Checkpointer *checkpointer) override;
38 
42  void checkIndices();
43 
44  private:
45  std::string mObjName;
46  int mGlobalBatchCount = 0;
47  int mFileCount = 0;
48  int mBatchWidth = 0;
49  int mBatchOffset = 0;
50  unsigned int mRandomSeed = 123456789;
51  bool mWrapToStartIndex = true;
52  std::vector<int> mIndexLookupTable;
53  std::vector<int> mIndices;
54  std::vector<int> mStartIndices;
55  std::vector<int> mSkipAmounts;
56  BatchMethod mBatchMethod;
57  bool mInitializeFromCheckpointFlag = false;
58  // mInitializeFromCheckpointFlag is a hack.
59  // BatchIndexer should load the indices from checkpoint when the InputLayer's
60  // initializeFromCheckpointFlag is true, and not when it's false.
61  // The problem is that BatchIndexer can't see the InputLayer, where the
62  // initializeFromCheckpointFlag is read.
63 };
64 }
65 
66 #endif // BATCHINDEXER_HPP_