LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/tests - XLSnnModelTest.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 1 1 100.0 %
Date: 2024-05-06 04:51:23 Functions: 1 1 100.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #ifndef TESTS_XLSNNMODELTEST_H_
      24             : #define TESTS_XLSNNMODELTEST_H_
      25             : 
      26             : #include "XLCoreQueue.h"
      27             : #include "XLApplication.h"
      28             : #include "XLSnnModel.h"
      29             : #include "XLSnnModelProcessor.h"
      30             : 
      31             : namespace stappler::xenolith::shadernn {
      32             : 
      33             : struct MnistTrainData : Ref {
      34             :         struct ImagesHeader {
      35             :                 uint32_t magic;
      36             :                 uint32_t images;
      37             :                 uint32_t rows;
      38             :                 uint32_t columns;
      39             :         };
      40             : 
      41             :         struct VectorsHeader {
      42             :                 uint32_t magic;
      43             :                 uint32_t items;
      44             :         };
      45             : 
      46             :         MnistTrainData(StringView path);
      47             :         ~MnistTrainData();
      48             : 
      49             :         void loadVectors(StringView ipath);
      50             :         void loadImages(StringView ipath);
      51             :         void loadIndexes();
      52             : 
      53             :         void readImages(uint8_t *ptr, uint64_t size, size_t offset);
      54             :         void readLabels(uint8_t *iptr, uint64_t size, size_t offset);
      55             : 
      56             :         bool validateImages(const uint8_t *ptr, uint64_t size, size_t offset);
      57             : 
      58             :         void shuffle(Random &rnd);
      59             : 
      60             :         ImagesHeader imagesHeader;
      61             :         VectorsHeader vectorsHeader;
      62             : 
      63             :         float *imagesData = 0;
      64             :         float *vectorsData = 0;
      65             : 
      66             :         size_t imagesSize = 0;
      67             :         size_t vectorsSize = 0;
      68             : 
      69             :         std::vector<uint32_t> indexes;
      70             : };
      71             : 
      72             : struct CsvData : Ref {
      73             :         Vector<String> fields;
      74             :         Vector<Value> data;
      75             : };
      76             : 
      77             : class ModelQueue : public core::Queue {
      78             : public:
      79             :         static Rc<CsvData> readCsv(StringView);
      80             : 
      81             :         virtual ~ModelQueue();
      82             : 
      83             :         bool init(StringView model, ModelFlags, StringView input);
      84             : 
      85             :         const Vector<const AttachmentData *> &getInputAttachments() const { return _inputAttachments; }
      86        1200 :         const AttachmentData *getOutputAttachment() const { return _outputAttachment; }
      87             : 
      88             :         void run(Application *);
      89             : 
      90             : protected:
      91             :         using core::Queue::init;
      92             : 
      93             :         void onComplete(Application *, BytesView data);
      94             : 
      95             :         Application *_app = nullptr;
      96             :         String _image;
      97             :         Vector<const AttachmentData *> _inputAttachments;
      98             :         const AttachmentData *_outputAttachment = nullptr;
      99             : 
     100             :         Rc<MnistTrainData> _trainData;
     101             :         Rc<CsvData> _csvData;
     102             : 
     103             :         Rc<Model> _model;
     104             :         Rc<ModelProcessor> _processor;
     105             : 
     106             :         size_t _endEpoch = 2;
     107             :         size_t _epoch = 0;
     108             :         size_t _loadOffset = 0;
     109             :         float _epochLoss = 0.0f;
     110             : };
     111             : 
     112             : }
     113             : 
     114             : #endif /* TESTS_XLSNNMODELTEST_H_ */

Generated by: LCOV version 1.14