PetaVision  Alpha
SegmentLayer.cpp
1 #include "SegmentLayer.hpp"
2 
3 namespace PV {
4 
5 SegmentLayer::SegmentLayer(const char *name, HyPerCol *hc) {
6  initialize_base();
7  initialize(name, hc);
8 }
9 
10 SegmentLayer::SegmentLayer() {
11  initialize_base();
12  // initialize() gets called by subclass's initialize method
13 }
14 
15 int SegmentLayer::initialize_base() {
16  segmentMethod = NULL;
17  originalLayerName = NULL;
18  numChannels = 0;
19  labelBufSize = 0;
20  labelBuf = NULL;
21  maxXBuf = NULL;
22  maxYBuf = NULL;
23  minXBuf = NULL;
24  minYBuf = NULL;
25  centerIdxBufSize = 0;
26  centerIdxBuf = NULL;
27  allLabelsBuf = NULL;
28 
29  return PV_SUCCESS;
30 }
31 
32 int SegmentLayer::initialize(const char *name, HyPerCol *hc) {
33  int status = HyPerLayer::initialize(name, hc);
34  return status;
35 }
36 
37 int SegmentLayer::ioParamsFillGroup(enum ParamsIOFlag ioFlag) {
38  int status = HyPerLayer::ioParamsFillGroup(ioFlag);
39  ioParam_segmentMethod(ioFlag);
40  ioParam_originalLayerName(ioFlag);
41  return status;
42 }
43 
44 void SegmentLayer::ioParam_segmentMethod(enum ParamsIOFlag ioFlag) {
45  parent->parameters()->ioParamStringRequired(ioFlag, name, "segmentMethod", &segmentMethod);
46  assert(segmentMethod);
47  // Check valid segment methods
48  // none means the gsyn is already a segmentation. Helpful if reading segmentation from pvp
49  if (strcmp(segmentMethod, "none") == 0) {
50  }
51  // TODO add in other segmentation methods
52  // How do we segment across MPI margins?
53  else {
54  if (parent->columnId() == 0) {
55  ErrorLog().printf(
56  "%s: segmentMethod %s not recognized. Current options are \"none\".\n",
57  getDescription_c(),
58  segmentMethod);
59  }
60  MPI_Barrier(parent->getCommunicator()->communicator());
61  exit(EXIT_FAILURE);
62  }
63 }
64 
65 void SegmentLayer::ioParam_originalLayerName(enum ParamsIOFlag ioFlag) {
66  parent->parameters()->ioParamStringRequired(
67  ioFlag, name, "originalLayerName", &originalLayerName);
68  assert(originalLayerName);
69  if (ioFlag == PARAMS_IO_READ && originalLayerName[0] == '\0') {
70  if (parent->columnId() == 0) {
71  ErrorLog().printf("%s: originalLayerName must be set.\n", getDescription_c());
72  }
73  MPI_Barrier(parent->getCommunicator()->communicator());
74  exit(EXIT_FAILURE);
75  }
76 }
77 
78 Response::Status
79 SegmentLayer::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
80  auto status = HyPerLayer::communicateInitInfo(message);
81  // Get original layer
82  originalLayer = message->lookup<HyPerLayer>(std::string(originalLayerName));
83  if (originalLayer == NULL) {
84  if (parent->columnId() == 0) {
85  ErrorLog().printf(
86  "%s: originalLayerName \"%s\" is not a layer in the HyPerCol.\n",
87  getDescription_c(),
88  originalLayerName);
89  }
90  MPI_Barrier(parent->getCommunicator()->communicator());
91  exit(EXIT_FAILURE);
92  }
93  if (originalLayer->getInitInfoCommunicatedFlag() == false) {
94  return Response::POSTPONE;
95  }
96 
97  // Sync margins
98  originalLayer->synchronizeMarginWidth(this);
99  this->synchronizeMarginWidth(originalLayer);
100 
101  // Check size
102  const PVLayerLoc *srcLoc = originalLayer->getLayerLoc();
103  const PVLayerLoc *thisLoc = getLayerLoc();
104 
105  // Original layer must be the same x/y size as this layer
106  if (srcLoc->nxGlobal != thisLoc->nxGlobal || srcLoc->nyGlobal != thisLoc->nyGlobal) {
107  if (parent->columnId() == 0) {
108  ErrorLog(errorMessage);
109  errorMessage.printf(
110  "%s: originalLayer \"%s\" does not have the same x and y dimensions as this "
111  "layer.\n",
112  getDescription_c(),
113  originalLayerName);
114  errorMessage.printf(
115  " original (nx=%d, ny=%d) versus (nx=%d, ny=%d)\n",
116  srcLoc->nxGlobal,
117  srcLoc->nyGlobal,
118  thisLoc->nxGlobal,
119  thisLoc->nyGlobal);
120  }
121  MPI_Barrier(parent->getCommunicator()->communicator());
122  exit(EXIT_FAILURE);
123  }
124 
125  // This layer must have only 1 feature
126  if (thisLoc->nf != 1) {
127  if (parent->columnId() == 0) {
128  ErrorLog().printf("%s: SegmentLayer must have 1 feature.\n", getDescription_c());
129  }
130  MPI_Barrier(parent->getCommunicator()->communicator());
131  exit(EXIT_FAILURE);
132  }
133 
134  // If segmentMethod is none, we also need to make sure the srcLayer also has nf == 1
135  if (strcmp(segmentMethod, "none") == 0 && srcLoc->nf != 1) {
136  if (parent->columnId() == 0) {
137  ErrorLog().printf(
138  "%s: Source layer must have 1 feature with segmentation method \"none\".\n",
139  getDescription_c());
140  }
141  MPI_Barrier(parent->getCommunicator()->communicator());
142  exit(EXIT_FAILURE);
143  }
144 
145  return status;
146 }
147 
148 int SegmentLayer::checkLabelBufSize(int newSize) {
149  if (newSize <= labelBufSize) {
150  return PV_SUCCESS;
151  }
152 
153  const PVLayerLoc *loc = getLayerLoc();
154 
155  // Grow buffer
156  labelBuf = (int *)realloc(labelBuf, newSize * sizeof(int));
157  maxXBuf = (int *)realloc(maxXBuf, newSize * sizeof(int));
158  maxYBuf = (int *)realloc(maxYBuf, newSize * sizeof(int));
159  minXBuf = (int *)realloc(minXBuf, newSize * sizeof(int));
160  minYBuf = (int *)realloc(minYBuf, newSize * sizeof(int));
161 
162  // Set new size
163  labelBufSize = newSize;
164  return PV_SUCCESS;
165 }
166 
167 int SegmentLayer::loadLabelBuf() {
168  // Load in maxX and label buf from maxX map
169  int numLabels = maxX.size();
170  // Allocate send buffer to the right size
171  checkLabelBufSize(numLabels);
172 
173  int idx = 0;
174  for (auto &m : maxX) {
175  labelBuf[idx] = m.first; // Store key in label
176  maxXBuf[idx] = m.second; // Store vale in maxXBuf
177  idx++;
178  }
179  assert(idx == numLabels);
180 
181  // Load rest of buffers based on label
182  for (int i = 0; i < numLabels; i++) {
183  int label = labelBuf[i];
184  maxYBuf[i] = maxY.at(label);
185  minXBuf[i] = minX.at(label);
186  minYBuf[i] = minY.at(label);
187  }
188  return PV_SUCCESS;
189 }
190 
191 int SegmentLayer::loadCenterIdxMap(int batchIdx, int numLabels) {
192  for (int i = 0; i < numLabels; i++) {
193  int label = allLabelsBuf[i];
194  int idx = centerIdxBuf[i];
195  centerIdx[batchIdx][label] = idx;
196  }
197  return PV_SUCCESS;
198 }
199 
200 int SegmentLayer::checkIdxBufSize(int newSize) {
201  if (newSize <= centerIdxBufSize) {
202  return PV_SUCCESS;
203  }
204 
205  // Grow buffer
206  centerIdxBuf = (int *)realloc(centerIdxBuf, newSize * sizeof(int));
207  allLabelsBuf = (int *)realloc(allLabelsBuf, newSize * sizeof(int));
208  // Set new size
209  centerIdxBufSize = newSize;
210  return PV_SUCCESS;
211 }
212 
213 Response::Status SegmentLayer::allocateDataStructures() {
214  auto status = HyPerLayer::allocateDataStructures();
215  if (!Response::completed(status)) {
216  return status;
217  }
218 
219  int nbatch = getLayerLoc()->nbatch;
220  maxX.clear();
221  maxY.clear();
222  minX.clear();
223  minY.clear();
224  centerIdx.clear();
225 
226  // Initialize vector of maps
227  for (int b = 0; b < nbatch; b++) {
228  centerIdx.push_back(std::map<int, int>());
229  }
230 
231  return Response::SUCCESS;
232 }
233 
234 void SegmentLayer::allocateV() {
235  // Allocate V does nothing since binning does not need a V layer
236  clayer->V = NULL;
237 }
238 
239 void SegmentLayer::initializeV() { assert(getV() == NULL); }
240 
241 void SegmentLayer::initializeActivity() {}
242 
243 Response::Status SegmentLayer::updateState(double timef, double dt) {
244  float *srcA = originalLayer->getActivity();
245  float *thisA = getActivity();
246  assert(srcA);
247  assert(thisA);
248 
249  const PVLayerLoc *loc = getLayerLoc();
250 
251  // Segment input layer based on segmentMethod
252  if (strcmp(segmentMethod, "none") == 0) {
253  int numBatchExtended = getNumExtendedAllBatches();
254  // Copy activity over
255  // Since both buffers should be identical size, we can do a memcpy here
256  memcpy(thisA, srcA, numBatchExtended * sizeof(float));
257  }
258  else {
259  // This case should never happen
260  assert(0);
261  }
262 
263  assert(loc->nf == 1);
264 
265  // Clear centerIdxs
266  for (int bi = 0; bi < loc->nbatch; bi++) {
267  centerIdx[bi].clear();
268  }
269 
270  for (int bi = 0; bi < loc->nbatch; bi++) {
271  float *batchA = thisA + bi * getNumExtended();
272  // Reset max/min buffers
273  maxX.clear();
274  maxY.clear();
275  minX.clear();
276  minY.clear();
277 
278  // Loop through this buffer to fill labelVec and idxVec
279  // Looping through restricted, but indices are extended
280  for (int yi = loc->halo.up; yi < loc->ny + loc->halo.up; yi++) {
281  for (int xi = loc->halo.lt; xi < loc->nx + loc->halo.lt; xi++) {
282  // Convert to local extended linear index
283  int niLocalExt = yi * (loc->nx + loc->halo.lt + loc->halo.rt) + xi;
284  // Convert yi and xi to global res index
285  int globalResYi = yi - loc->halo.up + loc->ky0;
286  int globalResXi = xi - loc->halo.lt + loc->kx0;
287 
288  // Get label value
289  // Note that we're assuming that the activity here are integers,
290  // even though the buffer is floats
291  int labelVal = round(batchA[niLocalExt]);
292 
293  // Calculate max/min x and y for a single batch
294  // If labelVal exists in map
295  if (maxX.count(labelVal)) {
296  // Here, we're assuming the 4 maps are in sync, so we use the
297  //.at method, as it will throw an exception as opposed to the
298  //[] operator, which will simply add the key into the map
299  if (globalResXi > maxX.at(labelVal)) {
300  maxX[labelVal] = globalResXi;
301  }
302  if (globalResXi < minX.at(labelVal)) {
303  minX[labelVal] = globalResXi;
304  }
305  if (globalResYi > maxY.at(labelVal)) {
306  maxY[labelVal] = globalResYi;
307  }
308  if (globalResYi < minY.at(labelVal)) {
309  minY[labelVal] = globalResYi;
310  }
311  }
312  // If doesn't exist, add into map with current vals
313  else {
314  maxX[labelVal] = globalResXi;
315  minX[labelVal] = globalResXi;
316  maxY[labelVal] = globalResYi;
317  minY[labelVal] = globalResYi;
318  }
319  }
320  }
321 
322  // We need to mpi across processors in case a segment crosses an mpi boundary
323  Communicator *icComm = parent->getCommunicator();
324  int numMpi = icComm->commSize();
325  int rank = icComm->commRank();
326 
327  // Local comm rank
328  // Non root processes simply send buffer size and then buffers
329  int numLabels = maxX.size();
330 
331  if (rank != 0) {
332  // Load buffers
333  loadLabelBuf();
334  // Send number of labels first
335  MPI_Send(&numLabels, 1, MPI_INT, 0, rank, icComm->communicator());
336  // Send labels, then max/min buffers
337  MPI_Send(labelBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
338  MPI_Send(maxXBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
339  MPI_Send(maxYBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
340  MPI_Send(minXBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
341  MPI_Send(minYBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
342 
343  // Receive the full centerIdxBuf from root process
344  int numCenterIdx = 0;
345  MPI_Bcast(&numCenterIdx, 1, MPI_INT, 0, icComm->communicator());
346  checkIdxBufSize(numCenterIdx);
347 
348  MPI_Bcast(allLabelsBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
349  MPI_Bcast(centerIdxBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
350 
351  // Load buffer into centerIdx map
352  loadCenterIdxMap(bi, numCenterIdx);
353  }
354  // Root process stores everything
355  else {
356  // One recv per buffer
357  for (int recvRank = 1; recvRank < numMpi; recvRank++) {
358  int numRecvLabels = 0;
359  MPI_Recv(&numRecvLabels, 1, MPI_INT, recvRank, recvRank, icComm->communicator(), NULL);
360  checkLabelBufSize(numRecvLabels);
361 
362  MPI_Recv(
363  labelBuf,
364  numRecvLabels,
365  MPI_INT,
366  recvRank,
367  recvRank,
368  icComm->communicator(),
369  NULL);
370  MPI_Recv(
371  maxXBuf,
372  numRecvLabels,
373  MPI_INT,
374  recvRank,
375  recvRank,
376  icComm->communicator(),
377  NULL);
378  MPI_Recv(
379  maxYBuf,
380  numRecvLabels,
381  MPI_INT,
382  recvRank,
383  recvRank,
384  icComm->communicator(),
385  NULL);
386  MPI_Recv(
387  minXBuf,
388  numRecvLabels,
389  MPI_INT,
390  recvRank,
391  recvRank,
392  icComm->communicator(),
393  NULL);
394  MPI_Recv(
395  minYBuf,
396  numRecvLabels,
397  MPI_INT,
398  recvRank,
399  recvRank,
400  icComm->communicator(),
401  NULL);
402 
403  for (int i = 0; i < numRecvLabels; i++) {
404  int label = labelBuf[i];
405  // Add on to maps
406  // If the label already exists, fill with proper max/min
407  if (maxX.count(label)) {
408  if (maxXBuf[i] > maxX.at(label)) {
409  maxX[label] = maxXBuf[i];
410  }
411  if (maxYBuf[i] > maxY.at(label)) {
412  maxY[label] = maxYBuf[i];
413  }
414  if (minXBuf[i] < minX.at(label)) {
415  minX[label] = minXBuf[i];
416  }
417  if (minYBuf[i] < minY.at(label)) {
418  minY[label] = minYBuf[i];
419  }
420  }
421  else {
422  maxX[label] = maxXBuf[i];
423  maxY[label] = maxYBuf[i];
424  minX[label] = minXBuf[i];
425  minY[label] = minYBuf[i];
426  }
427  }
428  }
429 
430  // Maps are now filled with all segments from the image
431  // Fill centerIdx based on max/min
432  for (auto &m : maxX) {
433  int label = m.first;
434  int centerX = minX.at(label) + (maxX.at(label) - minX.at(label)) / 2;
435  int centerY = minY.at(label) + (maxY.at(label) - minY.at(label)) / 2;
436  // Convert centerpoints (in global res idx) to linear idx (in global res space)
437  int centerIdxVal = centerY * (loc->nxGlobal) + centerX;
438  // Add to centerIdxMap
439  centerIdx[bi][label] = centerIdxVal;
440  }
441 
442  // Fill centerpoint buffer
443  int numCenterIdx = centerIdx[bi].size();
444  checkIdxBufSize(numCenterIdx);
445 
446  int idx = 0;
447  for (auto &ctr : centerIdx[bi]) {
448  allLabelsBuf[idx] = ctr.first;
449  centerIdxBuf[idx] = ctr.second;
450  idx++;
451  }
452 
453  // Broadcast buffers
454  MPI_Bcast(&numCenterIdx, 1, MPI_INT, 0, icComm->communicator());
455  MPI_Bcast(allLabelsBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
456  MPI_Bcast(centerIdxBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
457  }
458  } // End batch loop
459 
460  // centerIdx now stores each center coordinate of each segment
461  return Response::SUCCESS;
462 }
463 
464 SegmentLayer::~SegmentLayer() {
465  free(originalLayerName);
466  clayer->V = NULL;
467  maxX.clear();
468  maxY.clear();
469  minX.clear();
470  minY.clear();
471  // This should call destructors of all maps within the vector
472  centerIdx.clear();
473 }
474 
475 } /* namespace PV */
virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
Definition: HyPerLayer.cpp:571
static bool completed(Status &a)
Definition: Response.hpp:49
int initialize(const char *name, HyPerCol *hc)
Definition: HyPerLayer.cpp:129
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
bool getInitInfoCommunicatedFlag() const
Definition: BaseObject.hpp:95