1 #include "WeightsFileIO.hpp" 6 WeightsFileIO::WeightsFileIO(FileStream *fileStream, MPIBlock
const *mpiBlock, Weights *weights)
7 : mFileStream(fileStream), mMPIBlock(mpiBlock), mWeights(weights) {
8 if (mMPIBlock ==
nullptr) {
9 throw std::invalid_argument(
"WeightsFileIO instantiated with a null MPIBlock");
11 if (mMPIBlock->getRank() == mRootProcess and mFileStream ==
nullptr) {
12 throw std::invalid_argument(
13 "WeightsFileIO instantiated with a null file stream on the null process");
15 if (mWeights ==
nullptr) {
16 throw std::invalid_argument(
"WeightsFileIO instantiated with a null Weights object");
21 double WeightsFileIO::readWeights(
int frameNumber) {
22 if (mFileStream !=
nullptr and !mFileStream->readable()) {
23 throw std::invalid_argument(
24 "WeightsFileIO::readWeights called with a nonreadable file stream");
26 BufferUtils::WeightHeader header = readHeader(frameNumber);
30 if (mWeights->getSharedFlag()) {
31 timestamp = readSharedWeights(frameNumber, header);
34 timestamp = readNonsharedWeights(frameNumber, header);
36 mWeights->setTimestamp(timestamp);
40 BufferUtils::WeightHeader WeightsFileIO::readHeader(
int frameNumber) {
41 BufferUtils::WeightHeader header;
42 int const rank = mMPIBlock->getRank();
43 if (rank == mRootProcess) {
44 moveToFrame(header, *mFileStream, frameNumber);
47 MPI_Bcast(&header, (
int)
sizeof(header), MPI_BYTE, mRootProcess, mMPIBlock->getComm());
51 void WeightsFileIO::checkHeader(BufferUtils::WeightHeader
const &header) {
52 if (mWeights->getSharedFlag()) {
54 header.baseHeader.fileType != PVP_KERNEL_FILE_TYPE,
55 "Connection \"%s\" has sharedWeights true, ",
56 "but \"%s\" is not a shared-weights file\n",
57 mWeights->getName().c_str(),
58 mFileStream->getFileName().c_str());
60 header.numPatches != mWeights->getNumDataPatches(),
61 "Shared-weights connection \"%s\" has a unit cell (%d-by-%d-by-%d), " 62 "but \"%s\" has %d patches.\n",
63 mWeights->getName().c_str(),
64 mWeights->getNumDataPatchesX(),
65 mWeights->getNumDataPatchesY(),
66 mWeights->getNumDataPatchesF(),
67 mFileStream->getFileName().c_str(),
73 header.baseHeader.fileType != PVP_WGT_FILE_TYPE,
74 "Connection \"%s\" has sharedWeights false.\n",
75 "but \"%s\" is not a non-shared-weights file. ",
76 mWeights->getName().c_str(),
77 mFileStream->getFileName().c_str());
80 header.baseHeader.nBands < mWeights->getNumArbors(),
81 "Connection \"%s\" has %d arbors, but file \"%s\" has only %d arbors.\n",
82 mWeights->getName().c_str(),
83 mWeights->getNumArbors(),
84 mFileStream->getFileName().c_str(),
85 header.baseHeader.nBands);
88 header.nxp != mWeights->getPatchSizeX(),
89 "Connection \"%s\" has nxp=%d, but file \"%s\" has nxp=%d.\n",
90 mWeights->getName().c_str(),
91 mWeights->getPatchSizeX(),
92 mFileStream->getFileName().c_str(),
95 header.nyp != mWeights->getPatchSizeY(),
96 "Connection \"%s\" has nyp=%d, but file \"%s\" has nyp=%d.\n",
97 mWeights->getName().c_str(),
98 mWeights->getPatchSizeY(),
99 mFileStream->getFileName().c_str(),
102 header.nfp != mWeights->getPatchSizeF(),
103 "Connection \"%s\" has nfp=%d, but file \"%s\" has nfp=%d.\n",
104 mWeights->getName().c_str(),
105 mWeights->getPatchSizeF(),
106 mFileStream->getFileName().c_str(),
110 bool WeightsFileIO::isCompressedHeader(BufferUtils::WeightHeader
const &header) {
112 switch (header.baseHeader.dataType) {
113 case BufferUtils::BYTE:
115 header.baseHeader.dataSize != (
int)
sizeof(
unsigned char),
116 "File \"%s\" has dataSize=%d, inconsistent with dataType BYTE (%d)\n",
117 mFileStream->getFileName().c_str(),
118 header.baseHeader.dataSize,
119 header.baseHeader.dataType);
122 case BufferUtils::FLOAT:
124 header.baseHeader.dataSize != (
int)
sizeof(
float),
125 "File \"%s\" has dataSize=%d, inconsistent with dataType FLOAT (%d)\n",
126 mFileStream->getFileName().c_str(),
127 header.baseHeader.dataSize,
128 header.baseHeader.dataType);
129 isCompressed =
false;
131 case BufferUtils::INT:
133 "File \"%s\" has dataType INT. Only FLOAT and BYTE are supported.\n",
134 mFileStream->getFileName().c_str());
138 "File \"%s\" has unrecognized datatype.\n", mFileStream->getFileName().c_str());
144 double WeightsFileIO::readSharedWeights(
int frameNumber, BufferUtils::WeightHeader
const &header) {
145 bool compressed = isCompressedHeader(header);
146 double timestamp = header.baseHeader.timestamp;
147 long arborSizeInPvpFile = calcArborSizeLocal(compressed);
148 long arborSizeInPvpLocal = arborSizeInPvpFile;
149 std::vector<unsigned char> readBuffer(arborSizeInPvpLocal);
151 int const numArbors = mWeights->getNumArbors();
152 for (
int arbor = 0; arbor < numArbors; arbor++) {
153 if (mMPIBlock->getRank() == mRootProcess) {
154 mFileStream->read(readBuffer.data(), arborSizeInPvpFile);
157 readBuffer.data(), arborSizeInPvpFile, MPI_BYTE, mRootProcess, mMPIBlock->getComm());
158 loadWeightsFromBuffer(readBuffer, arbor, header.minVal, header.maxVal, compressed);
164 WeightsFileIO::readNonsharedWeights(
int frameNumber, BufferUtils::WeightHeader
const &header) {
165 bool compressed = isCompressedHeader(header);
166 long arborSizeInPvpFile = calcArborSizeFile(compressed);
167 long arborSizeInPvpLocal = calcArborSizeLocal(compressed);
168 std::vector<unsigned char> readBuffer(arborSizeInPvpLocal);
170 int const nxp = mWeights->getPatchSizeX();
171 int const nyp = mWeights->getPatchSizeY();
172 int const nfp = mWeights->getPatchSizeF();
173 long patchSizePvpFormat = (long)BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
175 int const numArbors = mWeights->getNumArbors();
176 if (mMPIBlock->getRank() == mRootProcess) {
177 long const frameStartFile = mFileStream->getInPos();
178 for (
int arbor = 0; arbor < numArbors; arbor++) {
179 long const arborStartInFile = frameStartFile + (long)(arbor * arborSizeInPvpFile);
180 mFileStream->setInPos(arborStartInFile,
true );
184 int startPatchX, endPatchX, startPatchY, endPatchY;
185 calcPatchBox(startPatchX, endPatchX, startPatchY, endPatchY);
186 int lineCount = (endPatchX - startPatchX) * mWeights->getNumDataPatchesF();
188 PVLayerLoc const &preLoc = mWeights->getGeometry()->getPreLoc();
189 PVLayerLoc const &postLoc = mWeights->getGeometry()->getPostLoc();
191 int marginX = calcNeededBorder(preLoc.nx, postLoc.nx, nxp);
192 int nxExtended = preLoc.nx * mMPIBlock->getGlobalNumColumns() + marginX + marginX;
194 int marginY = calcNeededBorder(preLoc.ny, postLoc.ny, nyp);
195 int nyExtended = preLoc.ny * mMPIBlock->getGlobalNumRows() + marginY + marginY;
197 for (
int destRank = 0; destRank < mMPIBlock->getSize(); destRank++) {
198 int rowIndex, columnIndex, batchElemIndex;
199 mMPIBlock->calcRowColBatchFromRank(destRank, rowIndex, columnIndex, batchElemIndex);
201 for (
int y = 0; y < endPatchY - startPatchY; y++) {
202 int const startFileX = columnIndex * preLoc.nx;
203 int const startFileY = y + rowIndex * preLoc.ny;
204 int const startFile = kIndex(
205 startFileX, startFileY, 0, nxExtended, nyExtended, header.baseHeader.nf);
206 long lineStartInFile = arborStartInFile + (long)startFile * patchSizePvpFormat;
207 mFileStream->setInPos(lineStartInFile,
true );
209 int const startPatchLocal = kIndex(
213 mWeights->getNumDataPatchesX(),
214 mWeights->getNumDataPatchesY(),
215 mWeights->getNumDataPatchesF());
217 unsigned char *lineLocInBuffer =
218 &readBuffer[(long)startPatchLocal * patchSizePvpFormat];
219 std::size_t bufferSize = (std::size_t)lineCount * (std::size_t)patchSizePvpFormat;
220 mFileStream->read(lineLocInBuffer, bufferSize);
222 if (destRank == mRootProcess) {
223 loadWeightsFromBuffer(readBuffer, arbor, header.minVal, header.maxVal, compressed);
226 int tag = tagbase + arbor;
227 MPI_Comm comm = mMPIBlock->getComm();
228 MPI_Send(readBuffer.data(), (int)readBuffer.size(), MPI_BYTE, destRank, tag, comm);
234 for (
int arbor = 0; arbor < numArbors; arbor++) {
235 int tag = tagbase + arbor;
236 MPI_Comm comm = mMPIBlock->getComm();
239 (int)readBuffer.size(),
245 loadWeightsFromBuffer(readBuffer, arbor, header.minVal, header.maxVal, compressed);
248 return header.baseHeader.timestamp;
252 void WeightsFileIO::writeWeights(
double timestamp,
bool compress) {
253 if (mFileStream !=
nullptr and !mFileStream->writeable()) {
254 throw std::invalid_argument(
255 "WeightsFileIO::writeWeights called with a nonwriteable file stream");
257 if (mWeights->getSharedFlag()) {
258 writeSharedWeights(timestamp, compress);
261 writeNonsharedWeights(timestamp, compress);
265 void WeightsFileIO::writeSharedWeights(
double timestamp,
bool compress) {
266 if (mMPIBlock->getRank() != mRootProcess) {
269 float minWeight = mWeights->calcMinWeight();
270 float maxWeight = mWeights->calcMaxWeight();
271 BufferUtils::WeightHeader header = BufferUtils::buildSharedWeightHeader(
272 mWeights->getPatchSizeX(),
273 mWeights->getPatchSizeY(),
274 mWeights->getPatchSizeF(),
275 mWeights->getNumArbors(),
276 mWeights->getNumDataPatchesX(),
277 mWeights->getNumDataPatchesY(),
278 mWeights->getNumDataPatchesF(),
284 mFileStream->write(&header,
sizeof(header));
286 long arborSizeInPvpFile = calcArborSizeLocal(compress);
287 long arborSizeInPvpLocal = arborSizeInPvpFile;
288 std::vector<unsigned char> writeBuffer(arborSizeInPvpLocal);
290 int const numArbors = mWeights->getNumArbors();
291 for (
int arbor = 0; arbor < numArbors; arbor++) {
292 storeSharedPatches(writeBuffer, arbor, minWeight, maxWeight, compress);
293 mFileStream->write(writeBuffer.data(), arborSizeInPvpFile);
297 void WeightsFileIO::writeNonsharedWeights(
double timestamp,
bool compress) {
299 extrema[0] = mWeights->calcMinWeight();
300 extrema[1] = -mWeights->calcMaxWeight();
301 MPI_Allreduce(MPI_IN_PLACE, extrema, 2, MPI_FLOAT, MPI_MIN, mMPIBlock->getComm());
302 extrema[1] = -extrema[1];
304 long arborSizeInPvpFile = calcArborSizeFile(compress);
305 long arborSizeInPvpLocal = calcArborSizeLocal(compress);
306 std::vector<unsigned char> writeBuffer(arborSizeInPvpLocal);
308 int const numArbors = mWeights->getNumArbors();
309 if (mMPIBlock->getRank() == mRootProcess) {
311 BufferUtils::WeightHeader header = BufferUtils::buildNonsharedWeightHeader(
312 mWeights->getPatchSizeX(),
313 mWeights->getPatchSizeY(),
314 mWeights->getPatchSizeF(),
315 mWeights->getNumArbors(),
318 &mWeights->getGeometry()->getPreLoc(),
319 &mWeights->getGeometry()->getPostLoc(),
320 mMPIBlock->getNumColumns(),
321 mMPIBlock->getNumRows(),
325 mFileStream->write(&header,
sizeof(header));
327 long const frameStartFile = mFileStream->getOutPos();
328 for (
int arbor = 0; arbor < numArbors; arbor++) {
329 long const arborStartFile = frameStartFile + (long)(arbor * arborSizeInPvpFile);
330 mFileStream->setOutPos(arborStartFile,
true );
334 int startPatchX, endPatchX, startPatchY, endPatchY;
335 calcPatchBox(startPatchX, endPatchX, startPatchY, endPatchY);
336 int const numDataPatchesF = mWeights->getNumDataPatchesF();
337 int startPatchK = startPatchX * numDataPatchesF;
338 int endPatchK = endPatchX * numDataPatchesF;
339 int const numDataPatchesK = mWeights->getNumDataPatchesX() * numDataPatchesF;
341 int const nxp = mWeights->getPatchSizeX();
342 int const nyp = mWeights->getPatchSizeY();
343 int const nfp = mWeights->getPatchSizeF();
344 auto const patchSizePvpFormat = BufferUtils::weightPatchSize(nxp * nyp * nfp, compress);
346 for (
int sourceRank = 0; sourceRank < mMPIBlock->getSize(); sourceRank++) {
347 int rowIndex, columnIndex, batchElemIndex;
348 mMPIBlock->calcRowColBatchFromRank(sourceRank, rowIndex, columnIndex, batchElemIndex);
350 if (sourceRank == mRootProcess) {
351 storeNonsharedPatches(writeBuffer, arbor, extrema[0], extrema[1], compress);
354 int tag = tagbase + arbor;
355 MPI_Comm comm = mMPIBlock->getComm();
358 (int)writeBuffer.size(),
366 for (
int y = 0; y < endPatchY - startPatchY; y++) {
367 PVLayerLoc const &preLoc = mWeights->getGeometry()->getPreLoc();
368 int const startFileX = columnIndex * preLoc.nx;
369 int const startFileY = y + rowIndex * preLoc.ny;
370 int const startFile = kIndex(
374 header.baseHeader.nxExtended,
375 header.baseHeader.nyExtended,
376 header.baseHeader.nf);
377 long lineStartFile = arborStartFile + (long)startFile * (
long)patchSizePvpFormat;
378 mFileStream->setOutPos(lineStartFile, true );
380 int const startPatchLocal = kIndex(
384 mWeights->getNumDataPatchesX(),
385 mWeights->getNumDataPatchesY(),
386 mWeights->getNumDataPatchesF());
388 for (
int k = startPatchK; k < endPatchK; k++) {
389 int patchIndexLocal = kIndex(
390 k, y + startPatchY, 0, numDataPatchesK, mWeights->getNumDataPatchesY(), 1);
391 unsigned char *patchLocInBuffer =
392 &writeBuffer[patchIndexLocal * patchSizePvpFormat];
393 writePatch(patchLocInBuffer, compress);
404 long const frameEndFile = frameStartFile + (long)(numArbors * arborSizeInPvpFile);
405 mFileStream->setOutPos(0L, std::ios_base::end);
406 long const endOfFile = mFileStream->getOutPos();
407 if (endOfFile < frameEndFile) {
408 mFileStream->setOutPos(frameEndFile - 1L,
true );
409 mFileStream->write(
"\0", 1L);
411 mFileStream->setOutPos(frameEndFile,
true );
414 for (
int arbor = 0; arbor < numArbors; arbor++) {
415 storeNonsharedPatches(writeBuffer, arbor, extrema[0], extrema[1], compress);
416 int tag = tagbase + arbor;
417 MPI_Comm comm = mMPIBlock->getComm();
418 MPI_Send(writeBuffer.data(), (int)writeBuffer.size(), MPI_BYTE, mRootProcess, tag, comm);
423 void WeightsFileIO::writePatch(
unsigned char const *patchBuffer,
bool compressed) {
424 int const nxp = mWeights->getPatchSizeX();
425 int const nyp = mWeights->getPatchSizeY();
426 int const nfp = mWeights->getPatchSizeF();
434 patch.nx = (std::uint16_t)nxp;
435 patch.ny = (std::uint16_t)nyp;
436 patch.offset = (std::uint32_t)0;
437 mFileStream->write(&patch.nx,
sizeof(patch.nx));
438 mFileStream->write(&patch.ny,
sizeof(patch.ny));
439 mFileStream->write(&patch.offset,
sizeof(patch.offset));
442 memcpy(&patch.nx, patchBuffer,
sizeof(patch.nx));
443 memcpy(&patch.ny, &patchBuffer[
sizeof(patch.nx)],
sizeof(patch.ny));
444 memcpy(&patch.offset, &patchBuffer[
sizeof(patch.nx) +
sizeof(patch.ny)],
sizeof(patch.offset));
446 std::size_t patchHeaderSize =
sizeof(patch.nx) +
sizeof(patch.ny) +
sizeof(patch.offset);
447 std::size_t dataSize = compressed ?
sizeof(
unsigned char) :
sizeof(
float);
448 std::size_t patchDataStartOffset = patchHeaderSize + (std::size_t)patch.offset * dataSize;
449 unsigned char const *patchDataStart = &patchBuffer[patchDataStartOffset];
450 long patchStartInFile = mFileStream->getOutPos();
451 long patchEndInFile = patchStartInFile + (long)(nxp * nyp * nfp * (
int)dataSize);
452 mFileStream->setOutPos((
long)patch.offset * (
long)dataSize,
false );
453 if ((
int)patch.nx == nxp) {
455 long dataLength = (long)patch.ny * (
long)(nxp * nfp) * (
long)dataSize;
456 mFileStream->write(patchDataStart, dataLength);
460 std::size_t stride = (std::size_t)(nfp * nxp) * dataSize;
461 std::size_t lineLength = (std::size_t)nfp * (std::size_t)patch.nx * dataSize;
462 std::size_t skipLength = stride - lineLength;
463 for (std::uint16_t y = (std::uint16_t)0; y < patch.ny - (std::uint16_t)1; y++) {
464 unsigned char const *lineStartInBuffer = &patchDataStart[y * stride];
465 mFileStream->write(lineStartInBuffer, lineLength);
466 mFileStream->setOutPos(skipLength,
false );
468 if (patch.ny > (std::uint16_t)0) {
469 std::size_t lastLineOffset = (std::size_t)(patch.ny - 1) * stride;
470 unsigned char const *lineStartInBuffer = &patchDataStart[lastLineOffset];
471 mFileStream->write(lineStartInBuffer, lineLength);
474 mFileStream->setOutPos(patchEndInFile,
true );
479 void WeightsFileIO::moveToFrame(
483 fileStream.setInPos(0L,
true );
484 for (
int f = 0; f < frameNumber; f++) {
485 fileStream.read(&header,
sizeof(header));
486 long recordSize = (long)(header.baseHeader.recordSize * header.baseHeader.numRecords);
487 fileStream.setInPos(recordSize,
false );
489 fileStream.read(&header,
sizeof(header));
492 long WeightsFileIO::calcArborSizeFile(
bool compressed) {
493 int const nxp = mWeights->getPatchSizeX();
494 int const nyp = mWeights->getPatchSizeY();
495 int const nfp = mWeights->getPatchSizeF();
496 int const patchSize = (int)BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
499 if (mWeights->getSharedFlag()) {
500 numPatches = mWeights->getNumDataPatches();
503 PVLayerLoc const &preLoc = mWeights->getGeometry()->getPreLoc();
504 PVLayerLoc const &postLoc = mWeights->getGeometry()->getPostLoc();
506 int marginX = calcNeededBorder(preLoc.nx, postLoc.nx, mWeights->getPatchSizeX());
507 int numPatchesX = preLoc.nx * mMPIBlock->getGlobalNumColumns() + marginX + marginX;
509 int marginY = calcNeededBorder(preLoc.ny, postLoc.ny, mWeights->getPatchSizeY());
510 int numPatchesY = preLoc.ny * mMPIBlock->getGlobalNumRows() + marginY + marginY;
512 numPatches = numPatchesX * numPatchesY * preLoc.nf;
515 int const arborSize = numPatches * patchSize;
519 long WeightsFileIO::calcArborSizeLocal(
bool compressed) {
520 int const nxp = mWeights->getPatchSizeX();
521 int const nyp = mWeights->getPatchSizeY();
522 int const nfp = mWeights->getPatchSizeF();
523 int const patchSize = (int)BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
525 int numPatches = mWeights->getNumDataPatches();
527 int const arborSize = numPatches * patchSize;
532 void WeightsFileIO::calcPatchBox(
537 PVLayerLoc const &preLoc = mWeights->getGeometry()->getPreLoc();
538 PVLayerLoc const &postLoc = mWeights->getGeometry()->getPostLoc();
539 PVHalo const &preHalo = preLoc.halo;
541 int nxp = mWeights->getPatchSizeX();
542 calcPatchRange(preLoc.nx, postLoc.nx, preHalo.lt, preHalo.rt, nxp, startPatchX, endPatchX);
544 int nyp = mWeights->getPatchSizeY();
545 calcPatchRange(preLoc.ny, postLoc.ny, preHalo.up, preHalo.dn, nyp, startPatchY, endPatchY);
548 void WeightsFileIO::calcPatchRange(
556 int const neededBorder = calcNeededBorder(nPre, nPost, patchSize);
558 startPatch = (preStartBorder >= neededBorder) ? preStartBorder - neededBorder : 0;
559 endPatch = preStartBorder + nPre;
560 endPatch += (preEndBorder >= neededBorder) ? neededBorder : preEndBorder;
563 int WeightsFileIO::calcNeededBorder(
int nPre,
int nPost,
int patchSize) {
566 pvAssert(nPre % nPost == 0);
567 int stride = nPre / nPost;
568 pvAssert(stride % 2 == 0);
569 int halfstride = stride / 2;
570 neededBorder = (patchSize - 1) * halfstride;
572 else if (nPre < nPost) {
573 pvAssert(nPost % nPre == 0);
574 int tstride = nPost / nPre;
575 pvAssert(patchSize % tstride == 0);
576 neededBorder = patchSize / (2 * tstride);
579 pvAssert(nPre == nPost);
580 pvAssert(patchSize % 2 == 1);
581 neededBorder = (patchSize - 1) / 2;
586 void WeightsFileIO::loadWeightsFromBuffer(
587 std::vector<unsigned char>
const &dataFromFile,
592 int const nxp = mWeights->getPatchSizeX();
593 int const nyp = mWeights->getPatchSizeY();
594 int const nfp = mWeights->getPatchSizeF();
595 int const numPatches = mWeights->getNumDataPatches();
597 auto const patchSizePvpFormat = BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
598 std::size_t
const patchHeaderSize =
sizeof(
unsigned int) + 2UL *
sizeof(
unsigned short);
600 for (
int k = 0; k < numPatches; k++) {
601 std::size_t
const offsetInFile = patchSizePvpFormat * (std::size_t)k;
602 unsigned char const *patchFromFile = &dataFromFile[offsetInFile + patchHeaderSize];
603 float *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
604 decompressPatch(patchFromFile, weightsInPatch, nxp * nyp * nfp, minValue, maxValue);
608 for (
int k = 0; k < numPatches; k++) {
609 std::size_t
const offsetInFile = patchSizePvpFormat * (std::size_t)k;
610 unsigned char const *patchFromFile = &dataFromFile[offsetInFile + patchHeaderSize];
611 float *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
612 memcpy(weightsInPatch, patchFromFile, (std::size_t)(nxp * nyp * nfp) *
sizeof(
float));
617 void WeightsFileIO::decompressPatch(
618 unsigned char const *dataFromFile,
623 for (
int k = 0; k < count; k++) {
624 float compressedWeight = (float)dataFromFile[k] / 255.0f;
625 destWeights[k] = (compressedWeight) * (maxValue - minValue) + minValue;
630 void WeightsFileIO::storeSharedPatches(
631 std::vector<unsigned char> &dataFromFile,
636 int const nxp = mWeights->getPatchSizeX();
637 int const nyp = mWeights->getPatchSizeY();
638 int const nfp = mWeights->getPatchSizeF();
640 int const numDataPatches = mWeights->getNumDataPatches();
641 auto const patchSizePvpFormat = BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
642 std::size_t
const patchHeaderSize =
sizeof(
unsigned int) + 2UL *
sizeof(
unsigned short);
644 for (
int k = 0; k < numDataPatches; k++) {
645 std::size_t
const offsetInFile = patchSizePvpFormat * (std::size_t)k;
646 unsigned char *patchFromFile = &dataFromFile[offsetInFile];
647 unsigned short shortDim;
648 shortDim = (
unsigned short)nxp;
649 memcpy(patchFromFile, &shortDim,
sizeof(shortDim));
650 shortDim = (
unsigned short)nyp;
651 memcpy(&patchFromFile[
sizeof(shortDim)], &shortDim,
sizeof(shortDim));
654 memset(&patchFromFile[2UL *
sizeof(shortDim)], 0,
sizeof(
unsigned int));
655 patchFromFile += patchHeaderSize;
656 float const *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
657 compressPatch(patchFromFile, weightsInPatch, nxp * nyp * nfp, minValue, maxValue);
661 for (
int k = 0; k < numDataPatches; k++) {
662 std::size_t
const offsetInFile = patchSizePvpFormat * (std::size_t)k;
663 unsigned char *patchFromFile = &dataFromFile[offsetInFile];
664 unsigned short shortDim;
665 shortDim = (
unsigned short)nxp;
666 memcpy(patchFromFile, &shortDim,
sizeof(shortDim));
667 shortDim = (
unsigned short)nyp;
668 memcpy(&patchFromFile[
sizeof(shortDim)], &shortDim,
sizeof(shortDim));
671 memset(&patchFromFile[2UL *
sizeof(shortDim)], 0,
sizeof(
unsigned int));
672 patchFromFile += patchHeaderSize;
673 float *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
674 memcpy(patchFromFile, weightsInPatch, (std::size_t)(nxp * nyp * nfp) *
sizeof(
float));
679 void WeightsFileIO::storeNonsharedPatches(
680 std::vector<unsigned char> &dataFromFile,
685 int const nxp = mWeights->getPatchSizeX();
686 int const nyp = mWeights->getPatchSizeY();
687 int const nfp = mWeights->getPatchSizeF();
688 int const numDataPatches = mWeights->getNumDataPatches();
690 auto const patchSizePvpFormat = BufferUtils::weightPatchSize(nxp * nyp * nfp, compressed);
691 std::size_t
const patchHeaderSize =
sizeof(std::uint32_t) + 2UL *
sizeof(std::uint16_t);
693 for (
int k = 0; k < numDataPatches; k++) {
694 std::size_t
const offsetInFile = patchSizePvpFormat * (std::size_t)k;
695 unsigned char *patchFromFile = &dataFromFile[offsetInFile];
697 Patch const &patch = mWeights->getPatch(k);
698 std::uint16_t shortDim;
699 shortDim = (std::uint16_t)patch.nx;
700 memcpy(patchFromFile, &shortDim,
sizeof(shortDim));
701 shortDim = (std::uint16_t)patch.ny;
702 memcpy(&patchFromFile[
sizeof(shortDim)], &shortDim,
sizeof(shortDim));
703 std::uint32_t offset = (std::uint32_t)patch.offset;
704 memcpy(&patchFromFile[2UL *
sizeof(shortDim)], &offset,
sizeof(offset));
705 patchFromFile += patchHeaderSize;
706 float const *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
707 compressPatch(patchFromFile, weightsInPatch, nxp * nyp * nfp, minValue, maxValue);
711 for (
int k = 0; k < numDataPatches; k++) {
712 std::size_t
const offsetInFile = patchSizePvpFormat * (std::size_t)k;
713 unsigned char *patchFromFile = &dataFromFile[offsetInFile];
715 Patch const &patch = mWeights->getPatch(k);
716 std::uint16_t shortDim;
717 shortDim = (std::uint16_t)patch.nx;
718 memcpy(patchFromFile, &shortDim,
sizeof(shortDim));
719 shortDim = (std::uint16_t)patch.ny;
720 memcpy(&patchFromFile[
sizeof(shortDim)], &shortDim,
sizeof(shortDim));
721 std::uint32_t offset = (std::uint32_t)patch.offset;
722 memcpy(&patchFromFile[2UL *
sizeof(shortDim)], &offset,
sizeof(offset));
723 patchFromFile += patchHeaderSize;
724 float const *weightsInPatch = mWeights->getDataFromDataIndex(arbor, k);
725 memcpy(patchFromFile, weightsInPatch, (std::size_t)(nxp * nyp * nfp) *
sizeof(
float));
730 void WeightsFileIO::compressPatch(
731 unsigned char *dataForFile,
732 float const *sourceWeights,
736 for (
int k = 0; k < count; k++) {
737 float compressedWeight = (sourceWeights[k] - minValue) / (maxValue - minValue);
738 dataForFile[k] = (
unsigned char)std::floor(255.0f * compressedWeight);