LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/tests - XLSnnModelTest.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 173 357 48.5 %
Date: 2024-05-06 04:51:23 Functions: 21 30 70.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnModelTest.h"
      24             : #include "XLSnnModel.h"
      25             : #include "XLSnnLayer.h"
      26             : #include "XLSnnAttachment.h"
      27             : #include "XLSnnVkInputLayer.h"
      28             : #include "XLVkAttachment.h"
      29             : #include "XLCoreFrameRequest.h"
      30             : #include "SPFilesystem.h"
      31             : #include "SPBitmap.h"
      32             : #include "SPValid.h"
      33             : 
      34             : #include <byteswap.h>
      35             : 
      36             : namespace stappler::xenolith::shadernn {
      37             : 
      38           1 : void MnistTrainData::loadVectors(StringView ipath) {
      39           1 :         auto path = filepath::merge<Interface>(ipath, "train-labels.idx1-ubyte");
      40           1 :         auto vectors = ::fopen(path.data(), "r");
      41           1 :         if (vectors) {
      42           1 :                 ::fseek(vectors, 0, SEEK_END);
      43           1 :                 auto fsize = ::ftell(vectors);
      44           1 :                 ::fseek(vectors, 0, SEEK_SET);
      45             : 
      46           1 :                 ::fread(&vectorsHeader, sizeof(VectorsHeader), 1, vectors);
      47             : 
      48           1 :                 vectorsHeader.magic = bswap_32(vectorsHeader.magic);
      49           1 :                 vectorsHeader.items = bswap_32(vectorsHeader.items);
      50             : 
      51           1 :                 auto dataSize = fsize - sizeof(ImagesHeader);
      52             : 
      53           1 :                 uint8_t *buf = (uint8_t *)::malloc(dataSize);
      54           1 :                 ::fread(buf, dataSize, 1, vectors);
      55             : 
      56           1 :                 vectorsData = (float *)malloc(dataSize * sizeof(float) * 10);
      57           1 :                 ::memset(vectorsData, 0, dataSize * sizeof(float));
      58           1 :                 auto ptr = vectorsData;
      59       59993 :                 for (size_t i = 0; i < dataSize; ++ i) {
      60      659912 :                         for (size_t j = 0; j < 10; ++ j) {
      61      599920 :                                 ptr[j] = 0.0f;
      62             :                         }
      63       59992 :                         ptr[buf[i]] = 1.0f;
      64       59992 :                         ptr += 10;
      65             :                 }
      66             : 
      67           1 :                 ::free(buf);
      68           1 :                 ::fclose(vectors);
      69             :         }
      70           1 : }
      71             : 
      72           1 : void MnistTrainData::loadImages(StringView ipath) {
      73           1 :         auto path = filepath::merge<Interface>(ipath, "train-images.idx3-ubyte");
      74           1 :         auto images = ::fopen(path.data(), "r");
      75           1 :         if (images) {
      76           1 :                 ::fseek(images, 0, SEEK_END);
      77           1 :                 auto fsize = ::ftell(images);
      78           1 :                 ::fseek(images, 0, SEEK_SET);
      79             : 
      80           1 :                 ::fread(&imagesHeader, sizeof(ImagesHeader), 1, images);
      81             : 
      82           1 :                 imagesHeader.magic = bswap_32(imagesHeader.magic);
      83           1 :                 imagesHeader.images = bswap_32(imagesHeader.images);
      84           1 :                 imagesHeader.rows = bswap_32(imagesHeader.rows);
      85           1 :                 imagesHeader.columns = bswap_32(imagesHeader.columns);
      86             : 
      87           1 :                 auto dataSize = fsize - sizeof(ImagesHeader);
      88             : 
      89             : 
      90           1 :                 uint8_t *buf = (uint8_t *)::malloc(dataSize);
      91             : 
      92           1 :                 ::fread(buf, dataSize, 1, images);
      93             : 
      94           1 :                 imagesData = (float *)malloc(dataSize * sizeof(float));
      95           1 :                 auto ptr = imagesData;
      96    47040001 :                 for (size_t i = 0; i < dataSize; ++ i) {
      97    47040000 :                         *ptr ++ = float(buf[i]) / 255.0f;
      98             :                 }
      99             : 
     100           1 :                 ::free(buf);
     101           1 :                 ::fclose(images);
     102             :         }
     103           1 : }
     104             : 
     105           1 : void MnistTrainData::loadIndexes() {
     106           1 :         indexes.resize(imagesHeader.images);
     107       60001 :         for (size_t i = 0; i < imagesHeader.images; ++ i) {
     108       60000 :                 indexes[i] = i;
     109             :         }
     110           1 : }
     111             : 
     112        1200 : void MnistTrainData::readImages(uint8_t *iptr, uint64_t size, size_t offset) {
     113        1200 :         auto count = size / (imagesHeader.rows * imagesHeader.columns * sizeof(float));
     114             : 
     115        1200 :         float *ptr = (float *)iptr;
     116             : 
     117      121200 :         for (size_t i = 0; i < count; ++ i) {
     118      120000 :         auto idx = indexes[offset + i];
     119      120000 :         auto blockSize = imagesHeader.rows * imagesHeader.columns;
     120      120000 :         ::memcpy(ptr + i * blockSize, imagesData + idx * blockSize, blockSize * sizeof(float));
     121             :         }
     122        1200 : }
     123             : 
     124        1200 : void MnistTrainData::readLabels(uint8_t *iptr, uint64_t size, size_t offset) {
     125        1200 :         auto count = size / (10 * sizeof(float));
     126             : 
     127        1200 :         float *ptr = (float *)iptr;
     128             : 
     129      121200 :         for (size_t i = 0; i < count; ++ i) {
     130      120000 :         auto idx = indexes[offset + i];
     131      120000 :         auto blockSize = 10;
     132      120000 :         ::memcpy(ptr + i * blockSize, vectorsData + idx * blockSize, blockSize * sizeof(float));
     133             :         }
     134        1200 : }
     135             : 
     136           0 : bool MnistTrainData::validateImages(const uint8_t *iptr, uint64_t size, size_t offset) {
     137           0 :         auto count = size / (imagesHeader.rows * imagesHeader.columns * sizeof(float));
     138             : 
     139           0 :         const float *ptr = (const float *)iptr;
     140             : 
     141           0 :         for (size_t i = 0; i < count; ++ i) {
     142           0 :         auto idx = indexes[offset + i];
     143           0 :         auto blockSize = imagesHeader.rows * imagesHeader.columns;
     144             : 
     145           0 :         auto sourcePtr = ptr + i * blockSize;
     146           0 :         auto targetPtr = imagesData + idx * blockSize;
     147             : 
     148           0 :         if (::memcmp(sourcePtr, targetPtr, blockSize * sizeof(float)) != 0) {
     149           0 :                 for (size_t j = 0; j < imagesHeader.rows; ++ j) {
     150           0 :                         for (size_t k = 0; k < imagesHeader.columns; ++ k) {
     151           0 :                                 std::cout << " " << *(sourcePtr + (j * imagesHeader.rows) + k);
     152             :                         }
     153           0 :                         std::cout << "\n";
     154             :                 }
     155           0 :                 std::cout << "\n";
     156           0 :                 for (size_t j = 0; j < imagesHeader.rows; ++ j) {
     157           0 :                         for (size_t k = 0; k < imagesHeader.columns; ++ k) {
     158           0 :                                 std::cout << " " << *(targetPtr + (j * imagesHeader.rows) + k);
     159             :                         }
     160           0 :                         std::cout << "\n";
     161             :                 }
     162           0 :                 std::cout << "\n";
     163           0 :                 return false;
     164             :         }
     165             :         }
     166           0 :         return true;
     167             : }
     168             : 
     169             : struct StdRandom {
     170             :         using result_type = unsigned int;
     171             : 
     172             :         static constexpr auto min() { return std::numeric_limits<unsigned int>::min(); };
     173             :         static constexpr auto max() { return std::numeric_limits<unsigned int>::max(); };
     174             : 
     175       71432 :         unsigned int operator() () {
     176       71432 :                 return rnd->next();
     177             :         }
     178             : 
     179             :         Random *rnd;
     180             : };
     181             : 
     182           2 : void MnistTrainData::shuffle(Random &rnd) {
     183           2 :     std::shuffle(indexes.begin(), indexes.end(), StdRandom{&rnd});
     184           2 : }
     185             : 
     186           1 : MnistTrainData::MnistTrainData(StringView path) {
     187           1 :         loadVectors(path);
     188           1 :         loadImages(path);
     189           1 :         loadIndexes();
     190           1 : }
     191             : 
     192           2 : MnistTrainData::~MnistTrainData() {
     193           1 :         if (imagesData) {
     194           1 :                 ::free(imagesData);
     195           1 :                 imagesData = nullptr;
     196             :         }
     197           1 :         if (vectorsData) {
     198           1 :                 ::free(vectorsData);
     199           1 :                 vectorsData = nullptr;
     200             :         }
     201           2 : }
     202             : 
     203           0 : Rc<CsvData> ModelQueue::readCsv(StringView data) {
     204           0 :         Rc<CsvData> ret = Rc<CsvData>::alloc();
     205             : 
     206           0 :         auto readQuoted = [] (StringView &r) {
     207           0 :                 StringView tmp(r);
     208             : 
     209           0 :                 while (r.is("\"\"")) {
     210           0 :                         r += 2;
     211             :                 }
     212             : 
     213           0 :                 while (!r.empty() && !r.is('"')) {
     214           0 :                         r.skipUntil<StringView::Chars<'"', '\\'>>();
     215           0 :                         if (r.is('\\')) {
     216           0 :                                 r += 2;
     217             :                         }
     218           0 :                         while (r.is("\"\"")) {
     219           0 :                                 r += 2;
     220             :                         }
     221             :                 }
     222             : 
     223           0 :                 tmp = StringView(tmp.data(), r.data() - tmp.data());
     224             : 
     225           0 :                 if (r.is('"')) {
     226           0 :                         ++ r;
     227             :                 }
     228             : 
     229           0 :                 return tmp;
     230             :         };
     231             : 
     232           0 :         auto readHeader = [&] (StringView &r) {
     233           0 :                 while (!r.empty() && !r.is('\n') && !r.is("\r")) {
     234           0 :                         r.skipChars<StringView::WhiteSpace>();
     235           0 :                         StringView tmp(r);
     236           0 :                         if (r.is('"')) {
     237           0 :                                 ++ r;
     238           0 :                                 ret->fields.emplace_back(readQuoted(r).str<Interface>());
     239           0 :                                 r.skipUntil<StringView::Chars<',', '\n', '\r'>>();
     240             :                         } else {
     241           0 :                                 auto tmp = r.readUntil<StringView::Chars<',', '\n', '\r'>>();
     242           0 :                                 tmp.trimChars<StringView::WhiteSpace>();
     243           0 :                                 ret->fields.emplace_back(readQuoted(r).str<Interface>());
     244             :                         }
     245           0 :                         if (r.is(',')) {
     246           0 :                                 ++ r;
     247             :                         }
     248             :                 }
     249           0 :                 if (r.is('\n') || r.is('\r')) {
     250           0 :                         r.skipChars<StringView::WhiteSpace>();
     251             :                 }
     252           0 :         };
     253             : 
     254           0 :         auto validateFloat = [] (StringView str) {
     255           0 :                 if (str.empty()) {
     256           0 :                         return false;
     257             :                 }
     258             : 
     259           0 :                 StringView r(str);
     260           0 :                 if (r.is('-')) { ++ r; }
     261           0 :                 r.skipChars<chars::Range<char, '0', '9'>, StringView::Chars<'.'>>();
     262           0 :                 if (!r.empty()) {
     263           0 :                         return false;
     264             :                 }
     265             : 
     266           0 :                 return true;
     267             :         };
     268             : 
     269           0 :         auto readLine = [&] (StringView &r) {
     270           0 :                 Value ret;
     271           0 :                 while (!r.empty() && !r.is('\n') && !r.is("\r")) {
     272           0 :                         r.skipChars<StringView::WhiteSpace>();
     273           0 :                         StringView tmp(r);
     274           0 :                         if (r.is('"')) {
     275           0 :                                 ++ r;
     276           0 :                                 auto data = readQuoted(r);
     277           0 :                                 if (valid::validateNumber(data)) {
     278           0 :                                         ret.addInteger(data.readInteger(10).get(0));
     279           0 :                                 } else if (validateFloat(data)) {
     280           0 :                                         ret.addInteger(data.readInteger(10).get(0));
     281             :                                 } else {
     282           0 :                                         ret.addString(data.str<Interface>());
     283             :                                 }
     284           0 :                                 r.skipUntil<StringView::Chars<',', '\n', '\r'>>();
     285             :                         } else {
     286           0 :                                 auto tmp = r.readUntil<StringView::Chars<',', '\n', '\r'>>();
     287           0 :                                 tmp.trimChars<StringView::WhiteSpace>();
     288           0 :                                 if (valid::validateNumber(tmp)) {
     289           0 :                                         ret.addInteger(tmp.readInteger(10).get(0));
     290           0 :                                 } else if (validateFloat(data)) {
     291           0 :                                         ret.addInteger(data.readInteger(10).get(0));
     292             :                                 } else {
     293           0 :                                         ret.addString(tmp.str<Interface>());
     294             :                                 }
     295             :                         }
     296           0 :                         if (r.is(',')) {
     297           0 :                                 ++ r;
     298             :                         }
     299             :                 }
     300           0 :                 if (r.is('\n') || r.is('\r')) {
     301           0 :                         r.skipChars<StringView::WhiteSpace>();
     302             :                 }
     303           0 :                 return ret;
     304           0 :         };
     305             : 
     306           0 :         while (!data.empty()) {
     307           0 :                 if (ret->fields.empty()) {
     308           0 :                         readHeader(data);
     309             :                 } else {
     310           0 :                         if (auto val = readLine(data)) {
     311           0 :                                 if (!val.getValue(0).isInteger()) {
     312           0 :                                         std::cout << val << "\n";
     313             :                                 }
     314           0 :                                 ret->data.emplace_back(move(val));
     315           0 :                         }
     316             :                 }
     317             :         }
     318             : 
     319           0 :         return ret;
     320           0 : }
     321             : 
     322           2 : ModelQueue::~ModelQueue() { }
     323             : 
     324           1 : bool ModelQueue::init(StringView modelPath, ModelFlags flags, StringView input) {
     325           1 :         _processor = Rc<ModelProcessor>::create();
     326           1 :         _model = _processor->load(FilePath(modelPath), flags);
     327             : 
     328           1 :         if (!_model) {
     329           0 :                 return false;
     330             :         }
     331             : 
     332           1 :         core::Queue::Builder builder(filepath::name(modelPath));
     333             : 
     334           1 :         Extent3 frameExtent(1, 1, 1);
     335           1 :         if (input.starts_with("mnist:")) {
     336           1 :                 _trainData = Rc<MnistTrainData>::alloc(input.sub(6));
     337           0 :         } else if (input.starts_with("csv:")) {
     338           0 :                 auto data = filesystem::readIntoMemory<Interface>(input.sub(4));
     339           0 :                 if (!data.empty()) {
     340           0 :                         _csvData = readCsv(StringView((const char *)data.data(), data.size()));
     341             :                 } else {
     342           0 :                         return false;
     343             :                 }
     344           0 :         } else {
     345           0 :                 if (!bitmap::getImageSize(input, frameExtent.width, frameExtent.height)) {
     346           0 :                         log::error("InputQueue", "fail to read image: ", input);
     347           0 :                         return false;
     348             :                 }
     349             :         }
     350             : 
     351           1 :         _image = input.str<Interface>();
     352             : 
     353           1 :         Map<Layer *, const core::AttachmentData *> inputs;
     354           1 :         Map<Attachment *, const core::AttachmentData *> attachments;
     355             : 
     356           1 :         auto modelInputs = _model->getInputs();
     357           3 :         for (auto &it : modelInputs) {
     358           2 :                 inputs.emplace(it, it->makeInputAttachment(builder));
     359             :         }
     360             : 
     361           1 :         size_t i = 0;
     362           7 :         for (auto &it : _model->getSortedLayers()) {
     363           6 :                 auto output = it->getOutput();
     364           6 :                 auto attachment = it->makeOutputAttachment(builder, _model->getSortedLayers().size() == i + 1);
     365           6 :                 attachments.emplace(output, attachment);
     366           6 :                 if (_model->getSortedLayers().size() == i + 1) {
     367           1 :                         _outputAttachment = attachment;
     368             :                 }
     369           6 :                 ++ i;
     370             :         }
     371             : 
     372           7 :         for (auto &it : _model->getSortedLayers()) {
     373           6 :                 auto ret = it->prepare(builder, inputs, attachments);
     374           6 :                 if (!ret) {
     375           0 :                         return false;
     376             :                 }
     377           6 :                 if (it->isInput()) {
     378           2 :                         _inputAttachments.emplace_back(static_cast<vk::shadernn::InputLayer *>(ret->pass.get())->getDataAttachment());
     379             :                 }
     380             :         }
     381             : 
     382           1 :         return Queue::init(move(builder));
     383           1 : }
     384             : 
     385        1200 : void ModelQueue::run(Application *app) {
     386        1200 :         _app = app;
     387             : 
     388        1200 :         Extent3 frameExtent(1, 1, 1);
     389        1200 :         if (!_trainData && !_csvData) {
     390           0 :                 if (!bitmap::getImageSize(StringView(_image), frameExtent.width, frameExtent.height)) {
     391           0 :                         log::error("InputQueue", "fail to read image: ", _image);
     392           0 :                         return;
     393             :                 }
     394        1200 :         } else if (_csvData) {
     395           0 :                 frameExtent.width = _csvData->fields.size();
     396           0 :                 frameExtent.height = _csvData->data.size();
     397             :         }
     398             : 
     399        1200 :         auto req = Rc<core::FrameRequest>::create(Rc<core::Queue>(this), core::FrameContraints{frameExtent});
     400             : 
     401        1200 :         ModelSpecialization spec = _processor->specializeModel(_model, frameExtent);
     402             : 
     403        8400 :         for (auto &it : spec.attachments) {
     404        7200 :                 if (auto a = getAttachment(it.first->getName())) {
     405           0 :                         if (auto img = dynamic_cast<vk::ImageAttachment *>(a->attachment.get())) {
     406           0 :                                 core::ImageInfoData info = img->getImageInfo();
     407           0 :                                 info.extent = it.second;
     408           0 :                                 log::debug("ModelQueue", "Specialize attachment ", it.first->getName(), " for extent ", info.extent);
     409           0 :                                 req->addImageSpecialization(img, move(info));
     410             :                         }
     411             :                 }
     412             :         }
     413             : 
     414        3600 :         for (auto &it : _inputAttachments) {
     415        2400 :                 if (it->type == core::AttachmentType::Image) {
     416           0 :                         auto inputData = Rc<vk::shadernn::InputDataInput>::alloc();
     417           0 :                         inputData->norm = vk::shadernn::NormData{
     418             :                                 Vec4(-0.5f, -0.5f, -0.5f, -0.5f), // to -0.5 - 0.5
     419             :                                 Vec4(2.0f, 2.0f, 2.0f, 2.0f) // to -1.0 - 1.0
     420             :                         };
     421           0 :                         inputData->image.extent = frameExtent;
     422           0 :                         inputData->image.stdCallback = [path = _image] (uint8_t *ptr, uint64_t size, const core::ImageData::DataCallback &dcb) {
     423           0 :                                 core::Resource::loadImageFileData(ptr, size, path, core::ImageFormat::R8G8B8A8_UNORM, dcb);
     424           0 :                         };
     425             : 
     426           0 :                         req->addInput(it, move(inputData));
     427        2400 :                 } else if (_trainData) {
     428        2400 :                         if (it->key == "input_samples_buffer") {
     429        1200 :                                 auto inputData = Rc<vk::shadernn::InputBufferDataInput>::alloc();
     430        4800 :                                 inputData->buffer.stdCallback = [data = _trainData, offset = _loadOffset] (uint8_t *ptr, uint64_t size, const core::BufferData::DataCallback &cb) {
     431        1200 :                                         data->readImages(ptr, size, offset);
     432        1200 :                                 };
     433        1200 :                                 req->addInput(it, move(inputData));
     434        2400 :                         } else if (it->key == "input_labels_buffer") {
     435        1200 :                                 auto inputData = Rc<vk::shadernn::InputBufferDataInput>::alloc();
     436        4800 :                                 inputData->buffer.stdCallback = [data = _trainData, offset = _loadOffset] (uint8_t *ptr, uint64_t size, const core::BufferData::DataCallback &cb) {
     437        1200 :                                         data->readLabels(ptr, size, offset);
     438        1200 :                                 };
     439        1200 :                                 req->addInput(it, move(inputData));
     440        1200 :                         }
     441           0 :                 } else if (_csvData) {
     442           0 :                         auto inputData = Rc<vk::shadernn::InputCsvInput>::alloc();
     443           0 :                         inputData->data = _csvData->data;
     444           0 :                         req->addInput(it, move(inputData));
     445           0 :                 }
     446             :         }
     447             : 
     448        1200 :         req->setOutput(getOutputAttachment(), [this, app, trainData = _trainData, csvData = _csvData] (core::FrameAttachmentData &data, bool success, Ref *) {
     449        1200 :                 if (data.image) {
     450           0 :                         app->getGlLoop()->captureImage([app] (core::ImageInfoData info, BytesView view) {
     451           0 :                                 if (!view.empty()) {
     452           0 :                                         Vector<Bitmap> planes;
     453           0 :                                         for (size_t i = 0; i < info.extent.depth; ++ i) {
     454           0 :                                                 auto &b = planes.emplace_back(Bitmap());
     455           0 :                                                 b.alloc(info.extent.width, info.extent.height, bitmap::PixelFormat::RGBA8888, bitmap::AlphaFormat::Premultiplied);
     456             :                                         }
     457             : 
     458           0 :                                         Vector<uint8_t *> ptrs;
     459           0 :                                         for (auto &it : planes) {
     460           0 :                                                 ptrs.emplace_back(it.dataPtr());
     461             :                                         }
     462             : 
     463           0 :                                         std::cout << view.size() << " ";
     464             : 
     465           0 :                                         for (size_t i = 0; i < info.extent.width; ++ i) {
     466           0 :                                                 for (size_t j = 0; j < info.extent.height; ++ j) {
     467           0 :                                                         for (size_t k = 0; k < info.extent.depth; ++ k) {
     468           0 :                                                                 for (size_t f = 0; f < 4; ++ f) {
     469           0 :                                                                         *ptrs[k] = uint8_t((view.readFloat32() + 1.0f) * 127.5f);
     470             :                                                                         if (f == 3) {
     471             :                                                                                 //*ptrs[k] = 255;
     472             :                                                                         }
     473           0 :                                                                         ptrs[k] ++;
     474             :                                                                 }
     475             :                                                         }
     476             :                                                 }
     477             :                                         }
     478             : 
     479           0 :                                         std::cout << view.size() << "\n";
     480             : 
     481           0 :                                         for (size_t i = 0; i < info.extent.depth; ++ i) {
     482           0 :                                                 planes[i].save(toString(i, "_", Time::now().toMicros(), ".png"));
     483             :                                         }
     484           0 :                                 }
     485           0 :                                 app->end();
     486           0 :                         }, data.image->getImage(), core::AttachmentLayout::General);
     487        1200 :                 } else if (auto buf = dynamic_cast<vk::BufferAttachmentHandle *>(data.handle.get())) {
     488        1200 :                         if (trainData) {
     489        1200 :                                 auto target = buf->getBuffers().back().buffer;
     490        1200 :                                 app->getGlLoop()->captureBuffer([this, app] (const core::BufferInfo &info, BytesView view) {
     491             :                                         //front->synchronizeParameters(SpanView<float>((float *)view.data(), view.size() / sizeof(float)));
     492        1200 :                                         onComplete(app, view);
     493        1200 :                                 }, target);
     494             : 
     495        1200 :                         } else if (csvData) {
     496           0 :                                 auto target = buf->getBuffers().front().buffer;
     497           0 :                                 app->getGlLoop()->captureBuffer([app, trainData] (core::BufferInfo info, BytesView view) {
     498           0 :                                         std::cout << view.size() / (sizeof(float) * 4) << "\n";
     499           0 :                                         std::cout << "0";
     500             : 
     501           0 :                                         size_t row = 1;
     502           0 :                                         size_t i = 0;
     503           0 :                                         while (!view.empty()) {
     504           0 :                                                 switch (i) {
     505           0 :                                                 case 0: {
     506           0 :                                                         auto u = view.readUnsigned32();
     507           0 :                                                         if (u > uint32_t(maxOf<int32_t>())) {
     508           0 :                                                                 std::cout << ", " << static_cast<int32_t>(u);
     509             :                                                         } else {
     510           0 :                                                                 std::cout << ", " << u;
     511             :                                                         }
     512           0 :                                                         break;
     513             :                                                 }
     514           0 :                                                 case 1:
     515           0 :                                                         std::cout << ", " << view.readUnsigned32();
     516           0 :                                                         break;
     517           0 :                                                 case 2:
     518             :                                                 case 3:
     519           0 :                                                         std::cout << ", " << view.readFloat32();
     520           0 :                                                         break;
     521           0 :                                                 default:
     522           0 :                                                         view.readUnsigned32();
     523           0 :                                                         break;
     524             :                                                 }
     525           0 :                                                 ++ i;
     526           0 :                                                 if (i > 3) {
     527           0 :                                                         i = 0;
     528           0 :                                                         std::cout << "\n" << row ++;
     529             :                                                 }
     530             :                                         }
     531           0 :                                         std::cout << "\n";
     532           0 :                                 }, target);
     533           0 :                         }
     534             :                 }
     535        1200 :                 return true;
     536             :         });
     537             : 
     538        1200 :         app->getGlLoop()->runRenderQueue(move(req), 0);
     539        1200 : }
     540             : 
     541        1200 : void ModelQueue::onComplete(Application *app, BytesView data) {
     542             :         //auto iter = _loadOffset / 100;
     543        1200 :         _loadOffset += 100;
     544             : 
     545        1200 :         data.readFloat32();
     546        1200 :         auto loss = data.readFloat32();
     547             : 
     548        1200 :         if (_loadOffset < 60000) {
     549             :                 // 5  std::cout << iter << " Loss: " << loss << "\n";
     550        1198 :                 _epochLoss += loss;
     551        1198 :                 app->performOnMainThread([this, app] {
     552        1198 :                         run(_app);
     553        1198 :                 });
     554             :         } else {
     555           2 :                 std::cout << _epoch << ": avg loss: " << _epochLoss / 600.0f << "\n";
     556             : 
     557           2 :                 _trainData->shuffle(_model->getRand());
     558           2 :                 _epochLoss = 0;
     559           2 :                 _loadOffset = 0;
     560           2 :                 ++ _epoch;
     561             : 
     562           2 :                 if (_epoch < _endEpoch) {
     563           1 :                         app->performOnMainThread([this, app] {
     564           1 :                                 run(app);
     565           1 :                         });
     566             :                 } else {
     567           1 :                         release(0);
     568           1 :                         app->end();
     569             :                 }
     570             :         }
     571        1200 : }
     572             : 
     573             : }

Generated by: LCOV version 1.14