LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/layers - XLSnnStatPercentLayer.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 61 0.0 %
Date: 2024-05-06 04:51:23 Functions: 0 11 0.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnStatPercentLayer.h"
      24             : #include "XLSnnVkStatPercentLayer.h"
      25             : 
      26             : namespace stappler::xenolith::shadernn {
      27             : 
      28           0 : bool StatPercentLayer::init(Model *m, StringView tag, size_t idx, const Value &data) {
      29           0 :         if (!Layer::init(m, tag, idx, data)) {
      30           0 :                 return false;
      31             :         }
      32             : 
      33           0 :         _fieldClass = data.getInteger("fieldClass");
      34           0 :         _fieldSource = data.getInteger("fieldSource");
      35           0 :         _fieldTarget = data.getInteger("fieldTarget");
      36           0 :         _classMin = data.getInteger("classMin");
      37           0 :         _classCount = data.getInteger("classCount");
      38             : 
      39           0 :         return true;
      40             : }
      41             : 
      42           0 : const core::QueuePassData *StatPercentLayer::prepare(core::Queue::Builder &builder,
      43             :                 Map<Layer *, const core::AttachmentData *> inputs,
      44             :                 Map<Attachment *, const core::AttachmentData *> attachments) {
      45           0 :         auto inputIt = attachments.find(_inputs.front().attachment);
      46           0 :         auto outputIt = attachments.find(getOutput());
      47             : 
      48           0 :         if (inputIt == attachments.end() || outputIt == attachments.end()) {
      49           0 :                 log::error("snn::InputLayer", "No attachments specified");
      50           0 :                 return nullptr;
      51             :         }
      52             : 
      53           0 :         return builder.addPass(getName(), core::PassType::Compute, core::RenderOrdering(_inputIndex),
      54           0 :                         [&] (core::QueuePassBuilder &passBuilder) -> Rc<core::QueuePass> {
      55           0 :                 return Rc<vk::shadernn::StatPercentLayer>::create(builder, passBuilder, this, inputIt->second, outputIt->second);
      56           0 :         });
      57             : }
      58             : 
      59           0 : const core::AttachmentData *StatPercentLayer::makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) {
      60           0 :         return builder.addAttachemnt(toString(getName(), "_output"),
      61           0 :                         [&] (core::AttachmentBuilder &attachmentBuilder) -> Rc<core::Attachment> {
      62           0 :                 if (isGlobalOutput) {
      63           0 :                         attachmentBuilder.defineAsOutput();
      64             :                 }
      65           0 :                 return Rc<vk::BufferAttachment>::create(attachmentBuilder,
      66           0 :                         core::BufferInfo(size_t(_classCount * 4 * sizeof(float)),
      67           0 :                                         core::BufferUsage::StorageBuffer | core::BufferUsage::TransferDst, core::PassType::Compute)
      68           0 :                 );
      69           0 :         });
      70             : }
      71             : 
      72           0 : bool StatAnalysisLayer::init(Model *m, StringView tag, size_t idx, const Value &data) {
      73           0 :         if (!Layer::init(m, tag, idx, data)) {
      74           0 :                 return false;
      75             :         }
      76             : 
      77           0 :         _threshold = float(data.getDouble("threshold"));
      78             : 
      79           0 :         return true;
      80             : }
      81             : 
      82           0 : void StatAnalysisLayer::setInputExtent(uint32_t index, Attachment *a, Extent3 e) {
      83           0 :         Layer::setInputExtent(index, a, e);
      84             : 
      85           0 :         auto outBy = a->getOutputBy();
      86           0 :         if (auto l = dynamic_cast<StatPercentLayer *>(outBy)) {
      87           0 :                 _percent = l;
      88             :         }
      89           0 : }
      90             : 
      91           0 : const core::QueuePassData *StatAnalysisLayer::prepare(core::Queue::Builder &builder,
      92             :                 Map<Layer *, const core::AttachmentData *> inputs,
      93             :                 Map<Attachment *, const core::AttachmentData *> attachments) {
      94           0 :         auto inputDataIt = attachments.find(_inputs[0].attachment);
      95           0 :         auto inputClasseIt = attachments.find(_inputs[1].attachment);
      96           0 :         auto outputIt = attachments.find(getOutput());
      97             : 
      98           0 :         if (inputDataIt == attachments.end() || inputClasseIt == attachments.end() || outputIt == attachments.end()) {
      99           0 :                 log::error("snn::InputLayer", "No attachments specified");
     100           0 :                 return nullptr;
     101             :         }
     102             : 
     103           0 :         return builder.addPass(getName(), core::PassType::Compute, core::RenderOrdering(_inputIndex),
     104           0 :                         [&] (core::QueuePassBuilder &passBuilder) -> Rc<core::QueuePass> {
     105           0 :                 return Rc<vk::shadernn::StatAnalysisLayer>::create(builder, passBuilder, this, inputDataIt->second, inputClasseIt->second, outputIt->second);
     106           0 :         });
     107             : }
     108             : 
     109           0 : const core::AttachmentData *StatAnalysisLayer::makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) {
     110           0 :         return builder.addAttachemnt(toString(getName(), "_output"),
     111           0 :                         [&] (core::AttachmentBuilder &attachmentBuilder) -> Rc<core::Attachment> {
     112           0 :                 if (isGlobalOutput) {
     113           0 :                         attachmentBuilder.defineAsOutput();
     114             :                 }
     115           0 :                 return Rc<vk::BufferAttachment>::create(attachmentBuilder,
     116           0 :                         core::BufferInfo(size_t(1 * 4 * sizeof(float)),
     117           0 :                                         core::BufferUsage::StorageBuffer | core::BufferUsage::TransferDst, core::PassType::Compute)
     118           0 :                 );
     119           0 :         });
     120             : }
     121             : 
     122             : }

Generated by: LCOV version 1.14