LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkMatrixMulLayer.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 348 349 99.7 %
Date: 2024-05-06 04:51:23 Functions: 24 24 100.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnVkMatrixMulLayer.h"
      24             : #include "XLSnnVkNmeMath.h"
      25             : #include "XLSnnMatrixMulLayer.h"
      26             : #include "XLSnnModel.h"
      27             : 
      28             : namespace stappler::xenolith::vk::shadernn {
      29             : 
      30           6 : MatrixMulLayer::~MatrixMulLayer() { }
      31             : 
      32           3 : bool MatrixMulLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front *front,
      33             :                 const AttachmentData *input, const AttachmentData *output) {
      34             :         using namespace core;
      35             : 
      36           3 :         _front = front;
      37             : 
      38          12 :         auto weightsBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_weights_buffer"),
      39           6 :                         BufferInfo(front->getWeightBufferSize(), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
      40           3 :                         [front = _front] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
      41             :                 /*auto name = toString(front->getName(), ".", front->getInputIndex(), ".weights.bin");
      42             :                 xenolith::shadernn::Model::loadBlob(name.data(), [&] (const uint8_t *blob, size_t s) {
      43             :                         memcpy(buf, blob, size);
      44             :                 });*/
      45             : 
      46           3 :                 front->generateWeights(buf, size, cb);
      47           3 :         });
      48             : 
      49          12 :         auto freeTermsBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_freeTerms_buffer"),
      50           6 :                         BufferInfo(front->getKernelSize() * sizeof(float), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
      51           3 :                         [front = _front] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
      52             :                 /*auto name = toString(front->getName(), ".", front->getInputIndex(), ".terms.bin");
      53             :                 xenolith::shadernn::Model::loadBlob(name.data(), [&] (const uint8_t *blob, size_t s) {
      54             :                         memcpy(buf, blob, size);
      55             :                 });*/
      56             : 
      57           3 :                 front->generateFreeTerms(buf, size, cb);
      58           3 :         });
      59             : 
      60           3 :         _nbuffers = 4;
      61           3 :         _weightsBufferIndex = 0;
      62           3 :         _freeTermBufferIndex = 1;
      63           3 :         _inputBufferIndex = 2;
      64           3 :         _outputBufferIndex = 3;
      65             : 
      66           3 :         const AttachmentData *weightsAttachment = nullptr;
      67             : 
      68           3 :         weightsAttachment = queueBuilder.addAttachemnt(toString(builder.getName(), "_weights"),
      69           3 :                         [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
      70           6 :                 return Rc<vk::BufferAttachment>::create(builder, Vector<const BufferData *>{
      71           3 :                         weightsBuffer,
      72           3 :                         freeTermsBuffer
      73           6 :                 });
      74             :         });
      75             : 
      76           3 :         builder.addAttachment(input, AttachmentDependencyInfo::make(PipelineStage::ComputeShader, AccessType::ShaderRead));
      77           3 :         builder.addAttachment(output);
      78           3 :         auto passWeights = builder.addAttachment(weightsAttachment);
      79             : 
      80           3 :         auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
      81           3 :                 layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
      82           6 :                         setBuilder.addDescriptorArray(passWeights, _nbuffers, DescriptorType::StorageBuffer);
      83           3 :                 });
      84           6 :         });
      85             : 
      86           3 :         builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
      87             : 
      88           3 :                 auto matMul = subpassBuilder.addComputePipeline(toString(builder.getName(), "_matMul_pipeline"), layout,
      89           9 :                         SpecializationInfo(
      90           6 :                                 queueBuilder.addProgramByRef(toString(builder.getName(), "_matMul_shader"),
      91           3 :                                         getShader(LayerShader::MultiplyMatrixByTransposedMatrix, Precision::Unknown)),
      92           6 :                                 Vector<SpecializationConstant>{
      93           3 :                                         SpecializationConstant(_nbuffers), // nbuffers
      94             :                                         SpecializationConstant(_outputBufferIndex), // output
      95             :                                         SpecializationConstant(_inputBufferIndex), // input
      96             :                                         SpecializationConstant(_weightsBufferIndex),  // weight
      97           3 :                                         SpecializationConstant(_front->getInputIndex())
      98             :                                 }));
      99             : 
     100           3 :                 auto matMulBorders = subpassBuilder.addComputePipeline(toString(builder.getName(), "_matMulBorders_pipeline"), layout,
     101           9 :                         SpecializationInfo(
     102           6 :                                 queueBuilder.addProgramByRef(toString(builder.getName(), "_matMulBorders_shader"),
     103           3 :                                         getShader(LayerShader::MultiplyMatrixByTransposedMatrixBorder, Precision::Unknown)),
     104           9 :                                 Vector<SpecializationConstant>{
     105             :                                         SpecializationConstant(_nbuffers), // nbuffers
     106             :                                         SpecializationConstant(_outputBufferIndex), // output
     107             :                                         SpecializationConstant(_inputBufferIndex), // input
     108             :                                         SpecializationConstant(_weightsBufferIndex),  // weight
     109           3 :                                         SpecializationConstant(_front->getInputIndex())
     110             :                                 }));
     111             : 
     112           3 :                 auto addVec = subpassBuilder.addComputePipeline(toString(builder.getName(), "_addVec_pipeline"), layout,
     113           9 :                         SpecializationInfo(
     114           6 :                                 queueBuilder.addProgramByRef(toString(builder.getName(), "_addVec_shader"),
     115           3 :                                         getShader(LayerShader::AddVectorToMatrixRows, Precision::Unknown)),
     116           6 :                                 Vector<SpecializationConstant>{
     117             :                                         SpecializationConstant(_nbuffers), // nbuffers
     118             :                                         SpecializationConstant(_outputBufferIndex), // output
     119             :                                         SpecializationConstant(_outputBufferIndex), // output
     120             :                                         SpecializationConstant(_freeTermBufferIndex)  // terms
     121             :                                 }));
     122             : 
     123           3 :                 auto relu = subpassBuilder.addComputePipeline(toString(builder.getName(), "_relu_pipeline"), layout,
     124           9 :                         SpecializationInfo(
     125           6 :                                 queueBuilder.addProgramByRef(toString(builder.getName(), "_relu_shader"),
     126           3 :                                         getShader(LayerShader::VectorReLU, Precision::Unknown)),
     127           6 :                                 Vector<SpecializationConstant>{
     128             :                                         SpecializationConstant(_nbuffers), // nbuffers
     129             :                                         SpecializationConstant(_outputBufferIndex), // output
     130             :                                         SpecializationConstant(_outputBufferIndex), // output
     131             :                                 }));
     132             : 
     133           3 :                 auto relu4 = subpassBuilder.addComputePipeline(toString(builder.getName(), "_relu4_pipeline"), layout,
     134           9 :                         SpecializationInfo(
     135           6 :                                 queueBuilder.addProgramByRef(toString(builder.getName(), "_relu4_shader"),
     136           3 :                                         getShader(LayerShader::VectorReLU4, Precision::Unknown)),
     137           6 :                                 Vector<SpecializationConstant>{
     138             :                                         SpecializationConstant(_nbuffers), // nbuffers
     139             :                                         SpecializationConstant(_outputBufferIndex), // output
     140             :                                         SpecializationConstant(_outputBufferIndex), // output
     141             :                                 }));
     142             : 
     143           3 :                 subpassBuilder.setPrepareCallback([this] (const SubpassData &subpass, FrameQueue &q) {
     144             :                         // log::debug("MatrixMulLayer", getName(), ": prepare");
     145        3600 :                         auto layer = (MatrixMulLayer *)subpass.pass->pass.get();
     146             : 
     147        3600 :                         vk::BufferAttachmentHandle *inputBuffer = nullptr;
     148        3600 :                         vk::BufferAttachmentHandle *outputBuffer = nullptr;
     149        3600 :                         vk::BufferAttachmentHandle *weightsBuffer = nullptr;
     150             : 
     151        3600 :                         if (auto bufferAttachment = q.getAttachment(layer->getInputAttachment())) {
     152        3600 :                                 inputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     153             :                         }
     154             : 
     155        3600 :                         if (auto bufferAttachment = q.getAttachment(layer->getOutputAttachment())) {
     156        3600 :                                 outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     157             :                         }
     158             : 
     159        3600 :                         if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
     160        3600 :                                 weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     161             :                         }
     162             : 
     163        3600 :                         auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
     164        3600 :                         auto &pool = handle->getMemPool(nullptr);
     165             : 
     166        3600 :                         auto extent = layer->getFront()->getOutputExtent();
     167             : 
     168        3600 :                         auto input = inputBuffer->getBuffers().front().buffer;
     169             :                         auto output = pool->spawn(AllocationUsage::DeviceLocal,
     170        3600 :                                 BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
     171        3600 :                                         size_t(extent.depth * layer->getFront()->getKernelSize() * sizeof(float))
     172        7200 :                         ));
     173             : 
     174             :                         auto feedback = pool->spawn(AllocationUsage::DeviceLocal,
     175        3600 :                                 BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
     176        3600 :                                         size_t(output->getSize())
     177        7200 :                         ));
     178             : 
     179        3600 :                         weightsBuffer->addBufferView(input);
     180        3600 :                         weightsBuffer->addBufferView(output);
     181             : 
     182        3600 :                         outputBuffer->addBufferView(output);
     183        3600 :                         outputBuffer->addBufferView(feedback);
     184        3600 :                 });
     185             : 
     186           6 :                 subpassBuilder.setCommandsCallback(
     187           3 :                                 [this, outputBufferIndex = _outputBufferIndex, matMul, matMulBorders, addVec, relu4, relu]
     188       15600 :                                          (const SubpassData &subpass, FrameQueue &q, core::CommandBuffer &b) {
     189             :                         // log::debug("MatrixMulLayer", getName(), ": commands");
     190        3600 :                         auto &buf = static_cast<CommandBuffer &>(b);
     191        3600 :                         auto pass = static_cast<RenderPass *>(subpass.pass->impl.get());
     192        3600 :                         auto layer = static_cast<MatrixMulLayer *>(subpass.pass->pass.get());
     193             : 
     194        3600 :                         vk::BufferAttachmentHandle *weightsBuffer = nullptr;
     195        3600 :                         if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
     196        3600 :                                 weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     197             :                         }
     198             : 
     199        3600 :                         vk::BufferAttachmentHandle *outputBuffer = nullptr;
     200        3600 :                         if (auto bufferAttachment = q.getAttachment(layer->getOutputAttachment())) {
     201        3600 :                                 outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     202             :                         }
     203             : 
     204        3600 :                         auto output = outputBuffer->getBuffers()[0].buffer;
     205        3600 :                         auto feedback = outputBuffer->getBuffers()[1].buffer;
     206             : 
     207        3600 :                         auto kernelSize = layer->getFront()->getWeightSize();
     208             : 
     209        3600 :                         const int secondHeight = kernelSize.height;
     210        3600 :                         const int secondWidth = kernelSize.width;
     211             : 
     212        3600 :                         auto input = layer->getFront()->getInput();
     213             : 
     214        3600 :                         const int firstHeight = input->getOutputExtent().depth;
     215        3600 :                         const int firstWidth = input->getOutputExtent().width;
     216        3600 :                         const int resultWidth = layer->getFront()->getOutputExtent().width;
     217             : 
     218             :                         /*std::cout << "secondHeight: " << secondHeight << "\n";
     219             :                         std::cout << "secondWidth: " << secondWidth << "\n";
     220             :                         std::cout << "firstHeight: " << firstHeight << "\n";
     221             :                         std::cout << "firstWidth: " << firstWidth << "\n";
     222             :                         std::cout << "resultWidth: " << resultWidth << "\n";
     223             :                         std::cout << "\n";*/
     224             : 
     225        3600 :                         buf.cmdBindDescriptorSets(pass, 0);
     226             : 
     227        3600 :                         auto flags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
     228             : 
     229             :                         BufferMemoryBarrier barriers[4] = {
     230        3600 :                                         BufferMemoryBarrier(weightsBuffer->getBuffers()[0].buffer, flags, flags),
     231        3600 :                                         BufferMemoryBarrier(weightsBuffer->getBuffers()[1].buffer, flags, flags),
     232        3600 :                                         BufferMemoryBarrier(weightsBuffer->getBuffers()[2].buffer, flags, flags),
     233        3600 :                                         BufferMemoryBarrier(weightsBuffer->getBuffers()[3].buffer, flags, flags)
     234       18000 :                         };
     235             : 
     236        3600 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 4));
     237             : 
     238        3600 :                         MultiplyMatrixByTransposedMatrix(
     239        3600 :                                         buf, static_cast<ComputePipeline *>(matMul->pipeline.get()), static_cast<ComputePipeline *>(matMulBorders->pipeline.get()),
     240             :                                         /*first inputData, */firstHeight, firstWidth, firstWidth,
     241             :                                         /*second weightData, */secondHeight, secondWidth,
     242             :                                         /*resultoutputData, */resultWidth, /*unused*/0 );
     243             : 
     244           0 :                         BufferMemoryBarrier barrier(output, VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead),
     245        3600 :                                         VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead));
     246        3600 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
     247             : 
     248        3600 :                         AddVectorToMatrixRows(buf, static_cast<ComputePipeline *>(addVec->pipeline.get()),  /*batchSize*/1, firstHeight, resultWidth);
     249             : 
     250        3600 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 4));
     251             : 
     252             :                         // save original output
     253        3600 :                         buf.cmdCopyBuffer(output, feedback);
     254             : 
     255        3600 :                         if (layer->getFront()->getActivation() == Activation::RELU) {
     256        2400 :                                 buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
     257        2400 :                                 VectorReLU(buf, static_cast<ComputePipeline *>(relu4->pipeline.get()), static_cast<ComputePipeline *>(relu->pipeline.get()),
     258        2400 :                                                 output->getSize() / sizeof(float), 0.0f);
     259             :                         }
     260             : 
     261             :                         BufferMemoryBarrier barrier1[1] = {
     262             :                                 BufferMemoryBarrier(feedback, flags, flags),
     263        3600 :                         };
     264             : 
     265        3600 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barrier1, 1));
     266        3600 :                 });
     267           3 :         });
     268             : 
     269           3 :         builder.addCompleteCallback([this] (const QueuePassData &, FrameQueue &q, bool success) {
     270             :                 // log::debug("MatrixMulLayer", getName(), ": submitted");
     271             :                 /*vk::BufferAttachmentHandle *weightsBuffer = nullptr;
     272             :                 vk::BufferAttachmentHandle *outputBuffer = nullptr;
     273             : 
     274             :                 if (auto bufferAttachment = q.getAttachment(getWeightsAttachment())) {
     275             :                         weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     276             :                 }
     277             : 
     278             :                 if (auto bufferAttachment = q.getAttachment(getOutputAttachment())) {
     279             :                         outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     280             :                 }
     281             : 
     282             :                 auto sec = Time::now().toSeconds();
     283             : 
     284             :                 q.getFrame()->getLoop()->captureBuffer([this, sec] (const BufferInfo &info, BytesView view) {
     285             :                         xenolith::shadernn::Model::saveBlob(
     286             :                                         filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".weights.bin")).data(),
     287             :                                         view.data(), view.size());
     288             :                 }, weightsBuffer->getBuffers()[_weightsBufferIndex].buffer.get());
     289             : 
     290             :                 q.getFrame()->getLoop()->captureBuffer([this, sec] (const BufferInfo &info, BytesView view) {
     291             :                         xenolith::shadernn::Model::saveBlob(
     292             :                                         filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".terms.bin")).data(),
     293             :                                         view.data(), view.size());
     294             :                 }, weightsBuffer->getBuffers()[_freeTermBufferIndex].buffer.get());*/
     295             : 
     296             :                 /*outputBuffer->getBuffers()[1].buffer->map([this, sec] (uint8_t *data, VkDeviceSize size) {
     297             :                         std::cout << getName() << " ";
     298             :                         base16::encode(std::cout, BytesView(data, 64));
     299             :                         std::cout << "\n";
     300             :                         xenolith::shadernn::Model::saveBlob(
     301             :                                         filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".output.bin")).data(),
     302             :                                         data, size);
     303             :                 });*/
     304             : 
     305             :                 /*q.getFrame()->getLoop()->captureBuffer([this, sec] (const BufferInfo &info, BytesView view) {
     306             :                         xenolith::shadernn::Model::saveBlob(
     307             :                                         filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".input.bin")).data(),
     308             :                                         view.data(), view.size());
     309             :                 }, weightsBuffer->getBuffers()[_inputBufferIndex].buffer.get());
     310             : 
     311             :                 q.getFrame()->getLoop()->captureBuffer([this, weightsBuffer] (const BufferInfo &info, BytesView view) {
     312             :                         xenolith::shadernn::Model::saveBlob(
     313             :                                         filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".output.bin")).data(),
     314             :                                         view.data(), view.size());
     315             :                 }, weightsBuffer->getBuffers()[_outputBufferIndex].buffer.get());*/
     316        3600 :         });
     317             : 
     318           3 :         _inputAttachment = input;
     319           3 :         _outputAttachment = output;
     320           3 :         _weightsAttachment = weightsAttachment;
     321             : 
     322        3600 :         _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
     323        3600 :                 return Rc<vk::QueuePassHandle>::create(pass, q);
     324           3 :         };
     325             : 
     326           6 :         return QueuePass::init(builder);
     327             : }
     328             : 
     329           3 : void MatrixMulLayer::initPropagationSubpass(core::Queue::Builder &builder, core::QueuePassBuilder &queueBuilder,
     330             :                 core::SubpassBuilder &subpass, const core::PipelineLayoutData *layout) {
     331             : 
     332           3 :         auto backwardNeeded = isBackwardNeeded();
     333             : 
     334           3 :         _fullPropagationBuffers = _staticPropagationBuffers;
     335             : 
     336           3 :         _propWeightsIndex = _fullPropagationBuffers ++;
     337           3 :         _propTermsIndex = _fullPropagationBuffers ++;
     338           3 :         _propOriginalOutput = _fullPropagationBuffers ++;
     339           3 :         _propOriginalInput = _fullPropagationBuffers ++;
     340             : 
     341           3 :         _propOutputDiff = _fullPropagationBuffers ++;
     342           3 :         _propWeightsDiff = _fullPropagationBuffers ++;
     343           3 :         _propTermsDiff = _fullPropagationBuffers ++;
     344           3 :         _propFeedback = _fullPropagationBuffers ++;
     345           3 :         _propTargetIndex = _fullPropagationBuffers ++;
     346             : 
     347           3 :         const core::ComputePipelineData *matMul = nullptr;
     348           3 :         const core::ComputePipelineData *matMulBorders = nullptr;
     349           3 :         const core::ComputePipelineData *reluDiff = nullptr;
     350             : 
     351           3 :         subpass.setPrepareCallback([this, backwardNeeded, sourceWeights = _weightsBufferIndex, sourceTerms = _freeTermBufferIndex]
     352       43200 :                                                                 (const core::SubpassData &subpass, FrameQueue &q) {
     353        3600 :                 auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
     354        3600 :                 auto &pool = handle->getMemPool(nullptr);
     355             : 
     356        3600 :                 vk::BufferAttachmentHandle *weightsBuffer = nullptr;
     357        3600 :                 vk::BufferAttachmentHandle *outputBuffer = nullptr;
     358        3600 :                 vk::BufferAttachmentHandle *inputBuffer = nullptr;
     359        3600 :                 vk::BufferAttachmentHandle *propagationBuffer = nullptr;
     360        3600 :                 vk::BufferAttachmentHandle *externalPropagationSource = nullptr;
     361             : 
     362        3600 :                 if (auto bufferAttachment = q.getAttachment(getWeightsAttachment())) {
     363        3600 :                         weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     364             :                 }
     365             : 
     366        3600 :                 if (auto bufferAttachment = q.getAttachment(getOutputAttachment())) {
     367        3600 :                         outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     368             :                 }
     369             : 
     370        3600 :                 if (auto bufferAttachment = q.getAttachment(getInputAttachment())) {
     371        3600 :                         inputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     372             :                 }
     373             : 
     374        3600 :                 if (auto bufferAttachment = q.getAttachment(getPropagationAttachment())) {
     375        3600 :                         propagationBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     376             :                 }
     377             : 
     378        3600 :                 if (auto bufferAttachment = q.getAttachment(getExternalPropagationDataSource())) {
     379        3600 :                         externalPropagationSource = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     380             :                 }
     381             : 
     382        3600 :                 propagationBuffer->addBufferView(weightsBuffer->getBuffers()[sourceWeights].buffer);
     383        3600 :                 propagationBuffer->addBufferView(weightsBuffer->getBuffers()[sourceTerms].buffer);
     384        3600 :                 propagationBuffer->addBufferView(outputBuffer->getBuffers().back().buffer); // use feedback, direct output transformed with activation
     385        3600 :                 propagationBuffer->addBufferView(inputBuffer->getBuffers().front().buffer);
     386             : 
     387             :                 // output from prev layer
     388        3600 :                 propagationBuffer->addBufferView(externalPropagationSource->getBuffers()[getExternalPropagationBufferIdx()].buffer);
     389             : 
     390             :                 auto weightsDiff = pool->spawn(AllocationUsage::DeviceLocal,
     391        3600 :                         BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::TransferDst | core::BufferUsage::StorageBuffer, PassType::Compute,
     392        3600 :                                 size_t(_front->getWeightBufferSize())
     393        7200 :                 ));
     394        3600 :                 propagationBuffer->addBufferView(weightsDiff);
     395             : 
     396             :                 auto termsDiff = pool->spawn(AllocationUsage::DeviceLocal,
     397        3600 :                         BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::TransferDst | core::BufferUsage::StorageBuffer, PassType::Compute,
     398        3600 :                                 size_t(_front->getKernelSize() * sizeof(float))
     399        7200 :                 ));
     400        3600 :                 propagationBuffer->addBufferView(termsDiff);
     401             : 
     402        3600 :                 auto weightExtent = _front->getWeightSize();
     403        3600 :                 auto outputExtent = _front->getOutputExtent();
     404             : 
     405        3600 :                 const int resultBufferSize = outputExtent.depth * weightExtent.width;
     406             : 
     407             :                 auto feedback = pool->spawn(AllocationUsage::DeviceLocal,
     408        3600 :                         BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::TransferDst | core::BufferUsage::StorageBuffer, PassType::Compute,
     409        3600 :                                 size_t(resultBufferSize * sizeof(float))
     410        7200 :                 ));
     411        3600 :                 propagationBuffer->addBufferView(feedback);
     412             : 
     413             :                 auto inputDiff = pool->spawn(AllocationUsage::DeviceLocal,
     414        3600 :                         BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
     415        3600 :                                 size_t(resultBufferSize * sizeof(float))
     416        7200 :                 ));
     417             : 
     418        3600 :                 propagationBuffer->addBufferView(inputDiff);
     419        3600 :         });
     420             : 
     421           3 :         if (backwardNeeded) {
     422           2 :                 matMul = subpass.addComputePipeline(toString(getName(), "_BackwardOnce_mul"), layout,
     423           6 :                         core::SpecializationInfo(
     424           4 :                                 builder.addProgramByRef(toString(getName(), "_BackwardOnce_mul"),
     425           2 :                                         getShader(LayerShader::MultiplyMatrixByMatrix, Precision::Unknown)),
     426           6 :                                 Vector<core::SpecializationConstant>{
     427           2 :                                         core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     428             :                                         core::SpecializationConstant(_propTargetIndex), // output
     429             :                                         core::SpecializationConstant(_propOutputDiff), // input
     430             :                                         core::SpecializationConstant(_propWeightsIndex),  // weight
     431           2 :                                         core::SpecializationConstant(_front->getInputIndex()),
     432             :                                 }));
     433             : 
     434           2 :                 matMulBorders = subpass.addComputePipeline(toString(getName(), "_BackwardOnce_mulBorders"), layout,
     435           6 :                         core::SpecializationInfo(
     436           4 :                                 builder.addProgramByRef(toString(getName(), "_BackwardOnce_mulBorders"),
     437           2 :                                         getShader(LayerShader::MultiplyMatrixByMatrixBorder, Precision::Unknown)),
     438           4 :                                 Vector<core::SpecializationConstant>{
     439           2 :                                         core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     440             :                                         core::SpecializationConstant(_propTargetIndex), // output
     441             :                                         core::SpecializationConstant(_propOutputDiff), // input
     442             :                                         core::SpecializationConstant(_propWeightsIndex)  // weight
     443             :                                 }));
     444             :         }
     445             : 
     446           3 :         if (_front->getActivation() == Activation::RELU) {
     447           2 :                 reluDiff = subpass.addComputePipeline(toString(getName(), "_BackwardOnce_reluDiff"), layout,
     448           6 :                         core::SpecializationInfo(
     449           4 :                                 builder.addProgramByRef(toString(getName(), "_BackwardOnce_reluDiff"),
     450           2 :                                         getShader(LayerShader::VectorReLUDiff, Precision::Unknown)),
     451           4 :                                 Vector<core::SpecializationConstant>{
     452           2 :                                         core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     453             :                                         core::SpecializationConstant(_propOutputDiff), // input diff
     454             :                                         core::SpecializationConstant(_propOriginalOutput), // original output
     455             :                                         core::SpecializationConstant(_propOutputDiff)  // output diff
     456             :                                 }));
     457             :         }
     458             : 
     459           3 :         auto learnMatMul = subpass.addComputePipeline(toString(getName(), "_LearnOnce_MatMul"), layout,
     460           9 :                 core::SpecializationInfo(
     461           6 :                         builder.addProgramByRef(toString(getName(), "_LearnOnce_MatMul"),
     462           3 :                                 getShader(LayerShader::MultiplyTransposedMatrixByMatrix, Precision::Unknown)),
     463           6 :                         Vector<core::SpecializationConstant>{
     464           3 :                                 core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     465             :                                 core::SpecializationConstant(_propWeightsDiff),
     466             :                                 core::SpecializationConstant(_propOutputDiff),
     467             :                                 core::SpecializationConstant(_propOriginalInput)
     468             :                         }));
     469             : 
     470           3 :         auto learnMatMulBorder = subpass.addComputePipeline(toString(getName(), "_LearnOnce_MatMulBorder"), layout,
     471           9 :                 core::SpecializationInfo(
     472           6 :                         builder.addProgramByRef(toString(getName(), "_LearnOnce_MatMulBorder"),
     473           3 :                                 getShader(LayerShader::MultiplyTransposedMatrixByMatrixBorder, Precision::Unknown)),
     474           6 :                         Vector<core::SpecializationConstant>{
     475           3 :                                 core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     476             :                                 core::SpecializationConstant(_propWeightsDiff),
     477             :                                 core::SpecializationConstant(_propOutputDiff),
     478             :                                 core::SpecializationConstant(_propOriginalInput)
     479             :                         }));
     480             : 
     481           3 :         auto learnSum = subpass.addComputePipeline(toString(getName(), "_LearnOnce_Sum"), layout,
     482           9 :                         core::SpecializationInfo(
     483           6 :                                 builder.addProgramByRef(toString(getName(), "_LearnOnce_Sum"),
     484           3 :                                         getShader(LayerShader::SumMatrixRows, Precision::Unknown)),
     485           6 :                                 Vector<core::SpecializationConstant>{
     486           3 :                                         core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     487             :                                         core::SpecializationConstant(_propTermsDiff),
     488             :                                         core::SpecializationConstant(_propOutputDiff),
     489             :                                 }));
     490             : 
     491             :         struct TrainPipelines {
     492             :                 const core::ComputePipelineData *decayHistory;
     493             :                 const core::ComputePipelineData *multHistory;
     494             :                 const core::ComputePipelineData *add4;
     495             :                 const core::ComputePipelineData *add1;
     496             :         };
     497             : 
     498           6 :         auto initTrainPipelines = [&] (uint32_t staticparam, uint32_t diff, uint32_t target) {
     499             :                 TrainPipelines ret;
     500             : 
     501           6 :                 ret.decayHistory = subpass.addComputePipeline(toString(getName(), "_trainDecayHistory", staticparam), layout,
     502          18 :                         core::SpecializationInfo(
     503          12 :                                 builder.addProgramByRef(toString(getName(), "_trainDecayHistory", staticparam),
     504           6 :                                         getShader(LayerShader::VectorEltwiseMultiply, Precision::Unknown)),
     505          12 :                                 Vector<core::SpecializationConstant>{
     506           6 :                                         core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     507             :                                         core::SpecializationConstant(staticparam),
     508             :                                         core::SpecializationConstant(staticparam),
     509             :                                         core::SpecializationConstant(_staticParams),
     510             :                                         core::SpecializationConstant(TV_MomentDecayRateVar),
     511             :                                 }));
     512             : 
     513           6 :                 ret.multHistory = subpass.addComputePipeline(toString(getName(), "_trainHistoryAdd", staticparam), layout,
     514          18 :                         core::SpecializationInfo(
     515          12 :                                 builder.addProgramByRef(toString(getName(), "_trainHistoryAdd", staticparam),
     516           6 :                                         getShader(LayerShader::VectorMultiplyAndAdd, Precision::Unknown)),
     517          12 :                                 Vector<core::SpecializationConstant>{
     518           6 :                                         core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     519             :                                         core::SpecializationConstant(staticparam),
     520             :                                         core::SpecializationConstant(staticparam),
     521             :                                         core::SpecializationConstant(diff),
     522             :                                         core::SpecializationConstant(_staticParams),
     523             :                                         core::SpecializationConstant(TV_RateVar),
     524             :                                 }));
     525             : 
     526           6 :                 ret.add4 = subpass.addComputePipeline(toString(getName(), "_trainAdd4_", staticparam), layout,
     527          18 :                         core::SpecializationInfo(
     528          12 :                                 builder.addProgramByRef(toString(getName(), "_trainAdd4_", staticparam),
     529           6 :                                         getShader(LayerShader::VectorAddFloat4, Precision::Unknown)),
     530          12 :                                 Vector<core::SpecializationConstant>{
     531           6 :                                         core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     532             :                                         core::SpecializationConstant(target),
     533             :                                         core::SpecializationConstant(target),
     534             :                                         core::SpecializationConstant(staticparam)
     535             :                                 }));
     536             : 
     537           6 :                 ret.add1 = subpass.addComputePipeline(toString(getName(), "_trainAdd1_", staticparam), layout,
     538          18 :                         core::SpecializationInfo(
     539          12 :                                 builder.addProgramByRef(toString(getName(), "_trainAdd1_", staticparam),
     540           6 :                                         getShader(LayerShader::VectorAddFloat1, Precision::Unknown)),
     541          12 :                                 Vector<core::SpecializationConstant>{
     542           6 :                                         core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
     543             :                                         core::SpecializationConstant(target),
     544             :                                         core::SpecializationConstant(target),
     545             :                                         core::SpecializationConstant(staticparam)
     546             :                                 }));
     547             : 
     548           6 :                 return ret;
     549           3 :         };
     550             : 
     551           3 :         auto trainWeights = initTrainPipelines(_staticWeightsHistoryIndex, _propWeightsDiff, _propWeightsIndex);
     552           3 :         auto trainTerms = initTrainPipelines(_staticTermsHistoryIndex, _propTermsDiff, _propTermsIndex);
     553             : 
     554           3 :         subpass.setCommandsCallback([this, backwardNeeded, layoutIndex = layout->index, matMul, matMulBorders,
     555             :                                                                  reluDiff, learnMatMul, learnMatMulBorder, learnSum,
     556             :                                                                  trainWeights, trainTerms]
     557       76800 :                                                                  (const core::SubpassData &subpass, core::FrameQueue &q, core::CommandBuffer &b) {
     558        3600 :                 auto &buf = static_cast<CommandBuffer &>(b);
     559        3600 :                 auto pass = static_cast<RenderPass *>(subpass.pass->impl.get());
     560        3600 :                 auto front = getFront();
     561             : 
     562        3600 :                 auto weightExtent = front->getWeightSize();
     563        3600 :                 auto outputExtent = front->getOutputExtent();
     564             : 
     565       10800 :                 auto makeFullBarrier = [&] (Buffer *b) {
     566       10800 :                         auto BufferAccessFlags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
     567       10800 :                         BufferMemoryBarrier barrier(b, BufferAccessFlags, BufferAccessFlags);
     568       10800 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
     569       10800 :                 };
     570             : 
     571        7200 :                 auto makeFullBarrier2 = [&] (Buffer *b1, Buffer *b2) {
     572        7200 :                         auto BufferAccessFlags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
     573             :                         BufferMemoryBarrier barriers[] = {
     574             :                                         BufferMemoryBarrier(b1, BufferAccessFlags, BufferAccessFlags),
     575             :                                         BufferMemoryBarrier(b2, BufferAccessFlags, BufferAccessFlags)
     576        7200 :                         };
     577        7200 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 2));
     578        7200 :                 };
     579             : 
     580        3600 :                 auto makeFullBarrier4 = [&] (Buffer *b1, Buffer *b2, Buffer *b3, Buffer *b4) {
     581        3600 :                         auto BufferAccessFlags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
     582             :                         BufferMemoryBarrier barriers[] = {
     583             :                                         BufferMemoryBarrier(b1, BufferAccessFlags, BufferAccessFlags),
     584             :                                         BufferMemoryBarrier(b2, BufferAccessFlags, BufferAccessFlags),
     585             :                                         BufferMemoryBarrier(b3, BufferAccessFlags, BufferAccessFlags),
     586             :                                         BufferMemoryBarrier(b4, BufferAccessFlags, BufferAccessFlags)
     587        3600 :                         };
     588        3600 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 4));
     589        3600 :                 };
     590             : 
     591        3600 :                 vk::BufferAttachmentHandle *propagationBuffer = nullptr;
     592             : 
     593        3600 :                 if (auto bufferAttachment = q.getAttachment(getPropagationAttachment())) {
     594        3600 :                         propagationBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     595             :                 }
     596             : 
     597        3600 :                 buf.cmdBindDescriptorSets(pass, layoutIndex);
     598             : 
     599        3600 :                 makeFullBarrier(propagationBuffer->getBuffers()[_propOutputDiff].buffer);
     600             : 
     601        3600 :                 if (front->getActivation() == Activation::RELU) {
     602        2400 :                         const int inputBufferSize = outputExtent.depth * front->getKernelSize();
     603        2400 :                         VectorReLUDiff(buf, static_cast<ComputePipeline *>(reluDiff->pipeline.get()), inputBufferSize, 0.0f);
     604        2400 :                         makeFullBarrier(propagationBuffer->getBuffers()[_propOutputDiff].buffer);
     605             :                 }
     606             : 
     607        3600 :                 if (backwardNeeded) {
     608        2400 :                         const int secondWidth = weightExtent.width;
     609        2400 :                         const int firstHeight = outputExtent.depth;
     610        2400 :                         const int firstWidth = outputExtent.width;
     611        2400 :                         const int resultBufferSize = firstHeight * secondWidth;
     612             : 
     613        2400 :                         MultiplyMatrixByMatrix(buf, static_cast<ComputePipeline *>(matMul->pipeline.get()), static_cast<ComputePipeline *>(matMulBorders->pipeline.get()),
     614             :                                         1, firstHeight, firstWidth, secondWidth, resultBufferSize);
     615             : 
     616        2400 :                         makeFullBarrier(propagationBuffer->getBuffers()[_propTargetIndex].buffer);
     617        2400 :                         buf.cmdCopyBuffer(propagationBuffer->getBuffers()[_propTargetIndex].buffer, propagationBuffer->getBuffers()[_propFeedback].buffer);
     618        2400 :                         makeFullBarrier(propagationBuffer->getBuffers()[_propFeedback].buffer);
     619             :                 }
     620             : 
     621        3600 :                 MultiplyTransposedMatrixByMatrix(buf, static_cast<ComputePipeline *>(learnMatMul->pipeline.get()),
     622        3600 :                                 static_cast<ComputePipeline *>(learnMatMulBorder->pipeline.get()),
     623        3600 :                                 outputExtent.depth, weightExtent.height, weightExtent.height,
     624        3600 :                                 weightExtent.width, weightExtent.width, weightExtent.width,
     625        3600 :                                 weightExtent.width * weightExtent.height);
     626             : 
     627        3600 :                 SumMatrixRows(buf, static_cast<ComputePipeline *>(learnSum->pipeline.get()),
     628        3600 :                                 1, outputExtent.depth, weightExtent.height);
     629             : 
     630        3600 :                 auto weightsSize = front->getWeightBufferSize() / sizeof(float);
     631        3600 :                 auto termsSize = front->getKernelSize();
     632             : 
     633        3600 :                 VectorMultiply(buf, static_cast<ComputePipeline *>(trainWeights.decayHistory->pipeline.get()), weightsSize);
     634        3600 :                 VectorMultiply(buf, static_cast<ComputePipeline *>(trainTerms.decayHistory->pipeline.get()), termsSize);
     635             : 
     636        7200 :                 makeFullBarrier4(propagationBuffer->getBuffers()[_propWeightsDiff].buffer, propagationBuffer->getBuffers()[_propTermsDiff].buffer,
     637        7200 :                                 propagationBuffer->getBuffers()[_staticWeightsHistoryIndex].buffer, propagationBuffer->getBuffers()[_staticTermsHistoryIndex].buffer);
     638        3600 :                 VectorMultiplyAndAdd(buf, static_cast<ComputePipeline *>(trainWeights.multHistory->pipeline.get()), weightsSize);
     639        3600 :                 VectorMultiplyAndAdd(buf, static_cast<ComputePipeline *>(trainTerms.multHistory->pipeline.get()), termsSize);
     640             : 
     641        3600 :                 makeFullBarrier2(propagationBuffer->getBuffers()[_staticWeightsHistoryIndex].buffer,
     642        3600 :                                 propagationBuffer->getBuffers()[_staticTermsHistoryIndex].buffer);
     643             : 
     644        3600 :                 VectorAdd(buf, static_cast<ComputePipeline *>(trainWeights.add4->pipeline.get()),
     645        3600 :                                 static_cast<ComputePipeline *>(trainWeights.add1->pipeline.get()), weightsSize);
     646        3600 :                 VectorAdd(buf, static_cast<ComputePipeline *>(trainTerms.add4->pipeline.get()),
     647        3600 :                                 static_cast<ComputePipeline *>(trainTerms.add1->pipeline.get()), termsSize);
     648             : 
     649        3600 :                 makeFullBarrier2(propagationBuffer->getBuffers()[_propWeightsIndex].buffer,
     650        3600 :                                 propagationBuffer->getBuffers()[_propTermsIndex].buffer);
     651        3600 :         });
     652             : 
     653           3 :         queueBuilder.addCompleteCallback([this] (const core::QueuePassData &, FrameQueue &q, bool success) {
     654             :                 /*vk::BufferAttachmentHandle *propagationBuffer = nullptr;
     655             :                 if (auto bufferAttachment = q.getAttachment(getPropagationAttachment())) {
     656             :                         propagationBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     657             :                 }
     658             : 
     659             :                 q.getFrame()->getLoop()->captureBuffer([front = _front] (const BufferInfo &info, BytesView view) {
     660             :                         xenolith::shadernn::Model::saveBlob(
     661             :                                         filesystem::currentDir<Interface>(toString(front->getName(),".", front->getInputIndex(), ".weightsDiff.bin")).data(),
     662             :                                         view.data(), view.size());
     663             :                         std::cout << "Save: " << toString(front->getName(),".", front->getInputIndex(), ".weightsDiff.bin") << "\n";
     664             :                 }, propagationBuffer->getBuffers()[_propWeightsDiff].buffer.get());
     665             : 
     666             :                 q.getFrame()->getLoop()->captureBuffer([front = _front] (const BufferInfo &info, BytesView view) {
     667             :                         xenolith::shadernn::Model::saveBlob(
     668             :                                         filesystem::currentDir<Interface>(toString(front->getName(),".", front->getInputIndex(), ".termsDiff.bin")).data(),
     669             :                                         view.data(), view.size());
     670             :                         std::cout << "Save: " << toString(front->getName(),".", front->getInputIndex(), ".termsDiff.bin") << "\n";
     671             :                 }, propagationBuffer->getBuffers()[_propTermsDiff].buffer.get());
     672             : 
     673             :                 q.getFrame()->getLoop()->captureBuffer([front = _front] (const BufferInfo &info, BytesView view) {
     674             :                         xenolith::shadernn::Model::saveBlob(
     675             :                                         filesystem::currentDir<Interface>(toString(front->getName(),".", front->getInputIndex(), ".feedback.bin")).data(),
     676             :                                         view.data(), view.size());
     677             :                         std::cout << "Save: " << toString(front->getName(),".", front->getInputIndex(), ".output.bin") << "\n";
     678             :                 }, propagationBuffer->getBuffers()[_propFeedback].buffer.get());*/
     679        3600 :         });
     680             : 
     681           3 :         _targetPropagationIdx = _propTargetIndex;
     682           3 : }
     683             : 
     684           3 : Vector<const core::BufferData *> MatrixMulLayer::getTrainableGradients(Queue::Builder &queueBuilder) const {
     685           9 :         auto weightsGradientBuffer = queueBuilder.addBuffer(toString(getName(), "_weightsGradient_buffer"),
     686           6 :                         BufferInfo(_front->getWeightBufferSize(), core::BufferUsage::StorageBuffer, PassType::Compute),
     687           3 :                         [] (uint8_t *buf, uint64_t size, const core::BufferData::DataCallback &cb) {
     688           3 :                 FillFloatBuffer(buf, size, 0.0f);
     689           3 :         });
     690             : 
     691           9 :         auto freeTermsGradientBuffer = queueBuilder.addBuffer(toString(getName(), "_freeTermsGradient_buffer"),
     692           6 :                         BufferInfo(_front->getKernelSize() * sizeof(float), core::BufferUsage::StorageBuffer, PassType::Compute),
     693           3 :                         [] (uint8_t *buf, uint64_t size, const core::BufferData::DataCallback &cb) {
     694           3 :                 FillFloatBuffer(buf, size, 0.0f);
     695           3 :         });
     696             : 
     697             :         return Vector<const core::BufferData *>{
     698             :                 weightsGradientBuffer,
     699             :                 freeTermsGradientBuffer
     700           3 :         };
     701             : }
     702             : 
     703             : }

Generated by: LCOV version 1.14