LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkMatrixMulLayer.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 2 2 100.0 %
Date: 2024-05-06 04:51:23 Functions: 2 2 100.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #ifndef SRC_BACKEND_VK_XLSNNVKMATRIXMULLAYER_H_
      24             : #define SRC_BACKEND_VK_XLSNNVKMATRIXMULLAYER_H_
      25             : 
      26             : #include "XLVkRenderPass.h"
      27             : #include "XLSnnVkTrainableLayer.h"
      28             : #include "XLSnnMatrixMulLayer.h"
      29             : 
      30             : namespace stappler::xenolith::vk::shadernn {
      31             : 
      32             : class MatrixMulLayer : public TrainableLayer {
      33             : public:
      34             :         using Front = xenolith::shadernn::MatrixMulLayer;
      35             : 
      36             :         virtual ~MatrixMulLayer();
      37             : 
      38             :         virtual bool init(Queue::Builder &queueBuilder, QueuePassBuilder &, Front *,
      39             :                         const AttachmentData *input, const AttachmentData *output);
      40             : 
      41             :         virtual void initPropagationSubpass(core::Queue::Builder &builder, core::QueuePassBuilder &, core::SubpassBuilder &subpass,
      42             :                          const core::PipelineLayoutData *layout) override;
      43             : 
      44             :         virtual Vector<const core::BufferData *> getTrainableGradients(Queue::Builder &queueBuilder) const;
      45             : 
      46       25200 :         const Front *getFront() const { return _front; }
      47             : 
      48          42 :         virtual uint32_t getPropagationSubpassBufferCount() const override { return 12; }
      49             : protected:
      50             :         using QueuePass::init;
      51             : 
      52             :         Rc<Front> _front;
      53             : 
      54             :         uint32_t _nbuffers = 4;
      55             :         uint32_t _weightsBufferIndex = 0;
      56             :         uint32_t _freeTermBufferIndex = 1;
      57             :         uint32_t _inputBufferIndex = 2;
      58             :         uint32_t _outputBufferIndex = 3;
      59             : 
      60             :         uint32_t _staticParams = 0;
      61             :         uint32_t _staticWeightsHistoryIndex = 1;
      62             :         uint32_t _staticTermsHistoryIndex = 2;
      63             : 
      64             :         uint32_t _propWeightsIndex = 0;
      65             :         uint32_t _propTermsIndex = 0;
      66             :         uint32_t _propOriginalOutput = 0;
      67             :         uint32_t _propOriginalInput = 0;
      68             :         uint32_t _propOutputDiff = 0;
      69             :         uint32_t _propTargetIndex = 0;
      70             :         uint32_t _propWeightsDiff = 0;
      71             :         uint32_t _propTermsDiff = 0;
      72             :         uint32_t _propFeedback = 0;
      73             : };
      74             : 
      75             : }
      76             : 
      77             : #endif /* SRC_BACKEND_VK_XLSNNVKMATRIXMULLAYER_H_ */

Generated by: LCOV version 1.14