PetaVision  Alpha
ImageLayer.cpp
1 #include "ImageLayer.hpp"
2 #include "../arch/mpi/mpi.h"
3 #include "structures/Image.hpp"
4 
5 #include <algorithm>
6 #include <cassert>
7 #include <cstring>
8 #include <iostream>
9 
10 namespace PV {
11 
12 ImageLayer::ImageLayer(const char *name, HyPerCol *hc) { initialize(name, hc); }
13 
15  // Check if the input path ends in ".txt" and enable the file list if so
16  std::string txt = ".txt";
17  if (getInputPath().size() > txt.size()
18  && getInputPath().compare(getInputPath().size() - txt.size(), txt.size(), txt) == 0) {
19  mUsingFileList = true;
20 
21  // Calculate file positions for beginning of each frame
22  populateFileList();
23  InfoLog() << "File " << getInputPath() << " contains " << mFileList.size() << " frames\n";
24  return mFileList.size();
25  }
26  else {
27  mUsingFileList = false;
28  return 1;
29  }
30 }
31 
32 std::string const &ImageLayer::getCurrentFilename(int localBatchElement, int mpiBatchIndex) const {
33  if (mUsingFileList) {
34  int blockBatchElement = localBatchElement + getLayerLoc()->nbatch * mpiBatchIndex;
35  int inputIndex = mBatchIndexer->getIndex(blockBatchElement);
36  return mFileList.at(inputIndex);
37  }
38  else {
39  return getInputPath();
40  }
41 }
42 
43 Response::Status ImageLayer::registerData(Checkpointer *checkpointer) {
44  auto status = InputLayer::registerData(checkpointer);
45  if (!Response::completed(status)) {
46  return status;
47  }
48  mURLDownloadTemplate = checkpointer->getOutputPath() + "/temp.XXXXXX";
49  return Response::SUCCESS;
50 }
51 
52 void ImageLayer::populateFileList() {
53  if (getMPIBlock()->getRank() == 0) {
54  std::string line;
55  mFileList.clear();
56  InfoLog() << "Reading list: " << getInputPath() << "\n";
57  std::ifstream infile(getInputPath(), std::ios_base::in);
58  FatalIf(
59  infile.fail(), "Unable to open \"%s\": %s\n", getInputPath().c_str(), strerror(errno));
60  while (getline(infile, line, '\n')) {
61  std::string noWhiteSpace = line;
62  noWhiteSpace.erase(
63  std::remove_if(noWhiteSpace.begin(), noWhiteSpace.end(), ::isspace),
64  noWhiteSpace.end());
65  if (!noWhiteSpace.empty()) {
66  mFileList.push_back(noWhiteSpace);
67  }
68  }
69  FatalIf(
70  mFileList.empty(),
71  "%s inputPath file list \"%s\" is empty.\n",
72  getDescription_c(),
73  getInputPath().c_str());
74  }
75 }
76 
78  std::string filename;
79  if (mUsingFileList) {
80  filename = mFileList.at(inputIndex);
81  }
82  else {
83  filename = getInputPath();
84  }
85  readImage(filename);
86 
87  if (mImage->getFeatures() != getLayerLoc()->nf) {
88  switch (getLayerLoc()->nf) {
89  case 1: // Grayscale
90  mImage->convertToGray(false);
91  break;
92  case 2: // Grayscale + Alpha
93  mImage->convertToGray(true);
94  break;
95  case 3: // RGB
96  mImage->convertToColor(false);
97  break;
98  case 4: // RGBA
99  mImage->convertToColor(true);
100  break;
101  default:
102  Fatal() << "Failed to read " << filename << ": Could not convert "
103  << mImage->getFeatures() << " channels to " << getLayerLoc()->nf << std::endl;
104  break;
105  }
106  }
107 
108  Buffer<float> result(
109  mImage->asVector(), mImage->getWidth(), mImage->getHeight(), getLayerLoc()->nf);
110  return result;
111 }
112 
113 void ImageLayer::readImage(std::string filename) {
114  const PVLayerLoc *loc = getLayerLoc();
115  bool usingTempFile = false;
116 
117  // Attempt to download our input file if we've been passed a URL or AWS path
118  if (filename.find("://") != std::string::npos) {
119  usingTempFile = true;
120  std::string extension = filename.substr(filename.find_last_of("."));
121  std::string pathstring = mURLDownloadTemplate + extension;
122  char tempStr[256];
123  strcpy(tempStr, pathstring.c_str());
124  int tempFileID = mkstemps(tempStr, extension.size());
125  pathstring = std::string(tempStr);
126  FatalIf(tempFileID < 0, "Cannot create temp image file.\n");
127  std::string systemstring;
128 
129  if (filename.find("s3://") != std::string::npos) {
130  systemstring = std::string("aws s3 cp \'") + filename + std::string("\' ") + pathstring;
131  }
132  else { // URLs other than s3://
133  systemstring = std::string("wget -O ") + pathstring + std::string(" \'") + filename
134  + std::string("\'");
135  }
136 
137  filename = pathstring;
138  const int numAttempts = 5;
139  for (int attemptNum = 0; attemptNum < numAttempts; attemptNum++) {
140  if (system(systemstring.c_str()) == 0) {
141  break;
142  }
143  sleep(1);
144  FatalIf(
145  attemptNum == numAttempts - 1,
146  "download command \"%s\" failed: %s. Exiting\n",
147  systemstring.c_str(),
148  strerror(errno));
149  }
150  }
151 
152  mImage = std::unique_ptr<Image>(new Image(std::string(filename)));
153 
154  FatalIf(
155  usingTempFile && remove(filename.c_str()),
156  "remove(\"%s\") failed. Exiting.\n",
157  filename.c_str());
158 }
159 
160 std::string ImageLayer::describeInput(int index) {
161  std::string description("");
162  if (mUsingFileList) {
163  description = mFileList.at(index);
164  }
165  return description;
166 }
167 }
virtual int countInputImages() override
Definition: ImageLayer.cpp:14
static bool completed(Status &a)
Definition: Response.hpp:49
virtual std::string describeInput(int index) override
Definition: ImageLayer.cpp:160
virtual std::string const & getCurrentFilename(int localBatchElement, int mpiBatchIndex) const override
Definition: ImageLayer.cpp:32
virtual Buffer< float > retrieveData(int inputIndex) override
Definition: ImageLayer.cpp:77