1 #include "SegmentLayer.hpp"     5 SegmentLayer::SegmentLayer(
const char *name, HyPerCol *hc) {
    10 SegmentLayer::SegmentLayer() {
    15 int SegmentLayer::initialize_base() {
    17    originalLayerName = NULL;
    32 int SegmentLayer::initialize(
const char *name, HyPerCol *hc) {
    39    ioParam_segmentMethod(ioFlag);
    40    ioParam_originalLayerName(ioFlag);
    44 void SegmentLayer::ioParam_segmentMethod(
enum ParamsIOFlag ioFlag) {
    45    parent->parameters()->ioParamStringRequired(ioFlag, name, 
"segmentMethod", &segmentMethod);
    46    assert(segmentMethod);
    49    if (strcmp(segmentMethod, 
"none") == 0) {
    54       if (parent->columnId() == 0) {
    56                "%s: segmentMethod %s not recognized. Current options are \"none\".\n",
    60       MPI_Barrier(parent->getCommunicator()->communicator());
    65 void SegmentLayer::ioParam_originalLayerName(
enum ParamsIOFlag ioFlag) {
    66    parent->parameters()->ioParamStringRequired(
    67          ioFlag, name, 
"originalLayerName", &originalLayerName);
    68    assert(originalLayerName);
    69    if (ioFlag == PARAMS_IO_READ && originalLayerName[0] == 
'\0') {
    70       if (parent->columnId() == 0) {
    71          ErrorLog().printf(
"%s: originalLayerName must be set.\n", getDescription_c());
    73       MPI_Barrier(parent->getCommunicator()->communicator());
    79 SegmentLayer::communicateInitInfo(std::shared_ptr<CommunicateInitInfoMessage const> message) {
    80    auto status = HyPerLayer::communicateInitInfo(message);
    82    originalLayer = message->lookup<
HyPerLayer>(std::string(originalLayerName));
    83    if (originalLayer == NULL) {
    84       if (parent->columnId() == 0) {
    86                "%s: originalLayerName \"%s\" is not a layer in the HyPerCol.\n",
    90       MPI_Barrier(parent->getCommunicator()->communicator());
    94       return Response::POSTPONE;
    98    originalLayer->synchronizeMarginWidth(
this);
    99    this->synchronizeMarginWidth(originalLayer);
   102    const PVLayerLoc *srcLoc  = originalLayer->getLayerLoc();
   106    if (srcLoc->nxGlobal != thisLoc->nxGlobal || srcLoc->nyGlobal != thisLoc->nyGlobal) {
   107       if (parent->columnId() == 0) {
   108          ErrorLog(errorMessage);
   110                "%s: originalLayer \"%s\" does not have the same x and y dimensions as this "   115                "    original (nx=%d, ny=%d) versus (nx=%d, ny=%d)\n",
   121       MPI_Barrier(parent->getCommunicator()->communicator());
   126    if (thisLoc->nf != 1) {
   127       if (parent->columnId() == 0) {
   128          ErrorLog().printf(
"%s: SegmentLayer must have 1 feature.\n", getDescription_c());
   130       MPI_Barrier(parent->getCommunicator()->communicator());
   135    if (strcmp(segmentMethod, 
"none") == 0 && srcLoc->nf != 1) {
   136       if (parent->columnId() == 0) {
   138                "%s: Source layer must have 1 feature with segmentation method \"none\".\n",
   141       MPI_Barrier(parent->getCommunicator()->communicator());
   148 int SegmentLayer::checkLabelBufSize(
int newSize) {
   149    if (newSize <= labelBufSize) {
   156    labelBuf = (
int *)realloc(labelBuf, newSize * 
sizeof(
int));
   157    maxXBuf  = (
int *)realloc(maxXBuf, newSize * 
sizeof(
int));
   158    maxYBuf  = (
int *)realloc(maxYBuf, newSize * 
sizeof(
int));
   159    minXBuf  = (
int *)realloc(minXBuf, newSize * 
sizeof(
int));
   160    minYBuf  = (
int *)realloc(minYBuf, newSize * 
sizeof(
int));
   163    labelBufSize = newSize;
   167 int SegmentLayer::loadLabelBuf() {
   169    int numLabels = maxX.size();
   171    checkLabelBufSize(numLabels);
   174    for (
auto &m : maxX) {
   175       labelBuf[idx] = m.first; 
   176       maxXBuf[idx]  = m.second; 
   179    assert(idx == numLabels);
   182    for (
int i = 0; i < numLabels; i++) {
   183       int label  = labelBuf[i];
   184       maxYBuf[i] = maxY.at(label);
   185       minXBuf[i] = minX.at(label);
   186       minYBuf[i] = minY.at(label);
   191 int SegmentLayer::loadCenterIdxMap(
int batchIdx, 
int numLabels) {
   192    for (
int i = 0; i < numLabels; i++) {
   193       int label                  = allLabelsBuf[i];
   194       int idx                    = centerIdxBuf[i];
   195       centerIdx[batchIdx][label] = idx;
   200 int SegmentLayer::checkIdxBufSize(
int newSize) {
   201    if (newSize <= centerIdxBufSize) {
   206    centerIdxBuf = (
int *)realloc(centerIdxBuf, newSize * 
sizeof(
int));
   207    allLabelsBuf = (
int *)realloc(allLabelsBuf, newSize * 
sizeof(
int));
   209    centerIdxBufSize = newSize;
   213 Response::Status SegmentLayer::allocateDataStructures() {
   214    auto status = HyPerLayer::allocateDataStructures();
   219    int nbatch = getLayerLoc()->nbatch;
   227    for (
int b = 0; b < nbatch; b++) {
   228       centerIdx.push_back(std::map<int, int>());
   231    return Response::SUCCESS;
   234 void SegmentLayer::allocateV() {
   239 void SegmentLayer::initializeV() { assert(getV() == NULL); }
   241 void SegmentLayer::initializeActivity() {}
   243 Response::Status SegmentLayer::updateState(
double timef, 
double dt) {
   244    float *srcA  = originalLayer->getActivity();
   245    float *thisA = getActivity();
   252    if (strcmp(segmentMethod, 
"none") == 0) {
   253       int numBatchExtended = getNumExtendedAllBatches();
   256       memcpy(thisA, srcA, numBatchExtended * 
sizeof(
float));
   263    assert(loc->nf == 1);
   266    for (
int bi = 0; bi < loc->nbatch; bi++) {
   267       centerIdx[bi].clear();
   270    for (
int bi = 0; bi < loc->nbatch; bi++) {
   271       float *batchA = thisA + bi * getNumExtended();
   280       for (
int yi = loc->halo.up; yi < loc->ny + loc->halo.up; yi++) {
   281          for (
int xi = loc->halo.lt; xi < loc->nx + loc->halo.lt; xi++) {
   283             int niLocalExt = yi * (loc->nx + loc->halo.lt + loc->halo.rt) + xi;
   285             int globalResYi = yi - loc->halo.up + loc->ky0;
   286             int globalResXi = xi - loc->halo.lt + loc->kx0;
   291             int labelVal = round(batchA[niLocalExt]);
   295             if (maxX.count(labelVal)) {
   299                if (globalResXi > maxX.at(labelVal)) {
   300                   maxX[labelVal] = globalResXi;
   302                if (globalResXi < minX.at(labelVal)) {
   303                   minX[labelVal] = globalResXi;
   305                if (globalResYi > maxY.at(labelVal)) {
   306                   maxY[labelVal] = globalResYi;
   308                if (globalResYi < minY.at(labelVal)) {
   309                   minY[labelVal] = globalResYi;
   314                maxX[labelVal] = globalResXi;
   315                minX[labelVal] = globalResXi;
   316                maxY[labelVal] = globalResYi;
   317                minY[labelVal] = globalResYi;
   324       int numMpi           = icComm->commSize();
   325       int rank             = icComm->commRank();
   329       int numLabels = maxX.size();
   335          MPI_Send(&numLabels, 1, MPI_INT, 0, rank, icComm->communicator());
   337          MPI_Send(labelBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
   338          MPI_Send(maxXBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
   339          MPI_Send(maxYBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
   340          MPI_Send(minXBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
   341          MPI_Send(minYBuf, numLabels, MPI_INT, 0, rank, icComm->communicator());
   344          int numCenterIdx = 0;
   345          MPI_Bcast(&numCenterIdx, 1, MPI_INT, 0, icComm->communicator());
   346          checkIdxBufSize(numCenterIdx);
   348          MPI_Bcast(allLabelsBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
   349          MPI_Bcast(centerIdxBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
   352          loadCenterIdxMap(bi, numCenterIdx);
   357          for (
int recvRank = 1; recvRank < numMpi; recvRank++) {
   358             int numRecvLabels = 0;
   359             MPI_Recv(&numRecvLabels, 1, MPI_INT, recvRank, recvRank, icComm->communicator(), NULL);
   360             checkLabelBufSize(numRecvLabels);
   368                   icComm->communicator(),
   376                   icComm->communicator(),
   384                   icComm->communicator(),
   392                   icComm->communicator(),
   400                   icComm->communicator(),
   403             for (
int i = 0; i < numRecvLabels; i++) {
   404                int label = labelBuf[i];
   407                if (maxX.count(label)) {
   408                   if (maxXBuf[i] > maxX.at(label)) {
   409                      maxX[label] = maxXBuf[i];
   411                   if (maxYBuf[i] > maxY.at(label)) {
   412                      maxY[label] = maxYBuf[i];
   414                   if (minXBuf[i] < minX.at(label)) {
   415                      minX[label] = minXBuf[i];
   417                   if (minYBuf[i] < minY.at(label)) {
   418                      minY[label] = minYBuf[i];
   422                   maxX[label] = maxXBuf[i];
   423                   maxY[label] = maxYBuf[i];
   424                   minX[label] = minXBuf[i];
   425                   minY[label] = minYBuf[i];
   432          for (
auto &m : maxX) {
   434             int centerX = minX.at(label) + (maxX.at(label) - minX.at(label)) / 2;
   435             int centerY = minY.at(label) + (maxY.at(label) - minY.at(label)) / 2;
   437             int centerIdxVal = centerY * (loc->nxGlobal) + centerX;
   439             centerIdx[bi][label] = centerIdxVal;
   443          int numCenterIdx = centerIdx[bi].size();
   444          checkIdxBufSize(numCenterIdx);
   447          for (
auto &ctr : centerIdx[bi]) {
   448             allLabelsBuf[idx] = ctr.first;
   449             centerIdxBuf[idx] = ctr.second;
   454          MPI_Bcast(&numCenterIdx, 1, MPI_INT, 0, icComm->communicator());
   455          MPI_Bcast(allLabelsBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
   456          MPI_Bcast(centerIdxBuf, numCenterIdx, MPI_INT, 0, icComm->communicator());
   461    return Response::SUCCESS;
   464 SegmentLayer::~SegmentLayer() {
   465    free(originalLayerName);
 virtual int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
static bool completed(Status &a)
int initialize(const char *name, HyPerCol *hc)
int ioParamsFillGroup(enum ParamsIOFlag ioFlag) override
bool getInitInfoCommunicatedFlag() const