1 #include "ImageLayer.hpp" 2 #include "../arch/mpi/mpi.h" 3 #include "structures/Image.hpp" 12 ImageLayer::ImageLayer(
const char *name, HyPerCol *hc) { initialize(name, hc); }
16 std::string txt =
".txt";
17 if (getInputPath().size() > txt.size()
18 && getInputPath().compare(getInputPath().size() - txt.size(), txt.size(), txt) == 0) {
19 mUsingFileList =
true;
23 InfoLog() <<
"File " << getInputPath() <<
" contains " << mFileList.size() <<
" frames\n";
24 return mFileList.size();
27 mUsingFileList =
false;
34 int blockBatchElement = localBatchElement + getLayerLoc()->nbatch * mpiBatchIndex;
35 int inputIndex = mBatchIndexer->getIndex(blockBatchElement);
36 return mFileList.at(inputIndex);
39 return getInputPath();
43 Response::Status ImageLayer::registerData(
Checkpointer *checkpointer) {
44 auto status = InputLayer::registerData(checkpointer);
48 mURLDownloadTemplate = checkpointer->getOutputPath() +
"/temp.XXXXXX";
49 return Response::SUCCESS;
52 void ImageLayer::populateFileList() {
53 if (getMPIBlock()->getRank() == 0) {
56 InfoLog() <<
"Reading list: " << getInputPath() <<
"\n";
57 std::ifstream infile(getInputPath(), std::ios_base::in);
59 infile.fail(),
"Unable to open \"%s\": %s\n", getInputPath().c_str(), strerror(errno));
60 while (getline(infile, line,
'\n')) {
61 std::string noWhiteSpace = line;
63 std::remove_if(noWhiteSpace.begin(), noWhiteSpace.end(), ::isspace),
65 if (!noWhiteSpace.empty()) {
66 mFileList.push_back(noWhiteSpace);
71 "%s inputPath file list \"%s\" is empty.\n",
73 getInputPath().c_str());
80 filename = mFileList.at(inputIndex);
83 filename = getInputPath();
87 if (mImage->getFeatures() != getLayerLoc()->nf) {
88 switch (getLayerLoc()->nf) {
90 mImage->convertToGray(
false);
93 mImage->convertToGray(
true);
96 mImage->convertToColor(
false);
99 mImage->convertToColor(
true);
102 Fatal() <<
"Failed to read " << filename <<
": Could not convert " 103 << mImage->getFeatures() <<
" channels to " << getLayerLoc()->nf << std::endl;
109 mImage->asVector(), mImage->getWidth(), mImage->getHeight(), getLayerLoc()->nf);
113 void ImageLayer::readImage(std::string filename) {
115 bool usingTempFile =
false;
118 if (filename.find(
"://") != std::string::npos) {
119 usingTempFile =
true;
120 std::string extension = filename.substr(filename.find_last_of(
"."));
121 std::string pathstring = mURLDownloadTemplate + extension;
123 strcpy(tempStr, pathstring.c_str());
124 int tempFileID = mkstemps(tempStr, extension.size());
125 pathstring = std::string(tempStr);
126 FatalIf(tempFileID < 0,
"Cannot create temp image file.\n");
127 std::string systemstring;
129 if (filename.find(
"s3://") != std::string::npos) {
130 systemstring = std::string(
"aws s3 cp \'") + filename + std::string(
"\' ") + pathstring;
133 systemstring = std::string(
"wget -O ") + pathstring + std::string(
" \'") + filename
137 filename = pathstring;
138 const int numAttempts = 5;
139 for (
int attemptNum = 0; attemptNum < numAttempts; attemptNum++) {
140 if (system(systemstring.c_str()) == 0) {
145 attemptNum == numAttempts - 1,
146 "download command \"%s\" failed: %s. Exiting\n",
147 systemstring.c_str(),
152 mImage = std::unique_ptr<Image>(
new Image(std::string(filename)));
155 usingTempFile &&
remove(filename.c_str()),
156 "remove(\"%s\") failed. Exiting.\n",
161 std::string description(
"");
162 if (mUsingFileList) {
163 description = mFileList.at(index);
virtual int countInputImages() override
static bool completed(Status &a)
virtual std::string describeInput(int index) override
virtual std::string const & getCurrentFilename(int localBatchElement, int mpiBatchIndex) const override
virtual Buffer< float > retrieveData(int inputIndex) override