8 #include "AdaptiveTimeScaleController.hpp" 9 #include "arch/mpi/mpi.h" 10 #include "include/pv_common.h" 11 #include "io/FileStream.hpp" 12 #include "io/fileio.hpp" 13 #include "utils/PVLog.hpp" 17 AdaptiveTimeScaleController::AdaptiveTimeScaleController(
24 bool writeTimeScaleFieldnames,
25 Communicator *communicator) {
27 mBatchWidth = batchWidth;
30 mTauFactor = tauFactor;
31 mGrowthFactor = growthFactor;
32 mWriteTimeScaleFieldnames = writeTimeScaleFieldnames;
33 mCommunicator = communicator;
35 mTimeScaleInfo.mTimeScale.assign(mBatchWidth, mBaseMin);
36 mTimeScaleInfo.mTimeScaleMax.assign(mBatchWidth, mBaseMax);
37 mTimeScaleInfo.mTimeScaleTrue.assign(mBatchWidth, -1.0);
38 mOldTimeScale.assign(mBatchWidth, mBaseMin);
39 mOldTimeScaleTrue.assign(mBatchWidth, -1.0);
42 AdaptiveTimeScaleController::~AdaptiveTimeScaleController() { free(mName); }
44 Response::Status AdaptiveTimeScaleController::registerData(Checkpointer *checkpointer) {
45 auto ptr = std::make_shared<CheckpointEntryTimeScaleInfo>(
46 mName,
"timescaleinfo", checkpointer->getMPIBlock(), &mTimeScaleInfo);
47 checkpointer->registerCheckpointEntry(ptr,
false );
48 return Response::SUCCESS;
51 std::vector<double> AdaptiveTimeScaleController::calcTimesteps(
53 std::vector<double>
const &rawTimeScales) {
54 mOldTimeScaleInfo = mTimeScaleInfo;
55 mTimeScaleInfo.mTimeScaleTrue = rawTimeScales;
56 for (
int b = 0; b < mBatchWidth; b++) {
57 double E_dt = mTimeScaleInfo.mTimeScaleTrue[b];
58 double E_0 = mOldTimeScaleInfo.mTimeScaleTrue[b];
59 double dE_dt_scaled = (E_0 - E_dt) / mTimeScaleInfo.mTimeScale[b];
65 if ((dE_dt_scaled < 0.0) || (E_0 <= 0) || (E_dt <= 0)) {
66 mTimeScaleInfo.mTimeScale[b] = mBaseMin;
67 mTimeScaleInfo.mTimeScaleMax[b] = mBaseMax;
70 double tau_eff_scaled = E_0 / dE_dt_scaled;
73 mTimeScaleInfo.mTimeScale[b] = mTauFactor * tau_eff_scaled;
74 if (mTimeScaleInfo.mTimeScale[b] >= mTimeScaleInfo.mTimeScaleMax[b]) {
75 mTimeScaleInfo.mTimeScale[b] = mTimeScaleInfo.mTimeScaleMax[b];
76 mTimeScaleInfo.mTimeScaleMax[b] = (1 + mGrowthFactor) * mTimeScaleInfo.mTimeScaleMax[b];
80 return mTimeScaleInfo.mTimeScale;
83 void AdaptiveTimeScaleController::writeTimestepInfo(
85 std::vector<PrintStream *> &streams) {
86 for (
int b = 0; b < mBatchWidth; b++) {
87 auto stream = *streams.at(b);
88 if (mWriteTimeScaleFieldnames) {
89 stream.printf(
"sim_time = %f\n", timeValue);
91 "\tbatch = %d, timeScale = %10.8f, timeScaleTrue = %10.8f",
93 mTimeScaleInfo.mTimeScale[b],
94 mTimeScaleInfo.mTimeScaleTrue[b]);
97 stream.printf(
"%f, ", timeValue);
101 mTimeScaleInfo.mTimeScale[b],
102 mTimeScaleInfo.mTimeScaleTrue[b]);
104 if (mWriteTimeScaleFieldnames) {
105 stream.printf(
", timeScaleMax = %10.8f\n", mTimeScaleInfo.mTimeScaleMax[b]);
108 stream.printf(
", %10.8f\n", mTimeScaleInfo.mTimeScaleMax[b]);
114 void CheckpointEntryTimeScaleInfo::write(
115 std::string
const &checkpointDirectory,
117 bool verifyWritesFlag)
const {
118 if (getMPIBlock()->getRank() == 0) {
119 int batchWidth = (int)mTimeScaleInfoPtr->mTimeScale.size();
120 std::string path = generatePath(checkpointDirectory,
"bin");
121 FileStream fileStream{path.c_str(), std::ios_base::out, verifyWritesFlag};
122 for (
int b = 0; b < batchWidth; b++) {
123 fileStream.write(&mTimeScaleInfoPtr->mTimeScale.at(b),
sizeof(double));
124 fileStream.write(&mTimeScaleInfoPtr->mTimeScaleTrue.at(b),
sizeof(double));
125 fileStream.write(&mTimeScaleInfoPtr->mTimeScaleMax.at(b),
sizeof(double));
127 path = generatePath(checkpointDirectory,
"txt");
128 FileStream txtFileStream{path.c_str(), std::ios_base::out, verifyWritesFlag};
130 for (std::size_t b = 0; b < batchWidth; b++) {
131 txtFileStream <<
"batch index = " << b + kb0 <<
"\n";
132 txtFileStream <<
"time = " << simTime <<
"\n";
133 txtFileStream <<
"timeScale = " << mTimeScaleInfoPtr->mTimeScale[b] <<
"\n";
134 txtFileStream <<
"timeScaleTrue = " << mTimeScaleInfoPtr->mTimeScaleTrue[b] <<
"\n";
135 txtFileStream <<
"timeScaleMax = " << mTimeScaleInfoPtr->mTimeScaleMax[b] <<
"\n";
140 void CheckpointEntryTimeScaleInfo::remove(std::string
const &checkpointDirectory)
const {
141 deleteFile(checkpointDirectory,
"bin");
142 deleteFile(checkpointDirectory,
"txt");
145 void CheckpointEntryTimeScaleInfo::read(std::string
const &checkpointDirectory,
double *simTimePtr)
147 int batchWidth = (int)mTimeScaleInfoPtr->mTimeScale.size();
148 if (getMPIBlock()->getRank() == 0) {
149 std::string path = generatePath(checkpointDirectory,
"bin");
150 FileStream fileStream{path.c_str(), std::ios_base::in,
false};
151 for (std::size_t b = 0; b < batchWidth; b++) {
152 fileStream.read(&mTimeScaleInfoPtr->mTimeScale.at(b),
sizeof(double));
153 fileStream.read(&mTimeScaleInfoPtr->mTimeScaleTrue.at(b),
sizeof(double));
154 fileStream.read(&mTimeScaleInfoPtr->mTimeScaleMax.at(b),
sizeof(double));
158 mTimeScaleInfoPtr->mTimeScale.data(), batchWidth, MPI_DOUBLE, 0, getMPIBlock()->
getComm());
160 mTimeScaleInfoPtr->mTimeScaleTrue.data(),
166 mTimeScaleInfoPtr->mTimeScaleMax.data(),
int getBatchIndex() const