PetaVision  Alpha
CheckpointEntryWeightPvp.cpp
1 /*
2  * CheckpointEntry.cpp
3  *
4  * Created on Sep 27, 2016
5  * Author: Pete Schultz
6  */
7 
8 #include "CheckpointEntryWeightPvp.hpp"
9 #include "io/WeightsFileIO.hpp"
10 #include "io/fileio.hpp"
11 #include "structures/Buffer.hpp"
12 #include "utils/BufferUtilsMPI.hpp"
13 #include "utils/BufferUtilsPvp.hpp"
14 #include <limits>
15 
16 namespace PV {
17 
18 void CheckpointEntryWeightPvp::initialize(Weights *weights, bool compressFlag) {
19  mWeights = weights;
20  mCompressFlag = compressFlag;
21 }
22 
23 void CheckpointEntryWeightPvp::write(
24  std::string const &checkpointDirectory,
25  double simTime,
26  bool verifyWritesFlag) const {
27  std::string path(checkpointDirectory);
28  path.append("/").append(getName()).append(".pvp");
29  FileStream *fileStream = nullptr;
30  if (getMPIBlock()->getRank() == 0) {
31  fileStream = new FileStream(path.c_str(), std::ios_base::out, verifyWritesFlag);
32  }
33 
34  WeightsFileIO weightFileIO(fileStream, getMPIBlock(), mWeights);
35  weightFileIO.writeWeights(simTime, mCompressFlag);
36  delete fileStream;
37 }
38 
39 void CheckpointEntryWeightPvp::read(std::string const &checkpointDirectory, double *simTimePtr)
40  const {
41  // Need to clear weights before reading because reading weights is increment-add, not assignment.
42  int const numArbors = mWeights->getNumArbors();
43  for (int arbor = 0; arbor < numArbors; arbor++) {
44  int const nxp = mWeights->getPatchSizeX();
45  int const nyp = mWeights->getPatchSizeY();
46  int const nfp = mWeights->getPatchSizeF();
47  int const numPatches = mWeights->getNumDataPatches();
48 
49  std::size_t const numWeightsInArbor = (std::size_t)(numPatches * nxp * nyp * nfp);
50  float *weightData = mWeights->getData(arbor);
51 
52  memset(weightData, 0, numWeightsInArbor * sizeof(*weightData));
53  }
54 
55  std::string path(checkpointDirectory);
56  path.append("/").append(getName()).append(".pvp");
57  FileStream *fileStream = nullptr;
58  if (getMPIBlock()->getRank() == 0) {
59  fileStream = new FileStream(path.c_str(), std::ios_base::in, false);
60  }
61 
62  WeightsFileIO weightFileIO(fileStream, getMPIBlock(), mWeights);
63  double simTime = weightFileIO.readWeights(0 /*frameNumber*/);
64  if (simTimePtr) {
65  *simTimePtr = simTime;
66  }
67  delete fileStream;
68 }
69 
70 void CheckpointEntryWeightPvp::remove(std::string const &checkpointDirectory) const {
71  deleteFile(checkpointDirectory, "pvp");
72 }
73 
74 } // end namespace PV
float * getData(int arbor)
Definition: Weights.cpp:196
int getPatchSizeX() const
Definition: Weights.hpp:219
int getNumDataPatches() const
Definition: Weights.hpp:174
int getNumArbors() const
Definition: Weights.hpp:151
int getPatchSizeY() const
Definition: Weights.hpp:222
int getPatchSizeF() const
Definition: Weights.hpp:225