1 #include "Segmentify.hpp" 5 Segmentify::Segmentify(
const char *name, HyPerCol *hc) {
10 Segmentify::Segmentify() {
15 int Segmentify::initialize_base() {
17 originalLayerName = NULL;
19 segmentLayerName = NULL;
29 int Segmentify::initialize(
const char *name, HyPerCol *hc) {
36 ioParam_originalLayerName(ioFlag);
37 ioParam_segmentLayerName(ioFlag);
38 ioParam_inputMethod(ioFlag);
39 ioParam_outputMethod(ioFlag);
43 void Segmentify::ioParam_inputMethod(
enum ParamsIOFlag ioFlag) {
44 parent->parameters()->ioParamStringRequired(ioFlag, name,
"inputMethod", &inputMethod);
45 if (strcmp(inputMethod,
"average") == 0) {
47 else if (strcmp(inputMethod,
"sum") == 0) {
49 else if (strcmp(inputMethod,
"max") == 0) {
52 if (parent->columnId() == 0) {
54 "%s: inputMethod must be \"average\", \"sum\", or \"max\".\n", getDescription_c());
56 MPI_Barrier(parent->getCommunicator()->communicator());
61 void Segmentify::ioParam_outputMethod(
enum ParamsIOFlag ioFlag) {
62 parent->parameters()->ioParamStringRequired(ioFlag, name,
"outputMethod", &outputMethod);
63 if (strcmp(outputMethod,
"centroid") == 0) {
65 else if (strcmp(outputMethod,
"fill") == 0) {
68 if (parent->columnId() == 0) {
70 "%s: outputMethod must be \"centriod\" or \"fill\".\n", getDescription_c());
72 MPI_Barrier(parent->getCommunicator()->communicator());
77 void Segmentify::ioParam_originalLayerName(
enum ParamsIOFlag ioFlag) {
78 parent->parameters()->ioParamStringRequired(
79 ioFlag, name,
"originalLayerName", &originalLayerName);
80 assert(originalLayerName);
81 if (ioFlag == PARAMS_IO_READ && originalLayerName[0] ==
'\0') {
82 if (parent->columnId() == 0) {
83 ErrorLog().printf(
"%s: originalLayerName must be set.\n", getDescription_c());
85 MPI_Barrier(parent->getCommunicator()->communicator());
90 void Segmentify::ioParam_segmentLayerName(
enum ParamsIOFlag ioFlag) {
91 parent->parameters()->ioParamStringRequired(ioFlag, name,
"segmentLayerName", &segmentLayerName);
92 assert(segmentLayerName);
93 if (ioFlag == PARAMS_IO_READ && segmentLayerName[0] ==
'\0') {
94 if (parent->columnId() == 0) {
95 ErrorLog().printf(
"%s: segmentLayerName must be set.\n", getDescription_c());
97 MPI_Barrier(parent->getCommunicator()->communicator());
103 Segmentify::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
104 auto status = HyPerLayer::communicateInitInfo(message);
110 originalLayer = message->lookup<
HyPerLayer>(std::string(originalLayerName));
111 if (originalLayer == NULL) {
112 if (parent->columnId() == 0) {
114 "%s: originalLayerName \"%s\" is not a layer in the HyPerCol.\n",
118 MPI_Barrier(parent->getCommunicator()->communicator());
122 return Response::POSTPONE;
126 segmentLayer = message->lookup<
SegmentLayer>(std::string(segmentLayerName));
127 if (segmentLayer == NULL) {
128 if (parent->columnId() == 0) {
130 "%s: segmentLayerName \"%s\" is not a SegmentLayer.\n",
134 MPI_Barrier(parent->getCommunicator()->communicator());
139 return Response::POSTPONE;
143 originalLayer->synchronizeMarginWidth(
this);
144 this->synchronizeMarginWidth(originalLayer);
147 const PVLayerLoc *srcLoc = originalLayer->getLayerLoc();
148 const PVLayerLoc *segLoc = segmentLayer->getLayerLoc();
150 assert(srcLoc != NULL && segLoc != NULL);
153 if (srcLoc->nf != thisLoc->nf) {
154 if (parent->columnId() == 0) {
155 ErrorLog(errorMessage);
157 "%s: originalLayer \"%s\" does not have the same feature dimension as this layer.\n",
160 errorMessage.printf(
" original (nf=%d) versus (nf=%d)\n", srcLoc->nf, thisLoc->nf);
162 MPI_Barrier(parent->getCommunicator()->communicator());
167 if (segLoc->nf != 1) {
168 if (parent->columnId() == 0) {
170 "%s: segmentLayer \"%s\" can only have 1 feature.\n",
174 MPI_Barrier(parent->getCommunicator()->communicator());
178 return Response::SUCCESS;
181 Response::Status Segmentify::allocateDataStructures() {
182 auto status = HyPerLayer::allocateDataStructures();
188 labelVals = (
float **)calloc(getLayerLoc()->nf,
sizeof(
float *));
189 labelCount = (
int **)calloc(getLayerLoc()->nf,
sizeof(
int *));
194 return Response::SUCCESS;
197 int Segmentify::checkLabelValBuf(
int newSize) {
198 if (newSize <= numLabelVals) {
203 for (
int i = 0; i < getLayerLoc()->nf; i++) {
204 labelVals[i] = (
float *)realloc(labelVals[i], newSize *
sizeof(
float));
205 labelCount[i] = (
int *)realloc(labelCount[i], newSize *
sizeof(
int));
207 labelIdxBuf = (
int *)realloc(labelIdxBuf, newSize *
sizeof(
int));
209 numLabelVals = newSize;
214 void Segmentify::allocateV() {
219 void Segmentify::initializeV() { assert(getV() == NULL); }
221 void Segmentify::initializeActivity() {}
223 int Segmentify::buildLabelToIdx(
int batchIdx) {
225 int numMpi = icComm->commSize();
226 int rank = icComm->commRank();
233 std::map<int, int> segMap = segmentLayer->getCenterIdxBuf(batchIdx);
236 numLabels = segMap.size();
238 checkLabelValBuf(numLabels);
241 for (
auto &seg : segMap) {
242 labelIdxBuf[l] = seg.first;
248 MPI_Bcast(&numLabels, 1, MPI_INT, 0, icComm->communicator());
249 checkLabelValBuf(numLabels);
250 MPI_Bcast(labelIdxBuf, numLabels, MPI_INT, 0, icComm->communicator());
252 for (
int l = 0; l < numLabels; l++) {
254 labelToIdx[labelIdxBuf[l]] = l;
257 for (
int fi = 0; fi < getLayerLoc()->nf; fi++) {
259 labelCount[fi][l] = 0;
260 if (strcmp(inputMethod,
"max") == 0) {
261 labelVals[fi][l] = -INFINITY;
264 else if (strcmp(inputMethod,
"average") == 0 || strcmp(inputMethod,
"sum") == 0) {
265 labelVals[fi][l] = 0;
275 int Segmentify::calculateLabelVals(
int batchIdx) {
278 const PVLayerLoc *srcLoc = originalLayer->getLayerLoc();
279 const PVLayerLoc *segLoc = segmentLayer->getLayerLoc();
281 assert(segLoc->nf == 1);
283 float *srcA = originalLayer->getActivity();
284 float *segA = segmentLayer->getActivity();
289 float *srcBatchA = srcA + batchIdx * originalLayer->getNumExtended();
290 float *segBatchA = segA + batchIdx * segmentLayer->getNumExtended();
294 for (
int yi = 0; yi < srcLoc->ny; yi++) {
297 float segToSrcScaleY = (float)segLoc->ny / (
float)srcLoc->ny;
298 int segmentYi = round(yi * segToSrcScaleY);
299 for (
int xi = 0; xi < srcLoc->nx; xi++) {
300 float segToSrcScaleX = (float)segLoc->nx / (
float)srcLoc->nx;
301 int segmentXi = round(xi * segToSrcScaleX);
304 (segmentYi + segLoc->halo.up) * (segLoc->nx + segLoc->halo.lt + segLoc->halo.rt)
305 + (segmentXi + segLoc->halo.lt);
308 int labelVal = round(segBatchA[extSegIdx]);
312 int labelIdx = labelToIdx.at(labelVal);
314 for (
int fi = 0; fi < srcLoc->nf; fi++) {
317 int extSrcIdx = (yi + srcLoc->halo.up)
318 * (srcLoc->nx + srcLoc->halo.lt + srcLoc->halo.rt) * srcLoc->nf
319 + (xi + srcLoc->halo.lt) * srcLoc->nf + fi;
320 float srcVal = srcBatchA[extSrcIdx];
321 labelCount[fi][labelIdx]++;
323 if (strcmp(inputMethod,
"max") == 0) {
324 if (labelVals[fi][labelIdx] < srcVal) {
325 labelVals[fi][labelIdx] = srcVal;
328 else if (strcmp(inputMethod,
"average") == 0 || strcmp(inputMethod,
"sum") == 0) {
329 labelVals[fi][labelIdx] += srcVal;
335 int numLabels = labelToIdx.size();
337 int rank = icComm->commRank();
340 for (
int fi = 0; fi < srcLoc->nf; fi++) {
342 MPI_IN_PLACE, labelCount[fi], numLabels, MPI_INT, MPI_SUM, icComm->communicator());
343 if (strcmp(inputMethod,
"max") == 0) {
345 MPI_IN_PLACE, labelVals[fi], numLabels, MPI_FLOAT, MPI_MAX, icComm->communicator());
347 else if (strcmp(inputMethod,
"sum") == 0 || strcmp(inputMethod,
"average") == 0) {
349 MPI_IN_PLACE, labelVals[fi], numLabels, MPI_FLOAT, MPI_SUM, icComm->communicator());
352 if (strcmp(inputMethod,
"average") == 0) {
353 for (
int l = 0; l < numLabels; l++) {
354 labelVals[fi][l] = labelVals[fi][l] / labelCount[fi][l];
362 int Segmentify::setOutputVals(
int batchIdx) {
364 const PVLayerLoc *segLoc = segmentLayer->getLayerLoc();
367 assert(segLoc->nf == 1);
369 float *segA = segmentLayer->getActivity();
370 float *thisA = getActivity();
375 float *segBatchA = segA + batchIdx * segmentLayer->getNumExtended();
376 float *thisBatchA = thisA + batchIdx * getNumExtended();
379 for (
int ni = 0; ni < getNumExtended(); ni++) {
384 float thisToSegScaleX = (float)thisLoc->nx / (
float)segLoc->nx;
385 float thisToSegScaleY = (float)thisLoc->ny / (
float)segLoc->ny;
388 if (strcmp(outputMethod,
"centroid") == 0) {
389 std::map<int, int> segMap = segmentLayer->getCenterIdxBuf(batchIdx);
391 for (
auto &seg : segMap) {
392 int label = seg.first;
393 int segGlobalResIdx = seg.second;
395 int segGlobalResX = segGlobalResIdx % (segLoc->nxGlobal);
396 int segGlobalResY = segGlobalResIdx / (segLoc->nyGlobal);
398 int thisGlobalResX = round(segGlobalResX * thisToSegScaleX);
399 int thisGlobalResY = round(segGlobalResY * thisToSegScaleY);
401 if (thisGlobalResX >= thisLoc->kx0 && thisGlobalResX < thisLoc->kx0 + thisLoc->nx
402 && thisGlobalResY >= thisLoc->ky0
403 && thisGlobalResY < thisLoc->ky0 + thisLoc->ny) {
405 int thisLocalExtX = thisGlobalResX - thisLoc->kx0 + thisLoc->halo.lt;
406 int thisLocalExtY = thisGlobalResY - thisLoc->ky0 + thisLoc->halo.up;
407 for (
int fi = 0; fi < thisLoc->nf; fi++) {
408 int thisLocalExtIdx = thisLocalExtY
409 * (thisLoc->nx + thisLoc->halo.lt + thisLoc->halo.rt)
411 + thisLocalExtX * thisLoc->nf + fi;
413 thisBatchA[thisLocalExtIdx] = labelVals[fi][labelToIdx.at(label)];
418 else if (strcmp(outputMethod,
"fill") == 0) {
421 for (
int yi = 0; yi < thisLoc->ny; yi++) {
423 int segResY = round((
float)yi / (
float)thisToSegScaleY);
424 for (
int xi = 0; xi < thisLoc->nx; xi++) {
425 int segResX = round((
float)xi / (
float)thisToSegScaleX);
428 (segResY + segLoc->halo.up) * (segLoc->nx + segLoc->halo.lt + segLoc->halo.rt)
429 + (segResX + segLoc->halo.lt);
431 int label = round(segBatchA[segExtIdx]);
433 for (
int fi = 0; fi < thisLoc->nf; fi++) {
435 int thisExtIdx = (yi + thisLoc->halo.up)
436 * (thisLoc->nx + thisLoc->halo.lt + thisLoc->halo.rt)
438 + (xi + thisLoc->halo.lt) * thisLoc->nf + fi;
439 thisBatchA[thisExtIdx] = labelVals[fi][labelToIdx.at(label)];
447 Response::Status Segmentify::updateState(
double timef,
double dt) {
450 for (
int bi = 0; bi < getLayerLoc()->nbatch; bi++) {
452 calculateLabelVals(bi);
456 return Response::SUCCESS;
459 Segmentify::~Segmentify() {
460 free(originalLayerName);
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
static bool completed(Status &a)
int initialize(const char *name, HyPerCol *hc)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
bool getInitInfoCommunicatedFlag() const