8 #include "FilenameParsingGroundTruthLayer.hpp" 19 FilenameParsingGroundTruthLayer::FilenameParsingGroundTruthLayer(
const char *name, HyPerCol *hc) {
23 FilenameParsingGroundTruthLayer::~FilenameParsingGroundTruthLayer() {
24 free(mInputLayerName);
25 free(mClassListFileName);
38 parent->parameters()->ioParamValue(
39 ioFlag, name,
"gtClassTrueValue", &mGtClassTrueValue, 1.0f,
false);
43 parent->parameters()->ioParamValue(
44 ioFlag, name,
"gtClassFalseValue", &mGtClassFalseValue, -1.0f,
false);
48 parent->parameters()->ioParamStringRequired(ioFlag, name,
"inputLayerName", &mInputLayerName);
52 parent->parameters()->ioParamString(
53 ioFlag, name,
"classList", &mClassListFileName, mClassListFileName,
false);
54 if (mClassListFileName ==
nullptr) {
55 WarnLog() << getName()
56 <<
": No classList specified. Looking for classes.txt in output directory.\n";
60 Response::Status FilenameParsingGroundTruthLayer::registerData(
Checkpointer *checkpointer) {
65 auto status = HyPerLayer::registerData(checkpointer);
70 std::ifstream inputFile;
71 std::string outPath(
"");
73 if (mClassListFileName !=
nullptr) {
74 outPath += std::string(mClassListFileName);
77 outPath += checkpointer->getOutputPath();
78 outPath +=
"/classes.txt";
81 inputFile.open(outPath.c_str(), std::ifstream::in);
82 FatalIf(!inputFile.is_open(),
"%s: Unable to open file %s\n", getName(), outPath.c_str());
86 while (getline(inputFile, line)) {
87 mClasses.push_back(line);
91 std::size_t numFeatures = (std::size_t)getLayerLoc()->nf;
93 numFeatures != mClasses.size(),
94 "%s has %d features but classList \"%s\" has %zu categories.\n",
99 return Response::SUCCESS;
102 Response::Status FilenameParsingGroundTruthLayer::communicateInitInfo(
103 std::shared_ptr<CommunicateInitInfoMessage const> message) {
104 mInputLayer = message->lookup<
InputLayer>(std::string(mInputLayerName));
106 mInputLayer ==
nullptr && parent->columnId() == 0,
107 "%s: inputLayerName \"%s\" is not a layer in the HyPerCol.\n",
111 mInputLayer->getPhase() <= getPhase(),
112 "%s: The phase of layer %s (%d) must be greater than the phase of the " 113 "FilenameParsingGroundTruthLayer (%d)\n",
116 mInputLayer->getPhase(),
118 return Response::SUCCESS;
122 return mInputLayer->
needUpdate(parent->simulationTime(), parent->getDeltaTime());
125 Response::Status FilenameParsingGroundTruthLayer::updateState(
double time,
double dt) {
126 update_timer->start();
127 float *A = getCLayer()->activity->data;
129 int numNeurons = getNumNeurons();
130 int const localBatchWidth = getLayerLoc()->nbatch;
131 int const blockBatchWidth = getMPIBlock()->
getBatchDimension() * localBatchWidth;
132 for (
int b = 0; b < blockBatchWidth; b++) {
133 int const mpiBlockBatchIndex = b / localBatchWidth;
134 int const localBatchIndex = b % localBatchWidth;
136 std::vector<float> fileMatches(mClasses.size());
137 if (getMPIBlock()->
getRank() == 0) {
138 std::string currentFilename =
140 for (
auto ci = (std::size_t)0; ci < mClasses.size(); ci++) {
141 std::size_t match = currentFilename.find(mClasses.at(ci));
142 fileMatches[ci] = match != std::string::npos ? mGtClassTrueValue : mGtClassFalseValue;
148 MPI_Bcast(fileMatches.data(), fileMatches.size(), MPI_FLOAT, 0, getMPIBlock()->
getComm());
154 float *ABatch = A + localBatchIndex * getNumExtended();
155 for (
int i = 0; i < numNeurons; i++) {
156 int nExt = kIndexExtended(
165 int fi = featureIndex(
167 loc->nx + loc->halo.rt + loc->halo.lt,
168 loc->ny + loc->halo.dn + loc->halo.up,
170 ABatch[nExt] = fileMatches[fi];
173 update_timer->stop();
174 return Response::SUCCESS;
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
virtual void ioParam_gtClassFalseValue(enum ParamsIOFlag ioFlag)
gtClassFalseValue: defines value to be set for the neurons that do not match the classes.txt classifer
static bool completed(Status &a)
virtual bool needUpdate(double time, double dt) override
int initialize(const char *name, HyPerCol *hc)
int getBatchDimension() const
virtual void ioParam_gtClassTrueValue(enum ParamsIOFlag ioFlag)
gtClassTrueValue: defines value to be set for the neuron that matches classes.txt classifer ...
virtual bool needUpdate(double simTime, double dt)
int getBatchIndex() const
virtual void ioParam_inputLayerName(enum ParamsIOFlag ioFlag)
movieLayerName: lists name of the movie layer from which the imageListPath is used to parse the class...
virtual void ioParam_classList(enum ParamsIOFlag ioFlag)
clasList: path to the .txt file that holds the list of imageListPath features that will parse to diff...