LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkLossLayer.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 7 7 100.0 %
Date: 2024-05-06 04:51:23 Functions: 6 6 100.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #ifndef SRC_BACKEND_VK_XLSNNVKLOSSLAYER_H_
      24             : #define SRC_BACKEND_VK_XLSNNVKLOSSLAYER_H_
      25             : 
      26             : #include "XLVkRenderPass.h"
      27             : #include "XLVkQueuePass.h"
      28             : #include "XLVkAttachment.h"
      29             : #include "XLSnnLossLayer.h"
      30             : 
      31             : namespace stappler::xenolith::vk::shadernn {
      32             : 
      33             : class CrossEntropyLossLayer : public vk::QueuePass {
      34             : public:
      35             :         using Front = xenolith::shadernn::CrossEntropyLossLayer;
      36             : 
      37             :         using BufferView = BufferAttachmentHandle::BufferView;
      38             : 
      39             :         using PipelineOpFn = Function<void(Front *, CommandBuffer &, ComputePipeline *, SpanView<BufferView>)>;
      40             : 
      41             :         enum class PipelineOpIndex {
      42             :                 MatrixSoftmaxByRows, // softmax to activation
      43             :                 VectorNegLog, // transform activation
      44             :                 VectorEltwiseMultiply, // multiply activation on labels
      45             :                 SumMatrixColumnsToResult, // sum activation columns to result
      46             :                 VectorSub, // compute diff fro activation to label
      47             :                 SumMatrixColumnsLabels, // sum labels
      48             :                 MultiplyDiagMatrixByMatrix, // calc gradient
      49             :                 VectorDotProduct, // loss value function
      50             :                 MultiplyDiagMatrixByMatrixForInput // prepare error propagation
      51             :         };
      52             : 
      53             :         struct PipelineOp {
      54             :                 PipelineOpIndex idx;
      55             :                 const core::ComputePipelineData *pipeline;
      56             :                 PipelineOpFn command;
      57             : 
      58           9 :                 PipelineOp(PipelineOpIndex i, const core::ComputePipelineData *p, PipelineOpFn &&f)
      59           9 :                 : idx(i), pipeline(p), command(move(f)) { };
      60             :         };
      61             : 
      62             :         static constexpr uint32_t DescriptorArraySize = 8;
      63             :         static constexpr uint32_t ParamsIdx = 0;
      64             :         static constexpr uint32_t WeightsIdx = 1;
      65             :         static constexpr uint32_t LossValueIdx = 2;
      66             :         static constexpr uint32_t LossGradientIdx = 3;
      67             :         static constexpr uint32_t InputNetworkIdx = 4;
      68             :         static constexpr uint32_t InputLabelsIdx = 5;
      69             :         static constexpr uint32_t ActivationIdx = 6;
      70             :         static constexpr uint32_t ActivationEltwiseMulIdx = 7;
      71             : 
      72             :         virtual ~CrossEntropyLossLayer();
      73             : 
      74             :         virtual bool init(Queue::Builder &queueBuilder, QueuePassBuilder &, Front *,
      75             :                         const AttachmentData *inputLabels, const AttachmentData *inputNetwork, const AttachmentData *output);
      76             : 
      77             :         void initPropagation(Queue::Builder &queueBuilder, QueuePassBuilder &);
      78             : 
      79             :         void runAll(CommandBuffer &, SpanView<BufferView>);
      80             : 
      81        1200 :         const AttachmentData *getInputLabelsAttachment() const { return _inputLabelsAttachment; }
      82        1200 :         const AttachmentData *getInputNetworkAttachment() const { return _inputNetworkAttachment; }
      83        3600 :         const AttachmentData *getWeightsAttachment() const { return _weightAttachment; }
      84        1200 :         const AttachmentData *getOutputAttachment() const { return _outputAttachment; }
      85             : 
      86        1200 :         const Front *getFront() const { return _front; }
      87             : 
      88             : protected:
      89             :         using QueuePass::init;
      90             : 
      91             :         const AttachmentData *_inputLabelsAttachment = nullptr;
      92             :         const AttachmentData *_inputNetworkAttachment = nullptr;
      93             :         const AttachmentData *_weightAttachment = nullptr;
      94             :         const AttachmentData *_outputAttachment = nullptr;
      95             : 
      96             :         Rc<Front> _front;
      97             :         Map<PipelineOpIndex, PipelineOp> _pipelineOps;
      98             : };
      99             : 
     100             : }
     101             : 
     102             : #endif /* SRC_BACKEND_VK_XLSNNVKLOSSLAYER_H_ */

Generated by: LCOV version 1.14