LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkTrainableLayer.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 6 7 85.7 %
Date: 2024-05-06 04:51:23 Functions: 6 7 85.7 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #ifndef SRC_BACKEND_VK_XLSNNVKTRAINABLELAYER_H_
      24             : #define SRC_BACKEND_VK_XLSNNVKTRAINABLELAYER_H_
      25             : 
      26             : #include "XLVkQueuePass.h"
      27             : #include "XLVkAttachment.h"
      28             : 
      29             : namespace stappler::xenolith::vk::shadernn {
      30             : 
      31             : class TrainableLayer : public vk::QueuePass {
      32             : public:
      33             :         enum VariableIndex {
      34             :                 TV_MomentDecayRateVar = 0,
      35             :                 TV_OpMomentDecayRateVar,
      36             :                 TV_OpRegL2MomentDecayRateVar,
      37             :                 TV_RateVar,
      38             :                 TV_L1Threshold,
      39             :                 TV_L1Mult,
      40             :                 TV_Count
      41             :         };
      42             : 
      43             :         virtual ~TrainableLayer();
      44             : 
      45             :         virtual void initPropagation(Queue::Builder &queueBuilder, QueuePassBuilder &,
      46             :                         const core::AttachmentData *source, uint32_t bufferIndex);
      47             : 
      48             :         virtual void initPropagationSubpass(Queue::Builder &queueBuilder, core::QueuePassBuilder &,
      49             :                         core::SubpassBuilder &, const core::PipelineLayoutData *);
      50             : 
      51           0 :         virtual uint32_t getPropagationSubpassBufferCount() const { return _fullPropagationBuffers; }
      52             : 
      53             :         virtual Vector<const core::BufferData *> getTrainableGradients(Queue::Builder &queueBuilder) const;
      54             : 
      55             :         virtual bool isBackwardNeeded() const;
      56             : 
      57        7200 :         const AttachmentData *getInputAttachment() const { return _inputAttachment; }
      58       10800 :         const AttachmentData *getOutputAttachment() const { return _outputAttachment; }
      59       10800 :         const AttachmentData *getWeightsAttachment() const { return _weightsAttachment; }
      60        7200 :         const AttachmentData *getPropagationAttachment() const { return _propagationAttachment; }
      61             :         uint32_t getTargetPropagationBufferIdx() const { return _targetPropagationIdx; }
      62             : 
      63        3600 :         const AttachmentData *getExternalPropagationDataSource() const { return _externalPropagationDataSource; }
      64        3600 :         uint32_t getExternalPropagationBufferIdx() const { return _externalPropagationBufferIdx; }
      65             : 
      66             : protected:
      67             :         using QueuePass::init;
      68             : 
      69             :         const AttachmentData *_inputAttachment = nullptr;
      70             :         const AttachmentData *_outputAttachment = nullptr;
      71             :         const AttachmentData *_weightsAttachment = nullptr;
      72             : 
      73             :         const AttachmentData *_propagationAttachment = nullptr;
      74             :         uint32_t _targetPropagationIdx = 2;
      75             : 
      76             :         const AttachmentData *_externalPropagationDataSource = nullptr;
      77             :         uint32_t _externalPropagationBufferIdx = 0;
      78             : 
      79             :         uint32_t _staticPropagationBuffers = 1;
      80             :         uint32_t _fullPropagationBuffers = 2;
      81             : 
      82             :         float _momentDecayRate = 0.9f;
      83             :         float _learningRate = 0.01f;
      84             :         float _regularizationL2 = 0.f;
      85             :         float _regularizationL1 = 0.f;
      86             :         float _maxGradientNorm = -1.f;
      87             : };
      88             : 
      89             : }
      90             : 
      91             : #endif /* SRC_BACKEND_VK_XLSNNVKTRAINABLELAYER_H_ */

Generated by: LCOV version 1.14