PetaVision  Alpha
WeightsFileIO.cpp
1 #include "WeightsFileIO.hpp"
2 #include <cstdint>
3 
4 namespace PV {
5 
6 WeightsFileIO::WeightsFileIO(FileStream *fileStream, MPIBlock const *mpiBlock, Weights *weights)
7  : mFileStream(fileStream), mMPIBlock(mpiBlock), mWeights(weights) {
8  if (mMPIBlock == nullptr) {
9  throw std::invalid_argument("WeightsFileIO instantiated with a null MPIBlock");
10  }
11  if (mMPIBlock->getRank() == mRootProcess and mFileStream == nullptr) {
12  throw std::invalid_argument(
13  "WeightsFileIO instantiated with a null file stream on the null process");
14  }
15  if (mWeights == nullptr) {
16  throw std::invalid_argument("WeightsFileIO instantiated with a null Weights object");
17  }
18 }
19 
20 // function members for reading
21 double WeightsFileIO::readWeights(int frameNumber) {
22  if (mFileStream != nullptr and !mFileStream->readable()) {
23  throw std::invalid_argument(
24  "WeightsFileIO::readWeights called with a nonreadable file stream");
25  }
26  BufferUtils::WeightHeader header = readHeader(frameNumber);
27  checkHeader(header);
28 
29  double timestamp;
30  if (mWeights->getSharedFlag()) {
31  timestamp = readSharedWeights(frameNumber, header);
32  }
33  else {
34  timestamp = readNonsharedWeights(frameNumber, header);
35  }
36  mWeights->setTimestamp(timestamp);
37  return timestamp;
38 }
39 
40 BufferUtils::WeightHeader WeightsFileIO::readHeader(int frameNumber) {
41  BufferUtils::WeightHeader header;
42  int const rank = mMPIBlock->getRank();
43  if (rank == mRootProcess) {
44  moveToFrame(header, *mFileStream, frameNumber);
45  }
46 
47  MPI_Bcast(&header, (int)sizeof(header), MPI_BYTE, mRootProcess, mMPIBlock->getComm());
48  return header;
49 }
50 
51 void WeightsFileIO::checkHeader(BufferUtils::WeightHeader const &header) {
52  if (mWeights->getSharedFlag()) {
53  FatalIf(
54  header.baseHeader.fileType != PVP_KERNEL_FILE_TYPE,
55  "Connection \"%s\" has sharedWeights true, ",
56  "but \"%s\" is not a shared-weights file\n",
57  mWeights->getName().c_str(),
58  mFileStream->getFileName().c_str());
59  FatalIf(
60  header.numPatches != mWeights->getNumDataPatches(),
61  "Shared-weights connection \"%s\" has a unit cell (%d-by-%d-by-%d), "
62  "but \"%s\" has %d patches.\n",
63  mWeights->getName().c_str(),
64  mWeights->getNumDataPatchesX(),
65  mWeights->getNumDataPatchesY(),
66  mWeights->getNumDataPatchesF(),
67  mFileStream->getFileName().c_str(),
68  header.numPatches);
69  }
70  else {
71  // TODO: It should be allowed to read a kernel file into a non-shared-weights atlas
72  FatalIf(
73  header.baseHeader.fileType != PVP_WGT_FILE_TYPE,
74  "Connection \"%s\" has sharedWeights false.\n",
75  "but \"%s\" is not a non-shared-weights file. ",
76  mWeights->getName().c_str(),
77  mFileStream->getFileName().c_str());
78  }
79  FatalIf(
80  header.baseHeader.nBands < mWeights->getNumArbors(),
81  "Connection \"%s\" has %d arbors, but file \"%s\" has only %d arbors.\n",
82  mWeights->getName().c_str(),
83  mWeights->getNumArbors(),
84  mFileStream->getFileName().c_str(),
85  header.baseHeader.nBands);
86 
87  FatalIf(
88  header.nxp != mWeights->getPatchSizeX(),
89  "Connection \"%s\" has nxp=%d, but file \"%s\" has nxp=%d.\n",
90  mWeights->getName().c_str(),
91  mWeights->getPatchSizeX(),
92  mFileStream->getFileName().c_str(),
93  header.nxp);
94  FatalIf(
95  header.nyp != mWeights->getPatchSizeY(),
96  "Connection \"%s\" has nyp=%d, but file \"%s\" has nyp=%d.\n",
97  mWeights->getName().c_str(),
98  mWeights->getPatchSizeY(),
99  mFileStream->getFileName().c_str(),
100  header.nyp);
101  FatalIf(
102  header.nfp != mWeights->getPatchSizeF(),
103  "Connection \"%s\" has nfp=%d, but file \"%s\" has nfp=%d.\n",
104  mWeights->getName().c_str(),
105  mWeights->getPatchSizeF(),
106  mFileStream->getFileName().c_str(),
107  header.nfp);
108 }
109 
110 bool WeightsFileIO::isCompressedHeader(BufferUtils::WeightHeader const &header) {
111  bool isCompressed;
112  switch (header.baseHeader.dataType) {
113  case BufferUtils::BYTE:
114  FatalIf(
115  header.baseHeader.dataSize != (int)sizeof(unsigned char),
116  "File \"%s\" has dataSize=%d, inconsistent with dataType BYTE (%d)\n",
117  mFileStream->getFileName().c_str(),
118  header.baseHeader.dataSize,
119  header.baseHeader.dataType);
120  isCompressed = true;
121  break;
122  case BufferUtils::FLOAT:
123  FatalIf(
124  header.baseHeader.dataSize != (int)sizeof(float),
125  "File \"%s\" has dataSize=%d, inconsistent with dataType FLOAT (%d)\n",
126  mFileStream->getFileName().c_str(),
127  header.baseHeader.dataSize,
128  header.baseHeader.dataType);
129  isCompressed = false;
130  break;
131  case BufferUtils::INT:
132  Fatal().printf(
133  "File \"%s\" has dataType INT. Only FLOAT and BYTE are supported.\n",
134  mFileStream->getFileName().c_str());
135  break;
136  default:
137  Fatal().printf(
138  "File \"%s\" has unrecognized datatype.\n", mFileStream->getFileName().c_str());
139  break;
140  }
141  return isCompressed;
142 }
143 
144 double WeightsFileIO::readSharedWeights(int frameNumber, BufferUtils::WeightHeader const &header) {
145  bool compressed = isCompressedHeader(header);
146  double timestamp = header.baseHeader.timestamp;
147  long arborSizeInPvpFile = calcArborSizeLocal(compressed);
148  long arborSizeInPvpLocal = arborSizeInPvpFile;
149  std::vector<unsigned char> readBuffer(arborSizeInPvpLocal);
150 
151  int const numArbors = mWeights->getNumArbors();
152  for (int arbor = 0; arbor < numArbors; arbor++) {
153  if (mMPIBlock->getRank() == mRootProcess) {
154  mFileStream->read(readBuffer.data(), arborSizeInPvpFile);
155  }
156  MPI_Bcast(
157  readBuffer.data(), arborSizeInPvpFile, MPI_BYTE, mRootProcess, mMPIBlock->getComm());
158  loadWeightsFromBuffer(readBuffer, arbor, header.minVal, header.maxVal, compressed);
159  }
160  return timestamp;
161 }
162 
163 double
164 WeightsFileIO::readNonsharedWeights(int frameNumber, BufferUtils::WeightHeader const &header) {
165  bool compressed = isCompressedHeader(header);
166  long arborSizeInPvpFile = calcArborSizeFile(compressed);
167  long arborSizeInPvpLocal = calcArborSizeLocal(compressed);
168  std::vector<unsigned char> readBuffer(arborSizeInPvpLocal);
169 
170  int const nxp = mWeights->getPatchSizeX();
171  int const nyp = mWeights->getPatchSizeY();
172  int const nfp = mWeights->getPatchSizeF();
173  long patchSizePvpFormat = (long)BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
174 
175  int const numArbors = mWeights->getNumArbors();
176  if (mMPIBlock->getRank() == mRootProcess) {
177  long const frameStartFile = mFileStream->getInPos();
178  for (int arbor = 0; arbor < numArbors; arbor++) {
179  long const arborStartInFile = frameStartFile + (long)(arbor * arborSizeInPvpFile);
180  mFileStream->setInPos(arborStartInFile, true /*from beginning of file*/);
181 
182  // For each process, need to determine patches to load from the PVP file.
183  // The patch atlas may have a bigger border than the PVP file.
184  int startPatchX, endPatchX, startPatchY, endPatchY;
185  calcPatchBox(startPatchX, endPatchX, startPatchY, endPatchY);
186  int lineCount = (endPatchX - startPatchX) * mWeights->getNumDataPatchesF();
187 
188  PVLayerLoc const &preLoc = mWeights->getGeometry()->getPreLoc();
189  PVLayerLoc const &postLoc = mWeights->getGeometry()->getPostLoc();
190 
191  int marginX = calcNeededBorder(preLoc.nx, postLoc.nx, nxp);
192  int nxExtended = preLoc.nx * mMPIBlock->getGlobalNumColumns() + marginX + marginX;
193 
194  int marginY = calcNeededBorder(preLoc.ny, postLoc.ny, nyp);
195  int nyExtended = preLoc.ny * mMPIBlock->getGlobalNumRows() + marginY + marginY;
196 
197  for (int destRank = 0; destRank < mMPIBlock->getSize(); destRank++) {
198  int rowIndex, columnIndex, batchElemIndex;
199  mMPIBlock->calcRowColBatchFromRank(destRank, rowIndex, columnIndex, batchElemIndex);
200 
201  for (int y = 0; y < endPatchY - startPatchY; y++) {
202  int const startFileX = columnIndex * preLoc.nx;
203  int const startFileY = y + rowIndex * preLoc.ny;
204  int const startFile = kIndex(
205  startFileX, startFileY, 0, nxExtended, nyExtended, header.baseHeader.nf);
206  long lineStartInFile = arborStartInFile + (long)startFile * patchSizePvpFormat;
207  mFileStream->setInPos(lineStartInFile, true /*from beginning of file*/);
208 
209  int const startPatchLocal = kIndex(
210  startPatchX,
211  y + startPatchY,
212  0,
213  mWeights->getNumDataPatchesX(),
214  mWeights->getNumDataPatchesY(),
215  mWeights->getNumDataPatchesF());
216 
217  unsigned char *lineLocInBuffer =
218  &readBuffer[(long)startPatchLocal * patchSizePvpFormat];
219  std::size_t bufferSize = (std::size_t)lineCount * (std::size_t)patchSizePvpFormat;
220  mFileStream->read(lineLocInBuffer, bufferSize);
221  }
222  if (destRank == mRootProcess) {
223  loadWeightsFromBuffer(readBuffer, arbor, header.minVal, header.maxVal, compressed);
224  }
225  else {
226  int tag = tagbase + arbor;
227  MPI_Comm comm = mMPIBlock->getComm();
228  MPI_Send(readBuffer.data(), (int)readBuffer.size(), MPI_BYTE, destRank, tag, comm);
229  }
230  }
231  }
232  }
233  else {
234  for (int arbor = 0; arbor < numArbors; arbor++) {
235  int tag = tagbase + arbor;
236  MPI_Comm comm = mMPIBlock->getComm();
237  MPI_Recv(
238  readBuffer.data(),
239  (int)readBuffer.size(),
240  MPI_BYTE,
241  mRootProcess,
242  tag,
243  comm,
244  MPI_STATUS_IGNORE);
245  loadWeightsFromBuffer(readBuffer, arbor, header.minVal, header.maxVal, compressed);
246  }
247  }
248  return header.baseHeader.timestamp;
249 }
250 
251 // function members for writing
252 void WeightsFileIO::writeWeights(double timestamp, bool compress) {
253  if (mFileStream != nullptr and !mFileStream->writeable()) {
254  throw std::invalid_argument(
255  "WeightsFileIO::writeWeights called with a nonwriteable file stream");
256  }
257  if (mWeights->getSharedFlag()) {
258  writeSharedWeights(timestamp, compress);
259  }
260  else {
261  writeNonsharedWeights(timestamp, compress);
262  }
263 }
264 
265 void WeightsFileIO::writeSharedWeights(double timestamp, bool compress) {
266  if (mMPIBlock->getRank() != mRootProcess) {
267  return;
268  }
269  float minWeight = mWeights->calcMinWeight();
270  float maxWeight = mWeights->calcMaxWeight();
271  BufferUtils::WeightHeader header = BufferUtils::buildSharedWeightHeader(
272  mWeights->getPatchSizeX(),
273  mWeights->getPatchSizeY(),
274  mWeights->getPatchSizeF(),
275  mWeights->getNumArbors(),
276  mWeights->getNumDataPatchesX(),
277  mWeights->getNumDataPatchesY(),
278  mWeights->getNumDataPatchesF(),
279  timestamp,
280  compress,
281  minWeight,
282  maxWeight);
283 
284  mFileStream->write(&header, sizeof(header));
285 
286  long arborSizeInPvpFile = calcArborSizeLocal(compress);
287  long arborSizeInPvpLocal = arborSizeInPvpFile;
288  std::vector<unsigned char> writeBuffer(arborSizeInPvpLocal);
289 
290  int const numArbors = mWeights->getNumArbors();
291  for (int arbor = 0; arbor < numArbors; arbor++) {
292  storeSharedPatches(writeBuffer, arbor, minWeight, maxWeight, compress);
293  mFileStream->write(writeBuffer.data(), arborSizeInPvpFile);
294  }
295 }
296 
297 void WeightsFileIO::writeNonsharedWeights(double timestamp, bool compress) {
298  float extrema[2];
299  extrema[0] = mWeights->calcMinWeight();
300  extrema[1] = -mWeights->calcMaxWeight();
301  MPI_Allreduce(MPI_IN_PLACE, extrema, 2, MPI_FLOAT, MPI_MIN, mMPIBlock->getComm());
302  extrema[1] = -extrema[1];
303 
304  long arborSizeInPvpFile = calcArborSizeFile(compress);
305  long arborSizeInPvpLocal = calcArborSizeLocal(compress);
306  std::vector<unsigned char> writeBuffer(arborSizeInPvpLocal);
307 
308  int const numArbors = mWeights->getNumArbors();
309  if (mMPIBlock->getRank() == mRootProcess) {
310 
311  BufferUtils::WeightHeader header = BufferUtils::buildNonsharedWeightHeader(
312  mWeights->getPatchSizeX(),
313  mWeights->getPatchSizeY(),
314  mWeights->getPatchSizeF(),
315  mWeights->getNumArbors(),
316  true /*extended*/,
317  timestamp,
318  &mWeights->getGeometry()->getPreLoc(),
319  &mWeights->getGeometry()->getPostLoc(),
320  mMPIBlock->getNumColumns(),
321  mMPIBlock->getNumRows(),
322  extrema[0] /*min weight*/,
323  extrema[1] /*max weight*/,
324  compress);
325  mFileStream->write(&header, sizeof(header));
326 
327  long const frameStartFile = mFileStream->getOutPos();
328  for (int arbor = 0; arbor < numArbors; arbor++) {
329  long const arborStartFile = frameStartFile + (long)(arbor * arborSizeInPvpFile);
330  mFileStream->setOutPos(arborStartFile, true /*from beginning of file*/);
331 
332  // For each process, need to determine patches to write to the PVP file.
333  // The patch atlas may have a bigger border than the PVP file has.
334  int startPatchX, endPatchX, startPatchY, endPatchY;
335  calcPatchBox(startPatchX, endPatchX, startPatchY, endPatchY);
336  int const numDataPatchesF = mWeights->getNumDataPatchesF();
337  int startPatchK = startPatchX * numDataPatchesF;
338  int endPatchK = endPatchX * numDataPatchesF;
339  int const numDataPatchesK = mWeights->getNumDataPatchesX() * numDataPatchesF;
340 
341  int const nxp = mWeights->getPatchSizeX();
342  int const nyp = mWeights->getPatchSizeY();
343  int const nfp = mWeights->getPatchSizeF();
344  auto const patchSizePvpFormat = BufferUtils::weightPatchSize(nxp * nyp * nfp, compress);
345 
346  for (int sourceRank = 0; sourceRank < mMPIBlock->getSize(); sourceRank++) {
347  int rowIndex, columnIndex, batchElemIndex;
348  mMPIBlock->calcRowColBatchFromRank(sourceRank, rowIndex, columnIndex, batchElemIndex);
349 
350  if (sourceRank == mRootProcess) {
351  storeNonsharedPatches(writeBuffer, arbor, extrema[0], extrema[1], compress);
352  }
353  else {
354  int tag = tagbase + arbor;
355  MPI_Comm comm = mMPIBlock->getComm();
356  MPI_Recv(
357  writeBuffer.data(),
358  (int)writeBuffer.size(),
359  MPI_BYTE,
360  sourceRank,
361  tag,
362  comm,
363  MPI_STATUS_IGNORE);
364  }
365 
366  for (int y = 0; y < endPatchY - startPatchY; y++) {
367  PVLayerLoc const &preLoc = mWeights->getGeometry()->getPreLoc();
368  int const startFileX = columnIndex * preLoc.nx;
369  int const startFileY = y + rowIndex * preLoc.ny;
370  int const startFile = kIndex(
371  startFileX,
372  startFileY,
373  0,
374  header.baseHeader.nxExtended,
375  header.baseHeader.nyExtended,
376  header.baseHeader.nf);
377  long lineStartFile = arborStartFile + (long)startFile * (long)patchSizePvpFormat;
378  mFileStream->setOutPos(lineStartFile, true /*from beginning of file*/);
379 
380  int const startPatchLocal = kIndex(
381  startPatchX,
382  y + startPatchY,
383  0,
384  mWeights->getNumDataPatchesX(),
385  mWeights->getNumDataPatchesY(),
386  mWeights->getNumDataPatchesF());
387 
388  for (int k = startPatchK; k < endPatchK; k++) {
389  int patchIndexLocal = kIndex(
390  k, y + startPatchY, 0, numDataPatchesK, mWeights->getNumDataPatchesY(), 1);
391  unsigned char *patchLocInBuffer =
392  &writeBuffer[patchIndexLocal * patchSizePvpFormat];
393  writePatch(patchLocInBuffer, compress);
394  }
395  }
396  }
397  }
398  // If file length is shorter than it should be, the last patch is shrunken at the end.
399  // In this case, we need to pad out the file length so that file reading does not hit
400  // end-of-file too early.
401  // If file length is longer than required by this frame, we don't need to do anything. This
402  // situation can arise, for example, for the outputPath file from a connection if we restart
403  // from a checkpoint when several frames were written after that checkpoint.
404  long const frameEndFile = frameStartFile + (long)(numArbors * arborSizeInPvpFile);
405  mFileStream->setOutPos(0L, std::ios_base::end);
406  long const endOfFile = mFileStream->getOutPos();
407  if (endOfFile < frameEndFile) {
408  mFileStream->setOutPos(frameEndFile - 1L, true /*from beginning*/);
409  mFileStream->write("\0", 1L);
410  }
411  mFileStream->setOutPos(frameEndFile, true /*from beginning*/);
412  }
413  else {
414  for (int arbor = 0; arbor < numArbors; arbor++) {
415  storeNonsharedPatches(writeBuffer, arbor, extrema[0], extrema[1], compress);
416  int tag = tagbase + arbor;
417  MPI_Comm comm = mMPIBlock->getComm();
418  MPI_Send(writeBuffer.data(), (int)writeBuffer.size(), MPI_BYTE, mRootProcess, tag, comm);
419  }
420  }
421 }
422 
423 void WeightsFileIO::writePatch(unsigned char const *patchBuffer, bool compressed) {
424  int const nxp = mWeights->getPatchSizeX();
425  int const nyp = mWeights->getPatchSizeY();
426  int const nfp = mWeights->getPatchSizeF();
427 
428  Patch patch;
429 
430  // In the file, patch header is always unshrunken. Otherwise, we would have to
431  // handle patches in overlap regions by reading in, forming the union of active regions,
432  // and then writing back. The appearance of the patch headers in the pvp file is a legacy
433  // from olden times when each process wrote its own pvp file.
434  patch.nx = (std::uint16_t)nxp;
435  patch.ny = (std::uint16_t)nyp;
436  patch.offset = (std::uint32_t)0;
437  mFileStream->write(&patch.nx, sizeof(patch.nx));
438  mFileStream->write(&patch.ny, sizeof(patch.ny));
439  mFileStream->write(&patch.offset, sizeof(patch.offset));
440 
441  // Now load the patch header
442  memcpy(&patch.nx, patchBuffer, sizeof(patch.nx));
443  memcpy(&patch.ny, &patchBuffer[sizeof(patch.nx)], sizeof(patch.ny));
444  memcpy(&patch.offset, &patchBuffer[sizeof(patch.nx) + sizeof(patch.ny)], sizeof(patch.offset));
445 
446  std::size_t patchHeaderSize = sizeof(patch.nx) + sizeof(patch.ny) + sizeof(patch.offset);
447  std::size_t dataSize = compressed ? sizeof(unsigned char) : sizeof(float);
448  std::size_t patchDataStartOffset = patchHeaderSize + (std::size_t)patch.offset * dataSize;
449  unsigned char const *patchDataStart = &patchBuffer[patchDataStartOffset];
450  long patchStartInFile = mFileStream->getOutPos();
451  long patchEndInFile = patchStartInFile + (long)(nxp * nyp * nfp * (int)dataSize);
452  mFileStream->setOutPos((long)patch.offset * (long)dataSize, false /*from current position*/);
453  if ((int)patch.nx == nxp) {
454  // active region is contiguous in memory; write all lines at once
455  long dataLength = (long)patch.ny * (long)(nxp * nfp) * (long)dataSize;
456  mFileStream->write(patchDataStart, dataLength);
457  }
458  else {
459  // active region is not contiguous. Write each line, then skip to the start of the next line
460  std::size_t stride = (std::size_t)(nfp * nxp) * dataSize;
461  std::size_t lineLength = (std::size_t)nfp * (std::size_t)patch.nx * dataSize;
462  std::size_t skipLength = stride - lineLength;
463  for (std::uint16_t y = (std::uint16_t)0; y < patch.ny - (std::uint16_t)1; y++) {
464  unsigned char const *lineStartInBuffer = &patchDataStart[y * stride];
465  mFileStream->write(lineStartInBuffer, lineLength);
466  mFileStream->setOutPos(skipLength, false /*from current position*/);
467  }
468  if (patch.ny > (std::uint16_t)0) {
469  std::size_t lastLineOffset = (std::size_t)(patch.ny - 1) * stride;
470  unsigned char const *lineStartInBuffer = &patchDataStart[lastLineOffset];
471  mFileStream->write(lineStartInBuffer, lineLength);
472  }
473  }
474  mFileStream->setOutPos(patchEndInFile, true /*from start of file*/);
475 }
476 
477 // utility function members
478 
479 void WeightsFileIO::moveToFrame(
481  FileStream &fileStream,
482  int frameNumber) {
483  fileStream.setInPos(0L, true /*from beginning*/);
484  for (int f = 0; f < frameNumber; f++) {
485  fileStream.read(&header, sizeof(header));
486  long recordSize = (long)(header.baseHeader.recordSize * header.baseHeader.numRecords);
487  fileStream.setInPos(recordSize, false /*relative to current point*/);
488  }
489  fileStream.read(&header, sizeof(header));
490 }
491 
492 long WeightsFileIO::calcArborSizeFile(bool compressed) {
493  int const nxp = mWeights->getPatchSizeX();
494  int const nyp = mWeights->getPatchSizeY();
495  int const nfp = mWeights->getPatchSizeF();
496  int const patchSize = (int)BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
497 
498  int numPatches;
499  if (mWeights->getSharedFlag()) {
500  numPatches = mWeights->getNumDataPatches();
501  }
502  else {
503  PVLayerLoc const &preLoc = mWeights->getGeometry()->getPreLoc();
504  PVLayerLoc const &postLoc = mWeights->getGeometry()->getPostLoc();
505 
506  int marginX = calcNeededBorder(preLoc.nx, postLoc.nx, mWeights->getPatchSizeX());
507  int numPatchesX = preLoc.nx * mMPIBlock->getGlobalNumColumns() + marginX + marginX;
508 
509  int marginY = calcNeededBorder(preLoc.ny, postLoc.ny, mWeights->getPatchSizeY());
510  int numPatchesY = preLoc.ny * mMPIBlock->getGlobalNumRows() + marginY + marginY;
511 
512  numPatches = numPatchesX * numPatchesY * preLoc.nf;
513  }
514 
515  int const arborSize = numPatches * patchSize;
516  return arborSize;
517 }
518 
519 long WeightsFileIO::calcArborSizeLocal(bool compressed) {
520  int const nxp = mWeights->getPatchSizeX();
521  int const nyp = mWeights->getPatchSizeY();
522  int const nfp = mWeights->getPatchSizeF();
523  int const patchSize = (int)BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
524 
525  int numPatches = mWeights->getNumDataPatches();
526 
527  int const arborSize = numPatches * patchSize;
528 
529  return arborSize;
530 }
531 
532 void WeightsFileIO::calcPatchBox(
533  int &startPatchX,
534  int &endPatchX,
535  int &startPatchY,
536  int &endPatchY) {
537  PVLayerLoc const &preLoc = mWeights->getGeometry()->getPreLoc();
538  PVLayerLoc const &postLoc = mWeights->getGeometry()->getPostLoc();
539  PVHalo const &preHalo = preLoc.halo;
540 
541  int nxp = mWeights->getPatchSizeX();
542  calcPatchRange(preLoc.nx, postLoc.nx, preHalo.lt, preHalo.rt, nxp, startPatchX, endPatchX);
543 
544  int nyp = mWeights->getPatchSizeY();
545  calcPatchRange(preLoc.ny, postLoc.ny, preHalo.up, preHalo.dn, nyp, startPatchY, endPatchY);
546 }
547 
548 void WeightsFileIO::calcPatchRange(
549  int nPre,
550  int nPost,
551  int preStartBorder,
552  int preEndBorder,
553  int patchSize,
554  int &startPatch,
555  int &endPatch) {
556  int const neededBorder = calcNeededBorder(nPre, nPost, patchSize);
557 
558  startPatch = (preStartBorder >= neededBorder) ? preStartBorder - neededBorder : 0;
559  endPatch = preStartBorder + nPre;
560  endPatch += (preEndBorder >= neededBorder) ? neededBorder : preEndBorder;
561 }
562 
563 int WeightsFileIO::calcNeededBorder(int nPre, int nPost, int patchSize) {
564  int neededBorder;
565  if (nPre > nPost) {
566  pvAssert(nPre % nPost == 0);
567  int stride = nPre / nPost;
568  pvAssert(stride % 2 == 0);
569  int halfstride = stride / 2;
570  neededBorder = (patchSize - 1) * halfstride;
571  }
572  else if (nPre < nPost) {
573  pvAssert(nPost % nPre == 0);
574  int tstride = nPost / nPre;
575  pvAssert(patchSize % tstride == 0);
576  neededBorder = patchSize / (2 * tstride); // integer division
577  }
578  else {
579  pvAssert(nPre == nPost);
580  pvAssert(patchSize % 2 == 1);
581  neededBorder = (patchSize - 1) / 2;
582  }
583  return neededBorder;
584 }
585 
586 void WeightsFileIO::loadWeightsFromBuffer(
587  std::vector<unsigned char> const &dataFromFile,
588  int arbor,
589  float minValue,
590  float maxValue,
591  bool compressed) {
592  int const nxp = mWeights->getPatchSizeX();
593  int const nyp = mWeights->getPatchSizeY();
594  int const nfp = mWeights->getPatchSizeF();
595  int const numPatches = mWeights->getNumDataPatches();
596 
597  auto const patchSizePvpFormat = BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
598  std::size_t const patchHeaderSize = sizeof(unsigned int) + 2UL * sizeof(unsigned short);
599  if (compressed) {
600  for (int k = 0; k < numPatches; k++) {
601  std::size_t const offsetInFile = patchSizePvpFormat * (std::size_t)k;
602  unsigned char const *patchFromFile = &dataFromFile[offsetInFile + patchHeaderSize];
603  float *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
604  decompressPatch(patchFromFile, weightsInPatch, nxp * nyp * nfp, minValue, maxValue);
605  }
606  }
607  else {
608  for (int k = 0; k < numPatches; k++) {
609  std::size_t const offsetInFile = patchSizePvpFormat * (std::size_t)k;
610  unsigned char const *patchFromFile = &dataFromFile[offsetInFile + patchHeaderSize];
611  float *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
612  memcpy(weightsInPatch, patchFromFile, (std::size_t)(nxp * nyp * nfp) * sizeof(float));
613  }
614  }
615 }
616 
617 void WeightsFileIO::decompressPatch(
618  unsigned char const *dataFromFile,
619  float *destWeights,
620  int count,
621  float minValue,
622  float maxValue) {
623  for (int k = 0; k < count; k++) {
624  float compressedWeight = (float)dataFromFile[k] / 255.0f;
625  destWeights[k] = (compressedWeight) * (maxValue - minValue) + minValue;
626  }
627 }
628 
629 // TODO: templating to reduce code duplication between and within store{Nonshared,Shared}Patches
630 void WeightsFileIO::storeSharedPatches(
631  std::vector<unsigned char> &dataFromFile,
632  int arbor,
633  float minValue,
634  float maxValue,
635  bool compressed) {
636  int const nxp = mWeights->getPatchSizeX();
637  int const nyp = mWeights->getPatchSizeY();
638  int const nfp = mWeights->getPatchSizeF();
639 
640  int const numDataPatches = mWeights->getNumDataPatches();
641  auto const patchSizePvpFormat = BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
642  std::size_t const patchHeaderSize = sizeof(unsigned int) + 2UL * sizeof(unsigned short);
643  if (compressed) {
644  for (int k = 0; k < numDataPatches; k++) {
645  std::size_t const offsetInFile = patchSizePvpFormat * (std::size_t)k;
646  unsigned char *patchFromFile = &dataFromFile[offsetInFile];
647  unsigned short shortDim;
648  shortDim = (unsigned short)nxp;
649  memcpy(patchFromFile, &shortDim, sizeof(shortDim));
650  shortDim = (unsigned short)nyp;
651  memcpy(&patchFromFile[sizeof(shortDim)], &shortDim, sizeof(shortDim));
652 
653  // always zero offset for shared
654  memset(&patchFromFile[2UL * sizeof(shortDim)], 0, sizeof(unsigned int));
655  patchFromFile += patchHeaderSize;
656  float const *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
657  compressPatch(patchFromFile, weightsInPatch, nxp * nyp * nfp, minValue, maxValue);
658  }
659  }
660  else {
661  for (int k = 0; k < numDataPatches; k++) {
662  std::size_t const offsetInFile = patchSizePvpFormat * (std::size_t)k;
663  unsigned char *patchFromFile = &dataFromFile[offsetInFile];
664  unsigned short shortDim;
665  shortDim = (unsigned short)nxp;
666  memcpy(patchFromFile, &shortDim, sizeof(shortDim));
667  shortDim = (unsigned short)nyp;
668  memcpy(&patchFromFile[sizeof(shortDim)], &shortDim, sizeof(shortDim));
669 
670  // always zero offset for shared
671  memset(&patchFromFile[2UL * sizeof(shortDim)], 0, sizeof(unsigned int));
672  patchFromFile += patchHeaderSize;
673  float *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
674  memcpy(patchFromFile, weightsInPatch, (std::size_t)(nxp * nyp * nfp) * sizeof(float));
675  }
676  }
677 }
678 
679 void WeightsFileIO::storeNonsharedPatches(
680  std::vector<unsigned char> &dataFromFile,
681  int arbor,
682  float minValue,
683  float maxValue,
684  bool compressed) {
685  int const nxp = mWeights->getPatchSizeX();
686  int const nyp = mWeights->getPatchSizeY();
687  int const nfp = mWeights->getPatchSizeF();
688  int const numDataPatches = mWeights->getNumDataPatches();
689 
690  auto const patchSizePvpFormat = BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
691  std::size_t const patchHeaderSize = sizeof(std::uint32_t) + 2UL * sizeof(std::uint16_t);
692  if (compressed) {
693  for (int k = 0; k < numDataPatches; k++) {
694  std::size_t const offsetInFile = patchSizePvpFormat * (std::size_t)k;
695  unsigned char *patchFromFile = &dataFromFile[offsetInFile];
696 
697  Patch const &patch = mWeights->getPatch(k);
698  std::uint16_t shortDim;
699  shortDim = (std::uint16_t)patch.nx;
700  memcpy(patchFromFile, &shortDim, sizeof(shortDim));
701  shortDim = (std::uint16_t)patch.ny;
702  memcpy(&patchFromFile[sizeof(shortDim)], &shortDim, sizeof(shortDim));
703  std::uint32_t offset = (std::uint32_t)patch.offset;
704  memcpy(&patchFromFile[2UL * sizeof(shortDim)], &offset, sizeof(offset));
705  patchFromFile += patchHeaderSize;
706  float const *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
707  compressPatch(patchFromFile, weightsInPatch, nxp * nyp * nfp, minValue, maxValue);
708  }
709  }
710  else {
711  for (int k = 0; k < numDataPatches; k++) {
712  std::size_t const offsetInFile = patchSizePvpFormat * (std::size_t)k;
713  unsigned char *patchFromFile = &dataFromFile[offsetInFile];
714 
715  Patch const &patch = mWeights->getPatch(k);
716  std::uint16_t shortDim;
717  shortDim = (std::uint16_t)patch.nx;
718  memcpy(patchFromFile, &shortDim, sizeof(shortDim));
719  shortDim = (std::uint16_t)patch.ny;
720  memcpy(&patchFromFile[sizeof(shortDim)], &shortDim, sizeof(shortDim));
721  std::uint32_t offset = (std::uint32_t)patch.offset;
722  memcpy(&patchFromFile[2UL * sizeof(shortDim)], &offset, sizeof(offset));
723  patchFromFile += patchHeaderSize;
724  float const *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
725  memcpy(patchFromFile, weightsInPatch, (std::size_t)(nxp * nyp * nfp) * sizeof(float));
726  }
727  }
728 }
729 
730 void WeightsFileIO::compressPatch(
731  unsigned char *dataForFile,
732  float const *sourceWeights,
733  int count,
734  float minValue,
735  float maxValue) {
736  for (int k = 0; k < count; k++) {
737  float compressedWeight = (sourceWeights[k] - minValue) / (maxValue - minValue);
738  dataForFile[k] = (unsigned char)std::floor(255.0f * compressedWeight);
739  }
740 }
741 
742 } // namespace PV