LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/layers - XLSnnLossLayer.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 44 51 86.3 %
Date: 2024-05-06 04:51:23 Functions: 8 9 88.9 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnLossLayer.h"
      24             : #include "XLSnnAttachment.h"
      25             : #include "XLSnnVkLossLayer.h"
      26             : 
      27             : namespace stappler::xenolith::shadernn {
      28             : 
      29             : static constexpr float MaxGradient = 1e+06;
      30             : 
      31           6 : void LossLayer::setParameter(ParameterIndex idx, float val) {
      32           6 :         if (idx < P_Count) {
      33           6 :                 _params[idx] = val;
      34             :         }
      35           6 : }
      36             : 
      37           2 : float LossLayer::getParameter(ParameterIndex idx) const {
      38           2 :         if (idx < P_Count) {
      39           2 :                 return _params[idx];
      40             :         }
      41           0 :         return 0.0f;
      42             : }
      43             : 
      44           0 : void LossLayer::synchronizeParameters(SpanView<float> data) {
      45           0 :         memcpy(_params.data(), data.data(), _params.size() * sizeof(float));
      46           0 : }
      47             : 
      48           1 : bool CrossEntropyLossLayer::init(Model *m, StringView tag, size_t idx, const Value &data) {
      49           1 :         if (!LossLayer::init(m, tag, idx, data)) {
      50           0 :                 return false;
      51             :         }
      52             : 
      53           1 :         _labelsInputName = data.getString("labels");
      54           1 :         _batchSize = data.getInteger("batch_size");
      55           1 :         _classesCount = data.getInteger("classes_count");
      56             : 
      57           1 :         setParameter(P_LossWeight, 1.0f);
      58           1 :         setParameter(P_Loss, 0.0f);
      59           1 :         setParameter(P_LossDivider, 1.f / float(_batchSize));
      60           1 :         setParameter(P_LossGradientDivider, getParameter(P_LossDivider) * getParameter(P_LossWeight));
      61           1 :         setParameter(P_MinGradient, -MaxGradient);
      62           1 :         setParameter(P_MaxGradient, MaxGradient);
      63             : 
      64           1 :         return true;
      65             : }
      66             : 
      67           2 : void CrossEntropyLossLayer::setInputExtent(uint32_t index, Attachment *a, Extent3 e) {
      68           2 :         Layer::setInputExtent(index, a, e);
      69           2 :         if (a->getOutputBy()->getName() == _labelsInputName) {
      70           1 :                 _inputLabels = a->getOutputBy();
      71             :         } else {
      72           1 :                 _inputNetwork = a->getOutputBy();
      73             :         }
      74           2 : }
      75             : 
      76           1 : const core::QueuePassData *CrossEntropyLossLayer::prepare(core::Queue::Builder &builder,
      77             :                         Map<Layer *, const core::AttachmentData *> inputs,
      78             :                         Map<Attachment *, const core::AttachmentData *> attachments) {
      79           1 :         auto inputLabelsIt = attachments.find(_inputs[0].attachment);
      80           1 :         auto inputNetworkIt = attachments.find(_inputs[1].attachment);
      81           1 :         auto outputIt = attachments.find(getOutput());
      82             : 
      83           1 :         if (inputLabelsIt == attachments.end() || inputNetworkIt == attachments.end() || outputIt == attachments.end()) {
      84           0 :                 log::error("snn::InputLayer", "No attachments specified");
      85           0 :                 return nullptr;
      86             :         }
      87             : 
      88           2 :         return builder.addPass(getName(), core::PassType::Compute, core::RenderOrdering(_inputIndex),
      89           1 :                         [&] (core::QueuePassBuilder &passBuilder) -> Rc<core::QueuePass> {
      90           2 :                 return Rc<vk::shadernn::CrossEntropyLossLayer>::create(builder, passBuilder, this, inputLabelsIt->second, inputNetworkIt->second, outputIt->second);
      91           1 :         });
      92             : }
      93             : 
      94           1 : const core::AttachmentData *CrossEntropyLossLayer::makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) {
      95           2 :         return builder.addAttachemnt(toString(getName(), "_output"),
      96           1 :                         [&] (core::AttachmentBuilder &attachmentBuilder) -> Rc<core::Attachment> {
      97           1 :                 if (isGlobalOutput) {
      98           1 :                         attachmentBuilder.defineAsOutput(core::FrameRenderPassState::Complete);
      99             :                 }
     100           2 :                 return Rc<vk::BufferAttachment>::create(attachmentBuilder,
     101           1 :                         core::BufferInfo(size_t(1 * 4 * sizeof(float)),
     102           2 :                                         core::BufferUsage::StorageBuffer | core::BufferUsage::TransferDst, core::PassType::Compute)
     103           2 :                 );
     104           2 :         });
     105             : }
     106             : 
     107             : }

Generated by: LCOV version 1.14