6 #include "InputLayer.hpp" 7 #include "columns/RandomSeed.hpp" 8 #include "utils/BufferUtilsMPI.hpp" 14 InputLayer::InputLayer(
const char *name, HyPerCol *hc) { initialize(name, hc); }
16 InputLayer::~InputLayer() {
17 delete mBorderExchanger;
18 delete mTimestampStream;
21 int InputLayer::initialize(
const char *name, HyPerCol *hc) {
26 Response::Status InputLayer::allocateDataStructures() {
27 auto status = HyPerLayer::allocateDataStructures();
31 if (mNeedInputRegionsPointer) {
32 mInputRegionsAllBatchElements.resize(getNumExtendedAllBatches());
34 return Response::SUCCESS;
37 void InputLayer::initializeBatchIndexer() {
39 pvAssert(getMPIBlock());
40 pvAssert(getMPIBlock()->getRank() == 0);
41 int localBatchCount = getLayerLoc()->nbatch;
44 int globalBatchCount = localBatchCount * mpiGlobalCount;
45 int batchOffset = localBatchCount * getMPIBlock()->
getStartBatch();
48 mBatchIndexer = std::unique_ptr<BatchIndexer>(
56 initializeFromCheckpointFlag));
57 for (
int b = 0; b < blockBatchCount; ++b) {
58 mBatchIndexer->specifyBatching(
59 b, mStartFrameIndex.at(batchOffset + b), mSkipFrameIndex.at(batchOffset + b));
60 mBatchIndexer->initializeBatch(b);
62 mBatchIndexer->setRandomSeed(RandomSeed::instance()->getInitialSeed() + mRandomSeed);
68 bool InputLayer::readyForNextFile() {
70 return mDisplayPeriod > 0;
73 Response::Status InputLayer::updateState(
double time,
double dt) {
74 if (readyForNextFile()) {
77 if (mTimestampStream) {
78 std::ostringstream outStrStream;
79 outStrStream.precision(15);
80 int kb0 = getLayerLoc()->kb0;
82 for (
int b = 0; b < blockBatchCount; ++b) {
83 int index = mBatchIndexer->getIndex(b);
84 outStrStream <<
"[" << getName() <<
"] time: " << time <<
", batch element: " << b + kb0
85 <<
", index: " << mBatchIndexer->getIndex(b) <<
"," 88 size_t len = outStrStream.str().length();
89 mTimestampStream->write(outStrStream.str().c_str(), len);
90 mTimestampStream->flush();
96 return Response::SUCCESS;
100 if (getMPIBlock()->getRank() == 0) {
101 int displayPeriodIndex = std::floor(timef / (mDisplayPeriod * dt));
102 if (displayPeriodIndex % mJitterChangeInterval == 0) {
103 for (
int b = 0; b < mRandomShiftX.size(); b++) {
104 mRandomShiftX[b] = -mMaxShiftX + (mRNG() % (2 * mMaxShiftX + 1));
105 mRandomShiftY[b] = -mMaxShiftY + (mRNG() % (2 * mMaxShiftY + 1));
107 mMirrorFlipX[b] = mXFlipToggle ? !mMirrorFlipX[b] : (mRNG() % 100) > 50;
110 mMirrorFlipY[b] = mYFlipToggle ? !mMirrorFlipY[b] : (mRNG() % 100) > 50;
116 int localNBatch = getLayerLoc()->nbatch;
118 for (
int b = 0; b < localNBatch; b++) {
119 if (getMPIBlock()->
getRank() == 0) {
120 int blockBatchElement = b + localNBatch * m;
121 int inputIndex = mBatchIndexer->getIndex(blockBatchElement);
123 int width = mInputData.at(b).getWidth();
124 int height = mInputData.at(b).getHeight();
125 int features = mInputData.at(b).getFeatures();
127 int const N = mInputRegion.at(b).getTotalElements();
128 for (
int k = 0; k < N; k++) {
129 mInputRegion.at(b).set(k, 1.0f);
136 cropToMPIBlock(mInputData.at(b));
137 cropToMPIBlock(mInputRegion.at(b));
152 int blockBatchCount = getLayerLoc()->nbatch * getMPIBlock()->
getBatchDimension();
153 for (
int b = 0; b < blockBatchCount; b++) {
154 mBatchIndexer->nextIndex(b);
161 if (procBatchIndex != 0 and procBatchIndex != mpiBatchIndex) {
165 PVHalo const *halo = &loc->halo;
166 int activityWidth, activityHeight, activityLeft, activityTop;
167 if (mUseInputBCflag) {
168 activityWidth = loc->nx + halo->lt + halo->rt;
169 activityHeight = loc->ny + halo->up + halo->dn;
174 activityWidth = loc->nx;
175 activityHeight = loc->ny;
176 activityLeft = halo->lt;
177 activityTop = halo->up;
182 if (getMPIBlock()->getRank() == 0) {
183 dataBuffer = mInputData.at(localBatchIndex);
184 regionBuffer = mInputRegion.at(localBatchIndex);
187 dataBuffer.resize(activityWidth, activityHeight, loc->nf);
188 regionBuffer.resize(activityWidth, activityHeight, loc->nf);
190 BufferUtils::scatter<float>(getMPIBlock(), dataBuffer, loc->nx, loc->ny, mpiBatchIndex, 0);
191 BufferUtils::scatter<float>(getMPIBlock(), regionBuffer, loc->nx, loc->ny, mpiBatchIndex, 0);
192 if (procBatchIndex != mpiBatchIndex) {
199 float *activityBuffer = &getActivity()[localBatchIndex * getNumExtended()];
200 for (
int n = 0; n < getNumExtended(); ++n) {
201 activityBuffer[n] = mPadValue;
204 for (
int y = 0; y < activityHeight; ++y) {
205 for (
int x = 0; x < activityWidth; ++x) {
206 for (
int f = 0; f < numFeatures; ++f) {
207 int activityIndex = kIndex(
211 loc->nx + halo->lt + halo->rt,
212 loc->ny + halo->up + halo->dn,
214 if (regionBuffer.at(x, y, f) > 0.0f) {
215 activityBuffer[activityIndex] = dataBuffer.at(x, y, f);
220 if (mNeedInputRegionsPointer) {
221 float *inputRegionBuffer =
222 &getInputRegionsAllBatchElements()[localBatchIndex * getNumExtended()];
223 for (
int y = 0; y < activityHeight; ++y) {
224 for (
int x = 0; x < activityWidth; ++x) {
225 for (
int f = 0; f < numFeatures; ++f) {
226 int activityIndex = kIndex(
230 loc->nx + halo->lt + halo->rt,
231 loc->ny + halo->up + halo->dn,
233 if (regionBuffer.at(x, y, f) > 0.0f) {
234 inputRegionBuffer[activityIndex] = regionBuffer.at(x, y, f);
245 pvAssert(getMPIBlock()->getRank() == 0);
247 int const xMargins = mUseInputBCflag ? loc->halo.lt + loc->halo.rt : 0;
248 int const yMargins = mUseInputBCflag ? loc->halo.dn + loc->halo.up : 0;
249 const int targetWidth = loc->nxGlobal + xMargins;
250 const int targetHeight = loc->nyGlobal + yMargins;
253 buffer.getFeatures() != loc->nf,
254 "ERROR: Input for layer %s has %d features, but layer has %d.\n",
256 buffer.getFeatures(),
259 if (mAutoResizeFlag) {
260 BufferUtils::rescale(
261 buffer, targetWidth, targetHeight, mRescaleMethod, mInterpolationMethod, mAnchor);
263 -mOffsetX + mRandomShiftX[blockBatchElement],
264 -mOffsetY + mRandomShiftY[blockBatchElement]);
267 buffer.grow(targetWidth, targetHeight, mAnchor);
269 -mOffsetX + mRandomShiftX[blockBatchElement],
270 -mOffsetY + mRandomShiftY[blockBatchElement]);
271 buffer.crop(targetWidth, targetHeight, mAnchor);
274 if (mMirrorFlipX[blockBatchElement] || mMirrorFlipY[blockBatchElement]) {
275 buffer.flip(mMirrorFlipX[blockBatchElement], mMirrorFlipY[blockBatchElement]);
281 Buffer<float> const ®ionBuffer = mInputRegion.at(batchElement);
282 int const totalElements = dataBuffer.getTotalElements();
283 pvAssert(totalElements == regionBuffer.getTotalElements());
284 int validRegionCount = 0;
285 for (
int k = 0; k < totalElements; k++) {
286 if (regionBuffer.at(k) > 0.0f) {
290 if (validRegionCount == 0) {
293 if (mNormalizeLuminanceFlag) {
294 if (mNormalizeStdDev) {
295 float imageSum = 0.0f;
296 float imageSumSq = 0.0f;
297 for (
int k = 0; k < totalElements; k++) {
298 if (regionBuffer.at(k) > 0.0f) {
299 float const v = dataBuffer.at(k);
306 float imageAverage = imageSum / validRegionCount;
307 for (
int k = 0; k < totalElements; k++) {
308 if (regionBuffer.at(k) > 0.0f) {
309 float const v = dataBuffer.at(k);
310 dataBuffer.set(k, v - imageAverage);
315 float imageVariance = imageSumSq / validRegionCount - imageAverage * imageAverage;
316 pvAssert(imageVariance >= 0);
317 if (imageVariance > 0) {
318 float imageStdDev = std::sqrt(imageVariance);
319 for (
int k = 0; k < totalElements; k++) {
320 if (regionBuffer.at(k) > 0.0f) {
321 float const v = dataBuffer.at(k) / imageStdDev;
322 dataBuffer.set(k, v);
330 for (
int k = 0; k < totalElements; k++) {
331 if (regionBuffer.at(k) > 0.0f) {
332 dataBuffer.set(k, 0.0f);
338 float imageMax = -std::numeric_limits<float>::max();
339 float imageMin = std::numeric_limits<float>::max();
340 for (
int k = 0; k < totalElements; k++) {
341 if (regionBuffer.at(k) > 0.0f) {
342 float const v = dataBuffer.at(k);
343 imageMax = v > imageMax ? v : imageMax;
344 imageMin = v < imageMin ? v : imageMin;
347 if (imageMax > imageMin) {
348 float imageStretch = 1.0f / (imageMax - imageMin);
349 for (
int k = 0; k < totalElements; k++) {
350 if (regionBuffer.at(k) > 0.0f) {
351 float const v = (dataBuffer.at(k) - imageMin) * imageStretch;
352 dataBuffer.set(k, v);
357 for (
int k = 0; k < totalElements; k++) {
358 if (regionBuffer.at(k) > 0.0f) {
359 dataBuffer.set(k, 0.0f);
366 if (mNormalizeLuminanceFlag) {
367 for (
int k = 0; k < totalElements; k++) {
368 if (regionBuffer.at(k) > 0.0f) {
369 float const v = -dataBuffer.at(k);
370 dataBuffer.set(k, v);
375 float imageMax = -std::numeric_limits<float>::max();
376 float imageMin = std::numeric_limits<float>::max();
377 for (
int k = 0; k < totalElements; k++) {
378 if (regionBuffer.at(k) > 0.0f) {
379 float const v = dataBuffer.at(k);
380 imageMax = v > imageMax ? v : imageMax;
381 imageMin = v < imageMin ? v : imageMin;
384 for (
int k = 0; k < totalElements; k++) {
385 if (regionBuffer.at(k) > 0.0f) {
386 float const v = imageMax + imageMin - dataBuffer.at(k);
387 dataBuffer.set(k, v);
397 int const startY = getMPIBlock()->
getStartRow() * loc->ny;
398 buffer.translate(-startX, -startY);
399 int const xMargins = mUseInputBCflag ? loc->halo.lt + loc->halo.rt : 0;
400 int const yMargins = mUseInputBCflag ? loc->halo.dn + loc->halo.up : 0;
401 int const blockWidth = getMPIBlock()->
getNumColumns() * loc->nx + xMargins;
402 int const blockHeight = getMPIBlock()->
getNumRows() * loc->ny + yMargins;
408 int InputLayer::requireChannel(
int channelNeeded,
int *numChannelsResult) {
409 if (parent->getCommunicator()->commRank() == 0) {
410 ErrorLog().printf(
"%s cannot be a post-synaptic layer.\n", getDescription_c());
412 *numChannelsResult = 0;
416 void InputLayer::allocateV() { clayer->V =
nullptr; }
418 void InputLayer::initializeV() { pvAssert(getV() ==
nullptr); }
420 void InputLayer::initializeActivity() {
421 retrieveInput(parent->simulationTime(), parent->getDeltaTime());
426 ioParam_displayPeriod(ioFlag);
427 ioParam_inputPath(ioFlag);
428 ioParam_offsetAnchor(ioFlag);
429 ioParam_offsets(ioFlag);
430 ioParam_maxShifts(ioFlag);
431 ioParam_flipsEnabled(ioFlag);
432 ioParam_flipsToggle(ioFlag);
433 ioParam_jitterChangeInterval(ioFlag);
434 ioParam_autoResizeFlag(ioFlag);
435 ioParam_aspectRatioAdjustment(ioFlag);
436 ioParam_interpolationMethod(ioFlag);
437 ioParam_inverseFlag(ioFlag);
438 ioParam_normalizeLuminanceFlag(ioFlag);
439 ioParam_normalizeStdDev(ioFlag);
440 ioParam_useInputBCflag(ioFlag);
441 ioParam_padValue(ioFlag);
442 ioParam_batchMethod(ioFlag);
443 ioParam_randomSeed(ioFlag);
444 ioParam_start_frame_index(ioFlag);
445 ioParam_skip_frame_index(ioFlag);
446 ioParam_resetToStartOnLoop(ioFlag);
447 ioParam_writeFrameToTimestamp(ioFlag);
451 Response::Status InputLayer::registerData(
Checkpointer *checkpointer) {
452 auto status = HyPerLayer::registerData(checkpointer);
456 if (checkpointer->getMPIBlock()->
getRank() == 0) {
457 mRNG.seed(mRandomSeed);
458 int numBatch = getLayerLoc()->nbatch;
460 mRandomShiftX.resize(nBatch);
461 mRandomShiftY.resize(nBatch);
462 mMirrorFlipX.resize(nBatch);
463 mMirrorFlipY.resize(nBatch);
464 mInputData.resize(numBatch);
465 mInputRegion.resize(numBatch);
466 initializeBatchIndexer();
467 mBatchIndexer->setWrapToStartIndex(mResetToStartOnLoop);
468 mBatchIndexer->registerData(checkpointer);
470 if (mWriteFrameToTimestamp) {
471 std::string timestampFilename = std::string(
"timestamps/");
472 timestampFilename += name + std::string(
".txt");
473 std::string cpFileStreamLabel(getName());
474 cpFileStreamLabel.append(
"_TimestampState");
475 bool needToCreateFile = checkpointer->getCheckpointReadDirectory().empty();
477 timestampFilename, needToCreateFile, checkpointer, cpFileStreamLabel);
480 return Response::SUCCESS;
483 Response::Status InputLayer::readStateFromCheckpoint(
Checkpointer *checkpointer) {
484 auto status = Response::NO_ACTION;
485 if (initializeFromCheckpointFlag) {
486 status = HyPerLayer::readStateFromCheckpoint(checkpointer);
491 pvAssert(getMPIBlock()->getRank() == 0);
497 int InputLayer::checkValidAnchorString(
const char *offsetAnchor) {
498 int status = PV_SUCCESS;
499 if (offsetAnchor == NULL || strlen(offsetAnchor) != (
size_t)2) {
503 char xOffsetAnchor = offsetAnchor[1];
504 if (xOffsetAnchor !=
'l' && xOffsetAnchor !=
'c' && xOffsetAnchor !=
'r') {
507 char yOffsetAnchor = offsetAnchor[0];
508 if (yOffsetAnchor !=
't' && yOffsetAnchor !=
'c' && yOffsetAnchor !=
'b') {
515 void InputLayer::ioParam_inputPath(
enum ParamsIOFlag ioFlag) {
516 char *tempString =
nullptr;
517 if (ioFlag == PARAMS_IO_WRITE) {
518 tempString = strdup(mInputPath.c_str());
520 parent->parameters()->ioParamStringRequired(ioFlag, name,
"inputPath", &tempString);
521 if (ioFlag == PARAMS_IO_READ) {
522 mInputPath = std::string(tempString);
527 void InputLayer::ioParam_useInputBCflag(
enum ParamsIOFlag ioFlag) {
528 parent->parameters()->ioParamValue(
529 ioFlag, name,
"useInputBCflag", &mUseInputBCflag, mUseInputBCflag);
532 int InputLayer::ioParam_offsets(
enum ParamsIOFlag ioFlag) {
533 parent->parameters()->ioParamValue(ioFlag, name,
"offsetX", &mOffsetX, mOffsetX);
534 parent->parameters()->ioParamValue(ioFlag, name,
"offsetY", &mOffsetY, mOffsetY);
538 int InputLayer::ioParam_maxShifts(
enum ParamsIOFlag ioFlag) {
539 parent->parameters()->ioParamValue(ioFlag, name,
"maxShiftX", &mMaxShiftX, mMaxShiftX);
540 parent->parameters()->ioParamValue(ioFlag, name,
"maxShiftY", &mMaxShiftY, mMaxShiftY);
544 int InputLayer::ioParam_flipsEnabled(
enum ParamsIOFlag ioFlag) {
545 parent->parameters()->ioParamValue(ioFlag, name,
"xFlipEnabled", &mXFlipEnabled, mXFlipEnabled);
546 parent->parameters()->ioParamValue(ioFlag, name,
"yFlipEnabled", &mYFlipEnabled, mYFlipEnabled);
550 int InputLayer::ioParam_flipsToggle(
enum ParamsIOFlag ioFlag) {
551 parent->parameters()->ioParamValue(ioFlag, name,
"xFlipToggle", &mXFlipToggle, mXFlipToggle);
552 parent->parameters()->ioParamValue(ioFlag, name,
"yFlipToggle", &mYFlipToggle, mYFlipToggle);
556 int InputLayer::ioParam_jitterChangeInterval(
enum ParamsIOFlag ioFlag) {
557 parent->parameters()->ioParamValue(
558 ioFlag, name,
"jitterChangeInterval", &mJitterChangeInterval, mJitterChangeInterval);
562 void InputLayer::ioParam_offsetAnchor(
enum ParamsIOFlag ioFlag) {
563 if (ioFlag == PARAMS_IO_READ) {
564 char *offsetAnchor =
nullptr;
565 parent->parameters()->ioParamString(ioFlag, name,
"offsetAnchor", &offsetAnchor,
"tl");
566 if (checkValidAnchorString(offsetAnchor) == PV_FAILURE) {
567 Fatal() <<
"Invalid value for offsetAnchor\n";
569 if (strcmp(offsetAnchor,
"tl") == 0) {
572 else if (strcmp(offsetAnchor,
"tc") == 0) {
575 else if (strcmp(offsetAnchor,
"tr") == 0) {
578 else if (strcmp(offsetAnchor,
"cl") == 0) {
581 else if (strcmp(offsetAnchor,
"cc") == 0) {
584 else if (strcmp(offsetAnchor,
"cr") == 0) {
587 else if (strcmp(offsetAnchor,
"bl") == 0) {
590 else if (strcmp(offsetAnchor,
"bc") == 0) {
593 else if (strcmp(offsetAnchor,
"br") == 0) {
597 if (parent->getCommunicator()->commRank() == 0) {
599 "%s: offsetAnchor must be a two-letter string. The first character must be " 600 "\"t\", \"c\", or \"b\" (for top, center or bottom); and the second character " 601 "must be \"l\", \"c\", or \"r\" (for left, center or right).\n",
604 MPI_Barrier(parent->getCommunicator()->communicator());
611 char *offsetAnchor = (
char *)calloc(3,
sizeof(
char));
612 offsetAnchor[2] =
'\0';
635 parent->parameters()->ioParamString(ioFlag, name,
"offsetAnchor", &offsetAnchor,
"tl");
640 void InputLayer::ioParam_autoResizeFlag(
enum ParamsIOFlag ioFlag) {
641 parent->parameters()->ioParamValue(
642 ioFlag, name,
"autoResizeFlag", &mAutoResizeFlag, mAutoResizeFlag);
645 void InputLayer::ioParam_aspectRatioAdjustment(
enum ParamsIOFlag ioFlag) {
646 assert(!parent->parameters()->presentAndNotBeenRead(name,
"autoResizeFlag"));
647 if (mAutoResizeFlag) {
648 char *aspectRatioAdjustment =
nullptr;
649 if (ioFlag == PARAMS_IO_WRITE) {
650 switch (mRescaleMethod) {
651 case BufferUtils::CROP: aspectRatioAdjustment = strdup(
"crop");
break;
652 case BufferUtils::PAD: aspectRatioAdjustment = strdup(
"pad");
break;
655 parent->parameters()->ioParamString(
656 ioFlag, name,
"aspectRatioAdjustment", &aspectRatioAdjustment,
"crop");
657 if (ioFlag == PARAMS_IO_READ) {
658 assert(aspectRatioAdjustment);
659 for (
char *c = aspectRatioAdjustment; *c; c++) {
663 if (strcmp(aspectRatioAdjustment,
"crop") == 0) {
664 mRescaleMethod = BufferUtils::CROP;
666 else if (strcmp(aspectRatioAdjustment,
"pad") == 0) {
667 mRescaleMethod = BufferUtils::PAD;
670 if (parent->getCommunicator()->commRank() == 0) {
672 "%s: aspectRatioAdjustment must be either \"crop\" or \"pad\".\n",
675 MPI_Barrier(parent->getCommunicator()->communicator());
678 free(aspectRatioAdjustment);
682 void InputLayer::ioParam_interpolationMethod(
enum ParamsIOFlag ioFlag) {
683 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"autoResizeFlag"));
684 if (mAutoResizeFlag) {
685 char *interpolationMethodString =
nullptr;
686 if (ioFlag == PARAMS_IO_READ) {
687 parent->parameters()->ioParamString(
690 "interpolationMethod",
691 &interpolationMethodString,
694 assert(interpolationMethodString);
695 for (
char *c = interpolationMethodString; *c; c++) {
698 if (!strncmp(interpolationMethodString,
"bicubic", strlen(
"bicubic"))) {
699 mInterpolationMethod = BufferUtils::BICUBIC;
702 !strncmp(interpolationMethodString,
"nearestneighbor", strlen(
"nearestneighbor"))) {
703 mInterpolationMethod = BufferUtils::NEAREST;
706 if (parent->getCommunicator()->commRank() == 0) {
708 "%s: interpolationMethod must be either \"bicubic\" or \"nearestNeighbor\".\n",
711 MPI_Barrier(parent->getCommunicator()->communicator());
716 assert(ioFlag == PARAMS_IO_WRITE);
717 switch (mInterpolationMethod) {
718 case BufferUtils::BICUBIC: interpolationMethodString = strdup(
"bicubic");
break;
719 case BufferUtils::NEAREST: interpolationMethodString = strdup(
"nearestNeighbor");
break;
721 parent->parameters()->ioParamString(
724 "interpolationMethod",
725 &interpolationMethodString,
729 free(interpolationMethodString);
733 void InputLayer::ioParam_inverseFlag(
enum ParamsIOFlag ioFlag) {
734 parent->parameters()->ioParamValue(ioFlag, name,
"inverseFlag", &mInverseFlag, mInverseFlag);
737 void InputLayer::ioParam_normalizeLuminanceFlag(
enum ParamsIOFlag ioFlag) {
738 parent->parameters()->ioParamValue(
739 ioFlag, name,
"normalizeLuminanceFlag", &mNormalizeLuminanceFlag, mNormalizeLuminanceFlag);
742 void InputLayer::ioParam_normalizeStdDev(
enum ParamsIOFlag ioFlag) {
743 assert(!parent->parameters()->presentAndNotBeenRead(name,
"normalizeLuminanceFlag"));
744 if (mNormalizeLuminanceFlag) {
745 parent->parameters()->ioParamValue(
746 ioFlag, name,
"normalizeStdDev", &mNormalizeStdDev, mNormalizeStdDev);
749 void InputLayer::ioParam_padValue(
enum ParamsIOFlag ioFlag) {
750 parent->parameters()->ioParamValue(ioFlag, name,
"padValue", &mPadValue, mPadValue);
754 assert(mInitVObject == NULL);
759 if (ioFlag == PARAMS_IO_READ) {
760 triggerLayerName = NULL;
763 name,
"triggerLayerName", NULL );
767 void InputLayer::ioParam_displayPeriod(
enum ParamsIOFlag ioFlag) {
768 parent->parameters()->ioParamValue(
769 ioFlag, name,
"displayPeriod", &mDisplayPeriod, mDisplayPeriod);
772 void InputLayer::ioParam_batchMethod(
enum ParamsIOFlag ioFlag) {
773 char *batchMethod =
nullptr;
774 if (ioFlag == PARAMS_IO_WRITE) {
775 switch (mBatchMethod) {
776 case BatchIndexer::BYFILE: batchMethod = strdup(
"byFile");
break;
777 case BatchIndexer::BYLIST: batchMethod = strdup(
"byList");
break;
778 case BatchIndexer::BYSPECIFIED: batchMethod = strdup(
"bySpecified");
break;
779 case BatchIndexer::RANDOM: batchMethod = strdup(
"random");
break;
782 parent->parameters()->ioParamString(ioFlag, name,
"batchMethod", &batchMethod,
"byFile");
783 if (strcmp(batchMethod,
"byImage") == 0 || strcmp(batchMethod,
"byFile") == 0) {
784 mBatchMethod = BatchIndexer::BYFILE;
786 else if (strcmp(batchMethod,
"byMovie") == 0 || strcmp(batchMethod,
"byList") == 0) {
787 mBatchMethod = BatchIndexer::BYLIST;
789 else if (strcmp(batchMethod,
"bySpecified") == 0) {
790 mBatchMethod = BatchIndexer::BYSPECIFIED;
792 else if (strcmp(batchMethod,
"random") == 0) {
793 mBatchMethod = BatchIndexer::RANDOM;
796 Fatal() << getName() <<
": Input layer " << name
797 <<
" batchMethod not recognized. Options " 798 "are \"byFile\", \"byList\", bySpecified, and random.\n";
803 void InputLayer::ioParam_randomSeed(
enum ParamsIOFlag ioFlag) {
804 parent->parameters()->ioParamValue(ioFlag, name,
"randomSeed", &mRandomSeed, mRandomSeed);
807 void InputLayer::ioParam_start_frame_index(
enum ParamsIOFlag ioFlag) {
808 int *paramsStartFrameIndex;
810 if (ioFlag == PARAMS_IO_WRITE) {
811 length = mStartFrameIndex.size();
812 paramsStartFrameIndex =
static_cast<int *
>(calloc(length,
sizeof(
int)));
813 for (
int i = 0; i < length; ++i) {
814 paramsStartFrameIndex[i] = mStartFrameIndex.at(i);
817 this->parent->parameters()->ioParamArray(
818 ioFlag, this->getName(),
"start_frame_index", ¶msStartFrameIndex, &length);
820 length != 0 && length != parent->getNBatchGlobal(),
821 "%s: start_frame_index requires either 0 or nbatch values.\n",
823 mStartFrameIndex.clear();
824 mStartFrameIndex.resize(parent->getNBatchGlobal());
826 for (
int i = 0; i < length; ++i) {
827 mStartFrameIndex.at(i) = paramsStartFrameIndex[i];
830 free(paramsStartFrameIndex);
833 void InputLayer::ioParam_skip_frame_index(
enum ParamsIOFlag ioFlag) {
834 pvAssert(!parent->parameters()->presentAndNotBeenRead(name,
"batchMethod"));
835 if (mBatchMethod != BatchIndexer::BYSPECIFIED) {
836 mSkipFrameIndex.resize(parent->getNBatchGlobal(), 0);
842 int *paramsSkipFrameIndex =
nullptr;
844 if (ioFlag == PARAMS_IO_WRITE) {
845 length = mSkipFrameIndex.size();
846 paramsSkipFrameIndex =
static_cast<int *
>(calloc(length,
sizeof(
int)));
847 for (
int i = 0; i < length; ++i) {
848 paramsSkipFrameIndex[i] = mSkipFrameIndex.at(i);
851 this->parent->parameters()->ioParamArray(
852 ioFlag, this->getName(),
"skip_frame_index", ¶msSkipFrameIndex, &length);
854 length != parent->getNBatchGlobal(),
855 "%s: skip_frame_index requires nbatch values.\n",
857 mSkipFrameIndex.clear();
858 mSkipFrameIndex.resize(length);
859 for (
int i = 0; i < length; ++i) {
860 mSkipFrameIndex.at(i) = paramsSkipFrameIndex[i];
862 free(paramsSkipFrameIndex);
865 void InputLayer::ioParam_resetToStartOnLoop(
enum ParamsIOFlag ioFlag) {
866 assert(!parent->parameters()->presentAndNotBeenRead(name,
"batchMethod"));
867 if (mBatchMethod == BatchIndexer::BYSPECIFIED) {
868 parent->parameters()->ioParamValue(
869 ioFlag, name,
"resetToStartOnLoop", &mResetToStartOnLoop, mResetToStartOnLoop);
872 mResetToStartOnLoop =
false;
876 void InputLayer::ioParam_writeFrameToTimestamp(
enum ParamsIOFlag ioFlag) {
877 assert(!parent->parameters()->presentAndNotBeenRead(name,
"displayPeriod"));
878 if (mDisplayPeriod > 0) {
879 parent->parameters()->ioParamValue(
880 ioFlag, name,
"writeFrameToTimestamp", &mWriteFrameToTimestamp, mWriteFrameToTimestamp);
883 mWriteFrameToTimestamp =
false;
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int getNumColumns() const
static bool completed(Status &a)
int getStartColumn() const
int getGlobalBatchDimension() const
int initialize(const char *name, HyPerCol *hc)
int getBatchDimension() const
int getBatchIndex() const
int getStartBatch() const
void handleUnnecessaryStringParameter(const char *group_name, const char *param_name)