PetaVision  Alpha
AdaptiveTimeScaleController.cpp
1 /*
2  * AdaptiveTimeScaleController.cpp
3  *
4  * Created on: Aug 18, 2016
5  * Author: pschultz
6  */
7 
8 #include "AdaptiveTimeScaleController.hpp"
9 #include "arch/mpi/mpi.h"
10 #include "include/pv_common.h"
11 #include "io/FileStream.hpp"
12 #include "io/fileio.hpp"
13 #include "utils/PVLog.hpp"
14 
15 namespace PV {
16 
17 AdaptiveTimeScaleController::AdaptiveTimeScaleController(
18  char const *name,
19  int batchWidth,
20  double baseMax,
21  double baseMin,
22  double tauFactor,
23  double growthFactor,
24  bool writeTimeScaleFieldnames,
25  Communicator *communicator) {
26  mName = strdup(name);
27  mBatchWidth = batchWidth;
28  mBaseMax = baseMax;
29  mBaseMin = baseMin;
30  mTauFactor = tauFactor;
31  mGrowthFactor = growthFactor;
32  mWriteTimeScaleFieldnames = writeTimeScaleFieldnames;
33  mCommunicator = communicator;
34 
35  mTimeScaleInfo.mTimeScale.assign(mBatchWidth, mBaseMin);
36  mTimeScaleInfo.mTimeScaleMax.assign(mBatchWidth, mBaseMax);
37  mTimeScaleInfo.mTimeScaleTrue.assign(mBatchWidth, -1.0);
38  mOldTimeScale.assign(mBatchWidth, mBaseMin);
39  mOldTimeScaleTrue.assign(mBatchWidth, -1.0);
40 }
41 
42 AdaptiveTimeScaleController::~AdaptiveTimeScaleController() { free(mName); }
43 
44 Response::Status AdaptiveTimeScaleController::registerData(Checkpointer *checkpointer) {
45  auto ptr = std::make_shared<CheckpointEntryTimeScaleInfo>(
46  mName, "timescaleinfo", checkpointer->getMPIBlock(), &mTimeScaleInfo);
47  checkpointer->registerCheckpointEntry(ptr, false /*not constant*/);
48  return Response::SUCCESS;
49 }
50 
51 std::vector<double> AdaptiveTimeScaleController::calcTimesteps(
52  double timeValue,
53  std::vector<double> const &rawTimeScales) {
54  mOldTimeScaleInfo = mTimeScaleInfo;
55  mTimeScaleInfo.mTimeScaleTrue = rawTimeScales;
56  for (int b = 0; b < mBatchWidth; b++) {
57  double E_dt = mTimeScaleInfo.mTimeScaleTrue[b];
58  double E_0 = mOldTimeScaleInfo.mTimeScaleTrue[b];
59  double dE_dt_scaled = (E_0 - E_dt) / mTimeScaleInfo.mTimeScale[b];
60 
61  if (E_dt == E_0) {
62  continue;
63  }
64 
65  if ((dE_dt_scaled < 0.0) || (E_0 <= 0) || (E_dt <= 0)) {
66  mTimeScaleInfo.mTimeScale[b] = mBaseMin;
67  mTimeScaleInfo.mTimeScaleMax[b] = mBaseMax;
68  }
69  else {
70  double tau_eff_scaled = E_0 / dE_dt_scaled;
71 
72  // dt := mTimeScaleMaxBase * tau_eff
73  mTimeScaleInfo.mTimeScale[b] = mTauFactor * tau_eff_scaled;
74  if (mTimeScaleInfo.mTimeScale[b] >= mTimeScaleInfo.mTimeScaleMax[b]) {
75  mTimeScaleInfo.mTimeScale[b] = mTimeScaleInfo.mTimeScaleMax[b];
76  mTimeScaleInfo.mTimeScaleMax[b] = (1 + mGrowthFactor) * mTimeScaleInfo.mTimeScaleMax[b];
77  }
78  }
79  }
80  return mTimeScaleInfo.mTimeScale;
81 }
82 
83 void AdaptiveTimeScaleController::writeTimestepInfo(
84  double timeValue,
85  std::vector<PrintStream *> &streams) {
86  for (int b = 0; b < mBatchWidth; b++) {
87  auto stream = *streams.at(b);
88  if (mWriteTimeScaleFieldnames) {
89  stream.printf("sim_time = %f\n", timeValue);
90  stream.printf(
91  "\tbatch = %d, timeScale = %10.8f, timeScaleTrue = %10.8f",
92  b,
93  mTimeScaleInfo.mTimeScale[b],
94  mTimeScaleInfo.mTimeScaleTrue[b]);
95  }
96  else {
97  stream.printf("%f, ", timeValue);
98  stream.printf(
99  "%d, %10.8f, %10.8f",
100  b,
101  mTimeScaleInfo.mTimeScale[b],
102  mTimeScaleInfo.mTimeScaleTrue[b]);
103  }
104  if (mWriteTimeScaleFieldnames) {
105  stream.printf(", timeScaleMax = %10.8f\n", mTimeScaleInfo.mTimeScaleMax[b]);
106  }
107  else {
108  stream.printf(", %10.8f\n", mTimeScaleInfo.mTimeScaleMax[b]);
109  }
110  stream.flush();
111  }
112 }
113 
114 void CheckpointEntryTimeScaleInfo::write(
115  std::string const &checkpointDirectory,
116  double simTime,
117  bool verifyWritesFlag) const {
118  if (getMPIBlock()->getRank() == 0) {
119  int batchWidth = (int)mTimeScaleInfoPtr->mTimeScale.size();
120  std::string path = generatePath(checkpointDirectory, "bin");
121  FileStream fileStream{path.c_str(), std::ios_base::out, verifyWritesFlag};
122  for (int b = 0; b < batchWidth; b++) {
123  fileStream.write(&mTimeScaleInfoPtr->mTimeScale.at(b), sizeof(double));
124  fileStream.write(&mTimeScaleInfoPtr->mTimeScaleTrue.at(b), sizeof(double));
125  fileStream.write(&mTimeScaleInfoPtr->mTimeScaleMax.at(b), sizeof(double));
126  }
127  path = generatePath(checkpointDirectory, "txt");
128  FileStream txtFileStream{path.c_str(), std::ios_base::out, verifyWritesFlag};
129  int kb0 = getMPIBlock()->getBatchIndex() * batchWidth;
130  for (std::size_t b = 0; b < batchWidth; b++) {
131  txtFileStream << "batch index = " << b + kb0 << "\n";
132  txtFileStream << "time = " << simTime << "\n";
133  txtFileStream << "timeScale = " << mTimeScaleInfoPtr->mTimeScale[b] << "\n";
134  txtFileStream << "timeScaleTrue = " << mTimeScaleInfoPtr->mTimeScaleTrue[b] << "\n";
135  txtFileStream << "timeScaleMax = " << mTimeScaleInfoPtr->mTimeScaleMax[b] << "\n";
136  }
137  }
138 }
139 
140 void CheckpointEntryTimeScaleInfo::remove(std::string const &checkpointDirectory) const {
141  deleteFile(checkpointDirectory, "bin");
142  deleteFile(checkpointDirectory, "txt");
143 }
144 
145 void CheckpointEntryTimeScaleInfo::read(std::string const &checkpointDirectory, double *simTimePtr)
146  const {
147  int batchWidth = (int)mTimeScaleInfoPtr->mTimeScale.size();
148  if (getMPIBlock()->getRank() == 0) {
149  std::string path = generatePath(checkpointDirectory, "bin");
150  FileStream fileStream{path.c_str(), std::ios_base::in, false};
151  for (std::size_t b = 0; b < batchWidth; b++) {
152  fileStream.read(&mTimeScaleInfoPtr->mTimeScale.at(b), sizeof(double));
153  fileStream.read(&mTimeScaleInfoPtr->mTimeScaleTrue.at(b), sizeof(double));
154  fileStream.read(&mTimeScaleInfoPtr->mTimeScaleMax.at(b), sizeof(double));
155  }
156  }
157  MPI_Bcast(
158  mTimeScaleInfoPtr->mTimeScale.data(), batchWidth, MPI_DOUBLE, 0, getMPIBlock()->getComm());
159  MPI_Bcast(
160  mTimeScaleInfoPtr->mTimeScaleTrue.data(),
161  batchWidth,
162  MPI_DOUBLE,
163  0,
164  getMPIBlock()->getComm());
165  MPI_Bcast(
166  mTimeScaleInfoPtr->mTimeScaleMax.data(),
167  batchWidth,
168  MPI_DOUBLE,
169  0,
170  getMPIBlock()->getComm());
171 }
172 
173 } /* namespace PV */
MPI_Comm getComm() const
Definition: MPIBlock.hpp:90
int getBatchIndex() const
Definition: MPIBlock.hpp:171