LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/processor - XLSnnModelProcessor.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 72 99 72.7 %
Date: 2024-05-06 04:51:23 Functions: 9 15 60.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnModelProcessor.h"
      24             : #include "XLSnnInputLayer.h"
      25             : #include "XLSnnConvLayer.h"
      26             : #include "XLSnnSubpixelLayer.h"
      27             : #include "XLSnnStatPercentLayer.h"
      28             : #include "XLSnnMatrixMulLayer.h"
      29             : #include "XLSnnLossLayer.h"
      30             : 
      31             : namespace stappler::xenolith::shadernn {
      32             : 
      33           1 : bool ModelProcessor::init() {
      34           1 :         _layers.emplace("inputlayer", [] (Model *m, StringView tag, size_t idx, const Value &data) -> Rc<Layer> {
      35           0 :                 return Rc<InputLayer>::create(m, tag, idx, data);
      36             :         });
      37             : 
      38           1 :         _layers.emplace("inputbufferlayer", [] (Model *m, StringView tag, size_t idx, const Value &data) -> Rc<Layer> {
      39           4 :                 return Rc<InputBufferLayer>::create(m, tag, idx, data);
      40             :         });
      41             : 
      42           1 :         _layers.emplace("inputcsvintlayer", [] (Model *m, StringView tag, size_t idx, const Value &data) -> Rc<Layer> {
      43           0 :                 return Rc<InputCsvIntLayer>::create(m, tag, idx, data);
      44             :         });
      45             : 
      46           1 :         _layers.emplace("conv2d", [] (Model *m, StringView tag, size_t idx, const Value &data) -> Rc<Layer> {
      47           0 :                 return Rc<Conv2DLayer>::create(m, tag, idx, data);
      48             :         });
      49             : 
      50           1 :         _layers.emplace("subpixel", [] (Model *m, StringView tag, size_t idx, const Value &data) -> Rc<Layer> {
      51           0 :                 return Rc<SubpixelLayer>::create(m, tag, idx, data);
      52             :         });
      53             : 
      54           1 :         _layers.emplace("statpercentlayer", [] (Model *m, StringView tag, size_t idx, const Value &data) -> Rc<Layer> {
      55           0 :                 return Rc<StatPercentLayer>::create(m, tag, idx, data);
      56             :         });
      57             : 
      58           1 :         _layers.emplace("statanalysislayer", [] (Model *m, StringView tag, size_t idx, const Value &data) -> Rc<Layer> {
      59           0 :                 return Rc<StatAnalysisLayer>::create(m, tag, idx, data);
      60             :         });
      61             : 
      62           1 :         _layers.emplace("matrixmullayer", [] (Model *m, StringView tag, size_t idx, const Value &data) -> Rc<Layer> {
      63           6 :                 return Rc<MatrixMulLayer>::create(m, tag, idx, data);
      64             :         });
      65             : 
      66           1 :         _layers.emplace("crossentropylosslayer", [] (Model *m, StringView tag, size_t idx, const Value &data) -> Rc<Layer> {
      67           2 :                 return Rc<CrossEntropyLossLayer>::create(m, tag, idx, data);
      68             :         });
      69             : 
      70           1 :         return true;
      71             : }
      72             : 
      73           1 : Rc<Model> ModelProcessor::load(FilePath modelPath, ModelFlags flags) {
      74           1 :         Value data;
      75           1 :         if (filesystem::exists(modelPath.get())) {
      76           1 :                 data = data::readFile<Interface>(modelPath.get());
      77           0 :         } else if (modelPath.get().at(0) != '/') {
      78           0 :                 auto path = filesystem::currentDir<Interface>(modelPath.get());
      79           0 :                 if (filesystem::exists(path)) {
      80           0 :                         data = data::readFile<Interface>(path);
      81             :                 }
      82           0 :         }
      83           1 :         if (!data) {
      84           0 :                 return nullptr;
      85             :         }
      86             : 
      87           1 :         auto &numNode = data.getValue("numLayers");
      88           1 :         auto numLayers = numNode.getInteger("count");
      89           1 :         if (numLayers == 0) {
      90           0 :                 return nullptr;
      91             :         }
      92             : 
      93           1 :         String dataFilePath;
      94           1 :         if (numNode.isString("bin_file_name")) {
      95           0 :                 auto dataFile = numNode.getString("bin_file_name");
      96           0 :                 dataFilePath = filepath::merge<Interface>(filepath::root(modelPath.get()), dataFile);
      97           0 :         }
      98             : 
      99           1 :         auto m = Rc<Model>::create(flags, data, numLayers, dataFilePath);
     100           1 :         if (loadFromJson(m, move(data)) && m->link()) {
     101           1 :                 return m;
     102             :         }
     103           0 :         return nullptr;
     104           1 : }
     105             : 
     106        1200 : ModelSpecialization ModelProcessor::specializeModel(Model *model, Extent3 extent) {
     107        1200 :         Map<const Layer *, Extent3> inputs;
     108        3600 :         for (auto &it : model->getInputs()) {
     109        2400 :                 inputs.emplace(it, extent);
     110        1200 :         }
     111        2400 :         return specializeModel(model, move(inputs));
     112        1200 : }
     113             : 
     114        1200 : ModelSpecialization ModelProcessor::specializeModel(Model *model, Map<const Layer *, Extent3> &&inputs) {
     115        1200 :         ModelSpecialization ret;
     116        1200 :         ret.inputs = move(inputs);
     117             : 
     118        3600 :         for (auto &it : ret.inputs) {
     119        2400 :                 ret.attachments.emplace(it.first->getOutput(), it.second);
     120             :         }
     121             : 
     122        8400 :         for (auto &it : model->getSortedLayers()) {
     123        7200 :                 if (ret.inputs.find(it) == ret.inputs.end()) {
     124        4800 :                         auto ext = it->getOutputExtent(ret);
     125        4800 :                         ret.attachments.emplace(it->getOutput(), ext);
     126             :                 }
     127             :         }
     128             : 
     129             :         /*for (auto &it : model->getSortedLayers()) {
     130             :                 auto iit = ret.attachments.find(it->getOutput());
     131             :                 if (iit != ret.attachments.end()) {
     132             :                         std::cout << "Specialization: Layer " << it->getName() << " (" << it->getTag() << ") " << iit->second << "\n";
     133             :                 }
     134             :         }*/
     135             : 
     136        1200 :         return ret;
     137           0 : }
     138             : 
     139           1 : bool ModelProcessor::loadFromJson(Model *m, Value &&data) const {
     140           9 :         for (auto &it : data.asDict()) {
     141           8 :                 if (!StringView(it.first).starts_with("Layer_")) {
     142           2 :                         continue;
     143             :                 }
     144             : 
     145           6 :                 StringView tag(it.first);
     146           6 :                 StringView r(tag);
     147           6 :                 r += "Layer_"_len;
     148             : 
     149           6 :                 auto idx = r.readInteger(10).get();
     150           6 :                 if (auto l = makeLayer(m, tag, idx, move(it.second))) {
     151           6 :                         m->addLayer(move(l));
     152             :                 } else {
     153           0 :                         log::error("ModelProcessor", "Fail to load layer: ", tag);
     154           0 :                         return false;
     155           6 :                 }
     156             :         }
     157           1 :         return true;
     158             : }
     159             : 
     160           6 : Rc<Layer> ModelProcessor::makeLayer(Model *m, StringView tag, size_t idx, Value &&data) const {
     161           6 :         auto numInbound = data.getInteger("numInputs");
     162           6 :         auto &inputs = data.getArray("inputId");
     163           6 :         if (numInbound != int64_t(inputs.size())) {
     164           0 :                 return nullptr;
     165             :         }
     166             : 
     167           6 :         auto layerType = data.getString("type");
     168           6 :         if (layerType == "Lambda") {
     169           0 :                 layerType = data.getString("name");
     170             :         }
     171             : 
     172           6 :         if (layerType == "DepthwiseConv2D" || layerType == "Depthwise") {
     173           0 :                 layerType = "SeparableConv2D";
     174             :         }
     175           6 :         if (layerType == "InstanceNormalization") {
     176           0 :                 layerType = "InstanceNorm";
     177             :         }
     178           6 :         if (layerType == "ZeroPadding2D") {
     179           0 :                 layerType = "Pad";
     180             :         }
     181           6 :         if (layerType == "subpixel" || layerType == "depth_to_space") {
     182           0 :                 layerType = "Subpixel";
     183             :         }
     184             : 
     185           6 :         layerType = string::tolower<Interface>(layerType);
     186             : 
     187           6 :         auto it = _layers.find(layerType);
     188           6 :         if (it != _layers.end()) {
     189           6 :                 return it->second(m, tag, idx, move(data));
     190             :         }
     191             : 
     192           0 :         return nullptr;
     193           6 : }
     194             : 
     195             : }

Generated by: LCOV version 1.14