Line data Source code
1 : /** 2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev> 3 : 4 : Permission is hereby granted, free of charge, to any person obtaining a copy 5 : of this software and associated documentation files (the "Software"), to deal 6 : in the Software without restriction, including without limitation the rights 7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 : copies of the Software, and to permit persons to whom the Software is 9 : furnished to do so, subject to the following conditions: 10 : 11 : The above copyright notice and this permission notice shall be included in 12 : all copies or substantial portions of the Software. 13 : 14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 : THE SOFTWARE. 21 : **/ 22 : 23 : #ifndef TESTS_XLSNNMODELTEST_H_ 24 : #define TESTS_XLSNNMODELTEST_H_ 25 : 26 : #include "XLCoreQueue.h" 27 : #include "XLApplication.h" 28 : #include "XLSnnModel.h" 29 : #include "XLSnnModelProcessor.h" 30 : 31 : namespace stappler::xenolith::shadernn { 32 : 33 : struct MnistTrainData : Ref { 34 : struct ImagesHeader { 35 : uint32_t magic; 36 : uint32_t images; 37 : uint32_t rows; 38 : uint32_t columns; 39 : }; 40 : 41 : struct VectorsHeader { 42 : uint32_t magic; 43 : uint32_t items; 44 : }; 45 : 46 : MnistTrainData(StringView path); 47 : ~MnistTrainData(); 48 : 49 : void loadVectors(StringView ipath); 50 : void loadImages(StringView ipath); 51 : void loadIndexes(); 52 : 53 : void readImages(uint8_t *ptr, uint64_t size, size_t offset); 54 : void readLabels(uint8_t *iptr, uint64_t size, size_t offset); 55 : 56 : bool validateImages(const uint8_t *ptr, uint64_t size, size_t offset); 57 : 58 : void shuffle(Random &rnd); 59 : 60 : ImagesHeader imagesHeader; 61 : VectorsHeader vectorsHeader; 62 : 63 : float *imagesData = 0; 64 : float *vectorsData = 0; 65 : 66 : size_t imagesSize = 0; 67 : size_t vectorsSize = 0; 68 : 69 : std::vector<uint32_t> indexes; 70 : }; 71 : 72 : struct CsvData : Ref { 73 : Vector<String> fields; 74 : Vector<Value> data; 75 : }; 76 : 77 : class ModelQueue : public core::Queue { 78 : public: 79 : static Rc<CsvData> readCsv(StringView); 80 : 81 : virtual ~ModelQueue(); 82 : 83 : bool init(StringView model, ModelFlags, StringView input); 84 : 85 : const Vector<const AttachmentData *> &getInputAttachments() const { return _inputAttachments; } 86 1200 : const AttachmentData *getOutputAttachment() const { return _outputAttachment; } 87 : 88 : void run(Application *); 89 : 90 : protected: 91 : using core::Queue::init; 92 : 93 : void onComplete(Application *, BytesView data); 94 : 95 : Application *_app = nullptr; 96 : String _image; 97 : Vector<const AttachmentData *> _inputAttachments; 98 : const AttachmentData *_outputAttachment = nullptr; 99 : 100 : Rc<MnistTrainData> _trainData; 101 : Rc<CsvData> _csvData; 102 : 103 : Rc<Model> _model; 104 : Rc<ModelProcessor> _processor; 105 : 106 : size_t _endEpoch = 2; 107 : size_t _epoch = 0; 108 : size_t _loadOffset = 0; 109 : float _epochLoss = 0.0f; 110 : }; 111 : 112 : } 113 : 114 : #endif /* TESTS_XLSNNMODELTEST_H_ */