LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkTrainableLayer.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 43 47 91.5 %
Date: 2024-05-06 04:51:23 Functions: 8 11 72.7 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnVkTrainableLayer.h"
      24             : 
      25             : namespace stappler::xenolith::vk::shadernn {
      26             : 
      27           3 : TrainableLayer::~TrainableLayer() { }
      28             : 
      29           3 : void TrainableLayer::initPropagation(Queue::Builder &queueBuilder, QueuePassBuilder &builder, const core::AttachmentData *source, uint32_t idx) {
      30             :         using namespace core;
      31             : 
      32           6 :         auto paramsBuf = queueBuilder.addBuffer(toString(getName(), "_trainingParams"), BufferInfo(BufferUsage::StorageBuffer, size_t(TV_Count * sizeof(float))),
      33           3 :                         [this] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &) {
      34           3 :                 auto target = (float *)buf;
      35             : 
      36           3 :                 float rate = _learningRate;
      37           3 :                 float regL1 = _regularizationL1;
      38           3 :                 float regL2 = _regularizationL2;
      39             : 
      40           3 :                 target[TV_MomentDecayRateVar] = _momentDecayRate;
      41           3 :                 target[TV_OpMomentDecayRateVar] = 1 - _momentDecayRate;
      42           3 :                 target[TV_OpRegL2MomentDecayRateVar] = -rate * regL2;
      43           3 :                 target[TV_RateVar] = (-rate);
      44           3 :                 target[TV_L1Threshold] = regL1;
      45           3 :                 target[TV_L1Mult] = -rate;
      46           3 :         });
      47             : 
      48           3 :         auto trainable = getTrainableGradients(queueBuilder);
      49           3 :         trainable.emplace(trainable.begin(), paramsBuf);
      50             : 
      51           3 :         _staticPropagationBuffers = trainable.size();
      52             : 
      53           3 :         _propagationAttachment = queueBuilder.addAttachemnt(toString(getName(), "_BackwardOnce_data"),
      54           3 :                         [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
      55           6 :                 return Rc<vk::BufferAttachment>::create(builder, move(trainable));
      56             :         });
      57             : 
      58           3 :         _externalPropagationDataSource = source;
      59           3 :         _externalPropagationBufferIdx = idx;
      60             : 
      61           3 :         auto propagationPassAttachment = builder.addAttachment(_propagationAttachment);
      62             : 
      63           3 :         auto l = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layout) {
      64           3 :                 layout.addSet([&] (DescriptorSetBuilder &set) {
      65           6 :                         set.addDescriptorArray(propagationPassAttachment, getPropagationSubpassBufferCount(), DescriptorType::StorageBuffer);
      66           3 :                 });
      67           6 :         });
      68             : 
      69           3 :         builder.addSubpass([&] (SubpassBuilder &subpass) {
      70           3 :                 initPropagationSubpass(queueBuilder, builder, subpass, l);
      71           3 :         });
      72             : 
      73           3 :         const core::QueuePassData *pass = _inputAttachment->passes.front()->pass;
      74           3 :         if (_inputAttachment->passes.front()->pass->pass != this) {
      75           3 :                 auto trainable = dynamic_cast<TrainableLayer *>(pass->pass.get());
      76           3 :                 if (trainable) {
      77           2 :                         trainable->initPropagation(queueBuilder, builder, _propagationAttachment, _targetPropagationIdx);
      78             :                 }
      79             :         }
      80           3 : }
      81             : 
      82           0 : void TrainableLayer::initPropagationSubpass(Queue::Builder &queueBuilder, core::QueuePassBuilder &,
      83             :                 core::SubpassBuilder &, const core::PipelineLayoutData *) {
      84             : 
      85           0 : }
      86             : 
      87           0 : Vector<const core::BufferData *> TrainableLayer::getTrainableGradients(Queue::Builder &queueBuilder) const {
      88           0 :         return Vector<const core::BufferData *>();
      89             : }
      90             : 
      91           3 : bool TrainableLayer::isBackwardNeeded() const {
      92           3 :         const core::QueuePassData *pass = _inputAttachment->passes.front()->pass;
      93           3 :         if (_inputAttachment->passes.front()->pass->pass != this && dynamic_cast<TrainableLayer *>(pass->pass.get())) {
      94           2 :                 return true;
      95             :         }
      96           1 :         return false;
      97             : }
      98             : 
      99             : }

Generated by: LCOV version 1.14