LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/processor - XLSnnModel.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 63 198 31.8 %
Date: 2024-05-06 04:51:23 Functions: 8 19 42.1 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnModel.h"
      24             : #include "XLSnnLossLayer.h"
      25             : #include "XLSnnAttachment.h"
      26             : 
      27             : namespace stappler::xenolith::shadernn {
      28             : 
      29           0 : void Model::saveBlob(const char *path, const uint8_t *buf, size_t size) {
      30           0 :         ::unlink(path);
      31             : 
      32           0 :         auto f = fopen(path, "w");
      33           0 :         ::fwrite(&size, sizeof(size), 1, f);
      34           0 :         ::fwrite(buf, size, 1, f);
      35           0 :         ::fclose(f);
      36           0 : }
      37             : 
      38           0 : void Model::loadBlob(const char *path, const std::function<void(const uint8_t *, size_t)> &cb) {
      39           0 :         auto f = fopen(path, "r");
      40           0 :         if (!f) {
      41           0 :                 return;
      42             :         }
      43             : 
      44           0 :         size_t size = 0;
      45           0 :         ::fread(&size, sizeof(size), 1, f);
      46             : 
      47           0 :         auto buf = new uint8_t[size];
      48           0 :         if (size > 0) {
      49           0 :                 ::fread(buf, size, 1, f);
      50             : 
      51           0 :                 cb(buf, size);
      52             :         }
      53             : 
      54           0 :         ::fclose(f);
      55           0 :         delete[] buf;
      56             : }
      57             : 
      58           0 : bool Model::compareBlob(const uint8_t *a, size_t na, const uint8_t *b, size_t nb, float v) {
      59           0 :         return compareBlob((const float *)a, na / sizeof(float), (const float *)b, nb / sizeof(float), v);
      60             : }
      61             : 
      62           0 : bool Model::compareBlob(const float *a, size_t na, const float *b, size_t nb, float v) {
      63           0 :         if (na != nb) {
      64           0 :                 return false;
      65             :         }
      66             : 
      67           0 :         while (na > 0) {
      68           0 :                 if (std::abs(*a - *b) > v) {
      69           0 :                         return false;
      70             :                 }
      71           0 :                 ++ a; ++ b; -- na;
      72             :         }
      73           0 :         return true;
      74             : }
      75             : 
      76           2 : Model::~Model() { }
      77             : 
      78           1 : bool Model::init(ModelFlags f, const Value &val, uint32_t numLayers, StringView dataFilePath) {
      79           1 :         _flags = f;
      80           1 :         _numLayers = numLayers;
      81             : 
      82           1 :         if (!dataFilePath.empty()) {
      83           0 :                 _dataFile = filesystem::openForReading(dataFilePath);
      84           0 :                 if (!_dataFile) {
      85           0 :                         log::error("snn::Model", "Fail to open model data file: ", dataFilePath);
      86           0 :                         return false;
      87             :                 }
      88             :         }
      89             : 
      90           1 :         if (val.isBool("trainable")) {
      91           1 :                 if (val.getBool("trainable")) {
      92           1 :                         _flags |= ModelFlags::Trainable;
      93             :                 }
      94             :         }
      95             : 
      96           1 :         auto range = val.getString("inputRange");
      97           1 :         if (range == "[0,1]") {
      98           0 :                 _flags |= ModelFlags::Range01;
      99             :         }
     100             : 
     101           1 :         if (auto &block = val.getValue("block_0")) {
     102           0 :                 for (auto &it : block.asDict()) {
     103           0 :                         if (it.first == "Input Height") {
     104           0 :                                 _inputHeight = it.second.getInteger();
     105           0 :                         } else if (it.first == "Input Width") {
     106           0 :                                 _inputWidth = it.second.getInteger();
     107             :                         }
     108             :                 }
     109             :         }
     110             : 
     111           1 :         if (auto &node = val.getValue("node")) {
     112           0 :                 for (auto &it : node.asDict()) {
     113           0 :                         if (it.first == "inputChannels") {
     114           0 :                                 _inputChannels = it.second.getInteger();
     115           0 :                         } else if (it.first == "upscale") {
     116           0 :                                 _upscale = it.second.getInteger();
     117           0 :                         } else if (it.first == "useSubpixel") {
     118           0 :                                 _useSubPixel = it.second.getBool();
     119             :                         }
     120             :                 }
     121             :         }
     122             : 
     123           1 :         return true;
     124           1 : }
     125             : 
     126           6 : void Model::addLayer(Rc<Layer> &&l) {
     127           6 :         _layers.emplace(l->getInputIndex(), move(l));
     128           6 : }
     129             : 
     130           1 : bool Model::link() {
     131             :         // find inputs
     132           1 :         Vector<Layer *> linkedLayers;
     133           7 :         for (auto &it : _layers) {
     134           6 :                 if (it.second->isInput()) {
     135           2 :                         auto attachment = Rc<Attachment>::create(_attachments.size(), it.second->getOutputExtent(), it.second);
     136           2 :                         it.second->setOutput(attachment);
     137           2 :                         linkedLayers.emplace_back(it.second.get());
     138           2 :                         linkInput(linkedLayers, it.second, _attachments.emplace_back(move(attachment)));
     139           2 :                 }
     140             :         }
     141             : 
     142           1 :         if (linkedLayers.size() == _layers.size()) {
     143           1 :                 _sortedLayers = move(linkedLayers);
     144           1 :                 return true;
     145             :         }
     146             : 
     147           0 :         log::error("snn::Model", "Fail to link model: potential loop in execution tree");
     148           0 :         return false;
     149           1 : }
     150             : 
     151           0 : float Model::readFloatData() {
     152           0 :         float value = 0.0f;
     153           0 :         if (_dataFile.read((uint8_t*) &value, sizeof(float)) == sizeof(float)) {
     154           0 :                 if (isHalfPrecision()) {
     155           0 :                         value = convertToMediumPrecision(value);
     156             :                 }
     157             :         }
     158           0 :         return value;
     159             : }
     160             : 
     161           0 : float Model::getLastLoss() const {
     162           0 :         if (auto l = dynamic_cast<LossLayer *>(_sortedLayers.back())) {
     163           0 :                 return l->getParameter(LossLayer::P_Loss);
     164             :         }
     165           0 :         return 0.0f;
     166             : }
     167             : 
     168        1201 : Vector<Layer *> Model::getInputs() const {
     169        1201 :         Vector<Layer *> ret;
     170        8407 :         for (auto &it : _sortedLayers) {
     171        7206 :                 if (it->isInput()) {
     172        2402 :                         ret.emplace_back(it);
     173             :                 }
     174             :         }
     175        1201 :         return ret;
     176           0 : }
     177             : 
     178           6 : void Model::linkInput(Vector<Layer *> &layers, Layer *inputLayer, Attachment *attachment) {
     179             :         // find usage
     180          42 :         for (auto &it : _layers) {
     181          66 :                 for (auto &input : it.second->getInputs()) {
     182          30 :                         if (input.layer == inputLayer->getInputIndex()) {
     183           5 :                                 attachment->addInputBy(it.second);
     184           5 :                                 it.second->setInputExtent(input.index, attachment, inputLayer->getOutputExtent());
     185           5 :                                 if (it.second->isInputDefined()) {
     186           4 :                                         if (std::find(layers.begin(), layers.end(), it.second.get()) == layers.end()) {
     187           4 :                                                 auto attachment = Rc<Attachment>::create(_attachments.size(), it.second->getOutputExtent(), it.second);
     188           4 :                                                 it.second->setOutput(attachment);
     189           4 :                                                 layers.emplace_back(it.second.get());
     190           4 :                                                 std::cout << "Layer " << it.second->getName() << " (" << it.second->getTag() << ") " << it.second->getOutputExtent() << "\n";
     191           4 :                                                 linkInput(layers, it.second, _attachments.emplace_back(move(attachment)));
     192           4 :                                         }
     193             :                                 }
     194             :                         }
     195             :                 }
     196             :         }
     197           6 : }
     198             : 
     199           3 : Activation getActivationValue(StringView istr) {
     200           3 :         auto str = string::toupper<Interface>(istr);
     201             : 
     202           3 :         if (str == "RELU") {
     203           2 :                 return Activation::RELU;
     204           1 :         } else if (str == "RELU6") {
     205           0 :                 return Activation::RELU6;
     206           1 :         } else if (str == "TANH") {
     207           0 :                 return Activation::TANH;
     208           1 :         } else if (str == "SIGMOID") {
     209           0 :                 return Activation::SIGMOID;
     210           1 :         } else if (str == "LEAKYRELU") {
     211           0 :                 return Activation::LEAKYRELU;
     212           1 :         } else if (str == "SILU") {
     213           0 :                 return Activation::SILU;
     214           1 :         } else if (str == "LINEAR") {
     215           1 :                 return Activation::None;
     216             :         }
     217           0 :         log::error("snn::Model", "Unknown activation: ", istr);
     218           0 :         return Activation::None;
     219           3 : }
     220             : 
     221           0 : float convertToMediumPrecision(float in) {
     222             :         union tmp {
     223             :                 unsigned int unsint;
     224             :                 float flt;
     225             :         };
     226             : 
     227             :         tmp _16to32, _32to16;
     228             : 
     229           0 :         _32to16.flt = in;
     230             : 
     231           0 :         unsigned short sign = (_32to16.unsint & 0x80000000) >> 31;
     232           0 :         unsigned short exponent = (_32to16.unsint & 0x7F800000) >> 23;
     233           0 :         unsigned int mantissa = _32to16.unsint & 0x7FFFFF;
     234             : 
     235           0 :         short newexp = exponent + (-127 + 15);
     236             : 
     237             :         unsigned int newMantissa;
     238           0 :         if (newexp >= 31) {
     239           0 :                 newexp = 31;
     240           0 :                 newMantissa = 0x00;
     241           0 :         } else if (newexp <= 0) {
     242           0 :                 newexp = 0;
     243           0 :                 newMantissa = 0;
     244             :         } else {
     245           0 :                 newMantissa = mantissa >> 13;
     246             :         }
     247             : 
     248           0 :         if (newexp == 0) {
     249           0 :                 if (newMantissa == 0) {
     250           0 :                         _16to32.unsint = sign << 31;
     251             :                 } else {
     252           0 :                         newexp = 0;
     253           0 :                         while ((newMantissa & 0x200) == 0) {
     254           0 :                                 newMantissa <<= 1;
     255           0 :                                 newexp++;
     256             :                         }
     257           0 :                         newMantissa <<= 1;
     258           0 :                         newMantissa &= 0x3FF;
     259           0 :                         _16to32.unsint = (sign << 31) | ((newexp + (-15 + 127)) << 23) | (newMantissa << 13);
     260             :                 }
     261           0 :         } else if (newexp == 31) {
     262           0 :                 _16to32.unsint = (sign << 31) | (0xFF << 23) | (newMantissa << 13);
     263             :         } else {
     264           0 :                 _16to32.unsint = (sign << 31) | ((newexp + (-15 + 127)) << 23) | (newMantissa << 13);
     265             :         }
     266             : 
     267           0 :         return _16to32.flt;
     268             : }
     269             : 
     270           0 : float convertToHighPrecision(uint16_t in) {
     271             :         union tmp {
     272             :                 unsigned int unsint;
     273             :                 float flt;
     274             :         };
     275             : 
     276           0 :         unsigned short sign = (in & 0x8000) >> 15;
     277           0 :         unsigned short exponent = (in & 0x7C00) >> 10;
     278           0 :         unsigned int mantissa = (in & 0x3FF);
     279             : 
     280             :         tmp _16to32;
     281             : 
     282           0 :         short newexp = exponent + 127 - 15;
     283           0 :         unsigned int newMantissa = mantissa;
     284           0 :         if (newexp <= 0) {
     285           0 :                 newexp = 0;
     286           0 :                 newMantissa = 0;
     287             :         }
     288             : 
     289           0 :         _16to32.unsint = (sign << 31) | (newexp << 23) | (newMantissa << 13);
     290           0 :         return _16to32.flt;
     291             : }
     292             : 
     293           0 : void convertToMediumPrecision(Vector<float> &in) {
     294           0 :         for (auto &val : in) {
     295           0 :                 val = convertToMediumPrecision(val);
     296             :         }
     297           0 : }
     298             : 
     299           0 : void convertToMediumPrecision(Vector<double> &in) {
     300           0 :         for (auto &val : in) {
     301           0 :                 val = convertToMediumPrecision(val);
     302             :         }
     303           0 : }
     304             : 
     305           0 : void getByteRepresentation(float in, Vector<unsigned char> &byteRep, bool fp16) {
     306           0 :         if (fp16) {
     307             :                 union tmp {
     308             :                         unsigned int unsint;
     309             :                         float flt;
     310             :                 };
     311             :                 tmp _32to16;
     312           0 :                 _32to16.flt = in;
     313             :                 unsigned short fp16Val;
     314             : 
     315           0 :                 unsigned short sign = (_32to16.unsint & 0x80000000) >> 31;
     316           0 :                 unsigned short exponent = (_32to16.unsint & 0x7F800000) >> 23;
     317           0 :                 unsigned int mantissa = (_32to16.unsint & 0x7FFFFF);
     318             : 
     319           0 :                 short newexp = exponent + (-127 + 15);
     320             : 
     321             :                 unsigned int newMantissa;
     322           0 :                 if (newexp >= 31) {
     323           0 :                         newexp = 31;
     324           0 :                         newMantissa = 0x00;
     325           0 :                 } else if (newexp <= 0) {
     326           0 :                         newexp = 0;
     327           0 :                         newMantissa = 0;
     328             :                 } else {
     329           0 :                         newMantissa = mantissa >> 13;
     330             :                 }
     331             : 
     332           0 :                 fp16Val = sign << 15;
     333           0 :                 fp16Val = fp16Val | (newexp << 10);
     334           0 :                 fp16Val = fp16Val | (newMantissa);
     335             : 
     336           0 :                 byteRep.push_back(fp16Val & 0xFF);
     337           0 :                 byteRep.push_back(fp16Val >> 8 & 0xFF);
     338             :         } else {
     339             :                 uint32_t fp32Val;
     340           0 :                 ::memcpy(&fp32Val, &in, 4);
     341           0 :                 for (int i = 0; i < 4; i++) {
     342           0 :                         byteRep.push_back((fp32Val >> 8 * i) & 0xFF);
     343             :                 }
     344             :         }
     345           0 : }
     346             : 
     347             : }

Generated by: LCOV version 1.14