1 #include "SegmentLayer.hpp" 5 SegmentLayer::SegmentLayer(
const char *name, HyPerCol *hc) {
10 SegmentLayer::SegmentLayer() {
15 int SegmentLayer::initialize_base() {
17 originalLayerName = NULL;
32 int SegmentLayer::initialize(
const char *name, HyPerCol *hc) {
39 ioParam_segmentMethod(ioFlag);
40 ioParam_originalLayerName(ioFlag);
44 void SegmentLayer::ioParam_segmentMethod(
enum ParamsIOFlag ioFlag) {
45 parent->parameters()->ioParamStringRequired(ioFlag, name,
"segmentMethod", &segmentMethod);
46 assert(segmentMethod);
49 if (strcmp(segmentMethod,
"none") == 0) {
54 if (parent->columnId() == 0) {
56 "%s: segmentMethod %s not recognized. Current options are \"none\".\n",
60 MPI_Barrier(parent->getCommunicator()->communicator());
65 void SegmentLayer::ioParam_originalLayerName(
enum ParamsIOFlag ioFlag) {
66 parent->parameters()->ioParamStringRequired(
67 ioFlag, name,
"originalLayerName", &originalLayerName);
68 assert(originalLayerName);
69 if (ioFlag == PARAMS_IO_READ && originalLayerName[0] ==
'\0') {
70 if (parent->columnId() == 0) {
71 ErrorLog().printf(
"%s: originalLayerName must be set.\n", getDescription_c());
73 MPI_Barrier(parent->getCommunicator()->communicator());
79 SegmentLayer::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
80 auto status = HyPerLayer::communicateInitInfo(message);
82 originalLayer = message->lookup<
HyPerLayer>(std::string(originalLayerName));
83 if (originalLayer == NULL) {
84 if (parent->columnId() == 0) {
86 "%s: originalLayerName \"%s\" is not a layer in the HyPerCol.\n",
90 MPI_Barrier(parent->getCommunicator()->communicator());
94 return Response::POSTPONE;
98 originalLayer->synchronizeMarginWidth(
this);
99 this->synchronizeMarginWidth(originalLayer);
102 const PVLayerLoc *srcLoc = originalLayer->getLayerLoc();
106 if (srcLoc->nxGlobal != thisLoc->nxGlobal || srcLoc->nyGlobal != thisLoc->nyGlobal) {
107 if (parent->columnId() == 0) {
108 ErrorLog(errorMessage);
110 "%s: originalLayer \"%s\" does not have the same x and y dimensions as this " 115 " original (nx=%d, ny=%d) versus (nx=%d, ny=%d)\n",
121 MPI_Barrier(parent->getCommunicator()->communicator());
126 if (thisLoc->nf != 1) {
127 if (parent->columnId() == 0) {
128 ErrorLog().printf(
"%s: SegmentLayer must have 1 feature.\n", getDescription_c());
130 MPI_Barrier(parent->getCommunicator()->communicator());
135 if (strcmp(segmentMethod,
"none") == 0 && srcLoc->nf != 1) {
136 if (parent->columnId() == 0) {
138 "%s: Source layer must have 1 feature with segmentation method \"none\".\n",
141 MPI_Barrier(parent->getCommunicator()->communicator());
148 int SegmentLayer::checkLabelBufSize(
int newSize) {
149 if (newSize <= labelBufSize) {
156 labelBuf = (
int *)realloc(labelBuf, newSize *
sizeof(
int));
157 maxXBuf = (
int *)realloc(maxXBuf, newSize *
sizeof(
int));
158 maxYBuf = (
int *)realloc(maxYBuf, newSize *
sizeof(
int));
159 minXBuf = (
int *)realloc(minXBuf, newSize *
sizeof(
int));
160 minYBuf = (
int *)realloc(minYBuf, newSize *
sizeof(
int));
163 labelBufSize = newSize;
167 int SegmentLayer::loadLabelBuf() {
169 int numLabels = maxX.size();
171 checkLabelBufSize(numLabels);
174 for (
auto &m : maxX) {
175 labelBuf[idx] = m.first;
176 maxXBuf[idx] = m.second;
179 assert(idx == numLabels);
182 for (
int i = 0; i < numLabels; i++) {
183 int label = labelBuf[i];
184 maxYBuf[i] = maxY.at(label);
185 minXBuf[i] = minX.at(label);
186 minYBuf[i] = minY.at(label);
191 int SegmentLayer::loadCenterIdxMap(
int batchIdx,
int numLabels) {
192 for (
int i = 0; i < numLabels; i++) {
193 int label = allLabelsBuf[i];
194 int idx = centerIdxBuf[i];
195 centerIdx[batchIdx][label] = idx;
200 int SegmentLayer::checkIdxBufSize(
int newSize) {
201 if (newSize <= centerIdxBufSize) {
206 centerIdxBuf = (
int *)realloc(centerIdxBuf, newSize *
sizeof(
int));
207 allLabelsBuf = (
int *)realloc(allLabelsBuf, newSize *
sizeof(
int));
209 centerIdxBufSize = newSize;
213 Response::Status SegmentLayer::allocateDataStructures() {
214 auto status = HyPerLayer::allocateDataStructures();
219 int nbatch = getLayerLoc()->nbatch;
227 for (
int b = 0; b < nbatch; b++) {
228 centerIdx.push_back(std::map<int, int>());
231 return Response::SUCCESS;
234 void SegmentLayer::allocateV() {
239 void SegmentLayer::initializeV() { assert(getV() == NULL); }
241 void SegmentLayer::initializeActivity() {}
243 Response::Status SegmentLayer::updateState(
double timef,
double dt) {
244 float *srcA = originalLayer->getActivity();
245 float *thisA = getActivity();
252 if (strcmp(segmentMethod,
"none") == 0) {
253 int numBatchExtended = getNumExtendedAllBatches();
256 memcpy(thisA, srcA, numBatchExtended *
sizeof(
float));
263 assert(loc->nf == 1);
266 for (
int bi = 0; bi < loc->nbatch; bi++) {
267 centerIdx[bi].clear();
270 for (
int bi = 0; bi < loc->nbatch; bi++) {
271 float *batchA = thisA + bi * getNumExtended();
280 for (
int yi = loc->halo.up; yi < loc->ny + loc->halo.up; yi++) {
281 for (
int xi = loc->halo.lt; xi < loc->nx + loc->halo.lt; xi++) {
283 int niLocalExt = yi * (loc->nx + loc->halo.lt + loc->halo.rt) + xi;
285 int globalResYi = yi - loc->halo.up + loc->ky0;
286 int globalResXi = xi - loc->halo.lt + loc->kx0;
291 int labelVal = round(batchA[niLocalExt]);
295 if (maxX.count(labelVal)) {
299 if (globalResXi > maxX.at(labelVal)) {
300 maxX[labelVal] = globalResXi;
302 if (globalResXi < minX.at(labelVal)) {
303 minX[labelVal] = globalResXi;
305 if (globalResYi > maxY.at(labelVal)) {
306 maxY[labelVal] = globalResYi;
308 if (globalResYi < minY.at(labelVal)) {
309 minY[labelVal] = globalResYi;
314 maxX[labelVal] = globalResXi;
315 minX[labelVal] = globalResXi;
316 maxY[labelVal] = globalResYi;
317 minY[labelVal] = globalResYi;
324 int numMpi = icComm->commSize();
325 int rank = icComm->commRank();
329 int numLabels = maxX.size();
335 MPI_Send(&numLabels, 1, MPI_INT, 0, rank, icComm->communicator());
337 MPI_Send(labelBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
338 MPI_Send(maxXBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
339 MPI_Send(maxYBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
340 MPI_Send(minXBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
341 MPI_Send(minYBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
344 int numCenterIdx = 0;
345 MPI_Bcast(&numCenterIdx, 1, MPI_INT, 0, icComm->communicator());
346 checkIdxBufSize(numCenterIdx);
348 MPI_Bcast(allLabelsBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
349 MPI_Bcast(centerIdxBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
352 loadCenterIdxMap(bi, numCenterIdx);
357 for (
int recvRank = 1; recvRank < numMpi; recvRank++) {
358 int numRecvLabels = 0;
359 MPI_Recv(&numRecvLabels, 1, MPI_INT, recvRank, recvRank, icComm->communicator(), NULL);
360 checkLabelBufSize(numRecvLabels);
368 icComm->communicator(),
376 icComm->communicator(),
384 icComm->communicator(),
392 icComm->communicator(),
400 icComm->communicator(),
403 for (
int i = 0; i < numRecvLabels; i++) {
404 int label = labelBuf[i];
407 if (maxX.count(label)) {
408 if (maxXBuf[i] > maxX.at(label)) {
409 maxX[label] = maxXBuf[i];
411 if (maxYBuf[i] > maxY.at(label)) {
412 maxY[label] = maxYBuf[i];
414 if (minXBuf[i] < minX.at(label)) {
415 minX[label] = minXBuf[i];
417 if (minYBuf[i] < minY.at(label)) {
418 minY[label] = minYBuf[i];
422 maxX[label] = maxXBuf[i];
423 maxY[label] = maxYBuf[i];
424 minX[label] = minXBuf[i];
425 minY[label] = minYBuf[i];
432 for (
auto &m : maxX) {
434 int centerX = minX.at(label) + (maxX.at(label) - minX.at(label)) / 2;
435 int centerY = minY.at(label) + (maxY.at(label) - minY.at(label)) / 2;
437 int centerIdxVal = centerY * (loc->nxGlobal) + centerX;
439 centerIdx[bi][label] = centerIdxVal;
443 int numCenterIdx = centerIdx[bi].size();
444 checkIdxBufSize(numCenterIdx);
447 for (
auto &ctr : centerIdx[bi]) {
448 allLabelsBuf[idx] = ctr.first;
449 centerIdxBuf[idx] = ctr.second;
454 MPI_Bcast(&numCenterIdx, 1, MPI_INT, 0, icComm->communicator());
455 MPI_Bcast(allLabelsBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
456 MPI_Bcast(centerIdxBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
461 return Response::SUCCESS;
464 SegmentLayer::~SegmentLayer() {
465 free(originalLayerName);
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
static bool completed(Status &a)
int initialize(const char *name, HyPerCol *hc)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
bool getInitInfoCommunicatedFlag() const