LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkLossLayer.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 208 216 96.3 %
Date: 2024-05-06 04:51:23 Functions: 30 30 100.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnVkLossLayer.h"
      24             : 
      25             : namespace stappler::xenolith::vk::shadernn {
      26             : 
      27             : static auto BufferAccessFlags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
      28             : 
      29           9 : static StringView getPipelineOpName(CrossEntropyLossLayer::PipelineOpIndex idx) {
      30           9 :         switch (idx) {
      31           1 :         case CrossEntropyLossLayer::PipelineOpIndex::MatrixSoftmaxByRows: return "MatrixSoftmaxByRows"; break;
      32           1 :         case CrossEntropyLossLayer::PipelineOpIndex::VectorNegLog: return "VectorNegLog"; break;
      33           1 :         case CrossEntropyLossLayer::PipelineOpIndex::VectorEltwiseMultiply: return "VectorEltwiseMultiply"; break;
      34           1 :         case CrossEntropyLossLayer::PipelineOpIndex::SumMatrixColumnsToResult: return "SumMatrixColumnsToResult"; break;
      35           1 :         case CrossEntropyLossLayer::PipelineOpIndex::VectorSub: return "VectorSub"; break;
      36           1 :         case CrossEntropyLossLayer::PipelineOpIndex::SumMatrixColumnsLabels: return "SumMatrixColumns"; break;
      37           1 :         case CrossEntropyLossLayer::PipelineOpIndex::MultiplyDiagMatrixByMatrix: return "MultiplyDiagMatrixByMatrix"; break;
      38           1 :         case CrossEntropyLossLayer::PipelineOpIndex::VectorDotProduct: return "VectorDotProduct"; break;
      39           1 :         case CrossEntropyLossLayer::PipelineOpIndex::MultiplyDiagMatrixByMatrixForInput: return "MultiplyDiagMatrixByMatrixForInput"; break;
      40             :         }
      41           0 :         return StringView();
      42             : }
      43             : 
      44           9 : static LayerShader getPipelineOpShader(CrossEntropyLossLayer::PipelineOpIndex idx) {
      45           9 :         switch (idx) {
      46           1 :         case CrossEntropyLossLayer::PipelineOpIndex::MatrixSoftmaxByRows: return LayerShader::MatrixSoftmaxByRows; break;
      47           1 :         case CrossEntropyLossLayer::PipelineOpIndex::VectorNegLog: return LayerShader::VectorLog; break;
      48           1 :         case CrossEntropyLossLayer::PipelineOpIndex::VectorEltwiseMultiply: return LayerShader::VectorEltwiseMultiply; break;
      49           1 :         case CrossEntropyLossLayer::PipelineOpIndex::SumMatrixColumnsToResult: return LayerShader::SumMatrixColumns; break;
      50           1 :         case CrossEntropyLossLayer::PipelineOpIndex::VectorSub: return LayerShader::VectorSub; break;
      51           1 :         case CrossEntropyLossLayer::PipelineOpIndex::SumMatrixColumnsLabels: return LayerShader::SumMatrixColumns; break;
      52           1 :         case CrossEntropyLossLayer::PipelineOpIndex::MultiplyDiagMatrixByMatrix: return LayerShader::MultiplyDiagMatrixByMatrix; break;
      53           1 :         case CrossEntropyLossLayer::PipelineOpIndex::VectorDotProduct: return LayerShader::VectorDotProduct; break;
      54           1 :         case CrossEntropyLossLayer::PipelineOpIndex::MultiplyDiagMatrixByMatrixForInput: return LayerShader::MultiplyDiagMatrixByMatrix; break;
      55             :         }
      56           0 :         return LayerShader::Gen;
      57             : }
      58             : 
      59           2 : CrossEntropyLossLayer::~CrossEntropyLossLayer() { }
      60             : 
      61           1 : bool CrossEntropyLossLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front * front,
      62             :                 const AttachmentData *inputLabels, const AttachmentData *inputNetwork, const AttachmentData *output) {
      63             :         using namespace core;
      64             : 
      65           1 :         _front = front;
      66             : 
      67           3 :         auto paramsBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_params_buffer"),
      68           2 :                         BufferInfo(front->getParameters().size() * sizeof(float), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
      69           1 :                         [front = front] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
      70           1 :                 memcpy(buf, front->getParameters().data(), size);
      71           1 :         });
      72             : 
      73           3 :         auto weightsBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_weights_buffer"),
      74           2 :                         BufferInfo(front->getWeightBufferSize(), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
      75           1 :                         [] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
      76           1 :                 FillFloatBuffer(buf, size, 1.0f);
      77           1 :         });
      78             : 
      79           3 :         auto resultBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_result_buffer"),
      80           2 :                         BufferInfo(front->getResultBufferSize(), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
      81           1 :                         [] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
      82           1 :                 FillFloatBuffer(buf, size, 0.0f);
      83           1 :         });
      84             : 
      85           3 :         auto lossGradientBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_lossGradient_buffer"),
      86           2 :                         BufferInfo(front->getLossGradientBufferSize(), BufferUsage::StorageBuffer, PassType::Compute),
      87           1 :                         [] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
      88           1 :                 FillFloatBuffer(buf, size, 0.0f);
      89           1 :         });
      90             : 
      91           1 :         auto weightsAttachment = queueBuilder.addAttachemnt(toString(builder.getName(), "_weights"),
      92           1 :                         [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
      93           2 :                 return Rc<vk::BufferAttachment>::create(builder, Vector<const BufferData *>{
      94           1 :                         paramsBuffer,
      95           1 :                         weightsBuffer,
      96           1 :                         resultBuffer,
      97           1 :                         lossGradientBuffer,
      98           2 :                 });
      99             :         });
     100             : 
     101           1 :         builder.addAttachment(inputLabels, AttachmentDependencyInfo::make(PipelineStage::ComputeShader, AccessType::ShaderRead));
     102           1 :         builder.addAttachment(inputNetwork, AttachmentDependencyInfo::make(PipelineStage::ComputeShader, AccessType::ShaderRead));
     103           1 :         builder.addAttachment(output);
     104           1 :         auto passWeights = builder.addAttachment(weightsAttachment);
     105             : 
     106           1 :         auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
     107           1 :                 layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
     108           2 :                         setBuilder.addDescriptorArray(passWeights, DescriptorArraySize, DescriptorType::StorageBuffer);
     109           1 :                 });
     110           2 :         });
     111             : 
     112           1 :         builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
     113           4 :                 auto addPipeline2 = [&] (PipelineOpIndex idx, uint32_t output, uint32_t input, PipelineOpFn &&fn) {
     114           4 :                         auto name = getPipelineOpName(idx);
     115           4 :                         auto shader = getPipelineOpShader(idx);
     116             : 
     117           4 :                         auto data = subpassBuilder.addComputePipeline(toString(builder.getName(), "_", name), layout,
     118          12 :                                 SpecializationInfo(
     119           8 :                                         queueBuilder.addProgramByRef(toString(builder.getName(), "_", name, "_shader"),
     120           4 :                                                 getShader(shader, Precision::Unknown)),
     121           8 :                                         Vector<SpecializationConstant>{
     122             :                                                 SpecializationConstant(DescriptorArraySize), // nbuffers
     123             :                                                 SpecializationConstant(output), // output
     124             :                                                 SpecializationConstant(input), // input
     125             :                                         }));
     126             : 
     127           4 :                         _pipelineOps.emplace(idx, PipelineOp(
     128           4 :                                 idx, data, move(fn)
     129             :                         ));
     130           4 :                 };
     131             : 
     132           5 :                 auto addPipeline3 = [&] (PipelineOpIndex idx, uint32_t output, uint32_t inputA, uint32_t inputB, PipelineOpFn &&fn,
     133             :                                 Vector<SpecializationConstant> &&extra = Vector<SpecializationConstant>()) {
     134           5 :                         auto name = getPipelineOpName(idx);
     135           5 :                         auto shader = getPipelineOpShader(idx);
     136             : 
     137             :                         auto constants = Vector<SpecializationConstant>{
     138             :                                 SpecializationConstant(DescriptorArraySize), // nbuffers
     139             :                                 SpecializationConstant(output), // output
     140             :                                 SpecializationConstant(inputA), // input
     141             :                                 SpecializationConstant(inputB), // input
     142           5 :                         };
     143             : 
     144          13 :                         for (auto &it : extra) {
     145           8 :                                 constants.emplace_back(it);
     146             :                         }
     147             : 
     148           5 :                         auto data = subpassBuilder.addComputePipeline(toString(builder.getName(), "_", name), layout,
     149          15 :                                 SpecializationInfo(
     150          10 :                                         queueBuilder.addProgramByRef(toString(builder.getName(), "_", name, "_shader"),
     151           5 :                                                 getShader(shader, Precision::Unknown)),
     152             :                                                 constants));
     153             : 
     154           5 :                         _pipelineOps.emplace(idx, PipelineOp(
     155           5 :                                 idx, data, move(fn)
     156             :                         ));
     157             : 
     158           5 :                         return data;
     159           5 :                 };
     160             : 
     161           1 :                 addPipeline2(PipelineOpIndex::MatrixSoftmaxByRows, ActivationIdx, InputNetworkIdx,
     162        1200 :                                 [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
     163        1200 :                         MatrixSoftmaxByRows( buf, pipeline, front->getBatchSize(), front->getClassesCount() );
     164        1200 :                 });
     165             : 
     166           1 :                 addPipeline2(PipelineOpIndex::VectorNegLog, ActivationEltwiseMulIdx, ActivationIdx,
     167        1200 :                                 [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
     168             : 
     169        1200 :                         BufferMemoryBarrier barrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags);
     170             : 
     171        1200 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
     172             : 
     173        1200 :                         VectorNegLog( buf, pipeline, front->getBatchSize() * front->getClassesCount() );
     174        1200 :                 });
     175             : 
     176           1 :                 addPipeline3(PipelineOpIndex::VectorEltwiseMultiply, ActivationEltwiseMulIdx, InputLabelsIdx, ActivationEltwiseMulIdx,
     177        1200 :                                 [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
     178             : 
     179        1200 :                         BufferMemoryBarrier barrier(buffers[ActivationEltwiseMulIdx].buffer, BufferAccessFlags, BufferAccessFlags);
     180             : 
     181        1200 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
     182             : 
     183        1200 :                         VectorEltwiseMultiply( buf, pipeline, front->getBatchSize() * front->getClassesCount() );
     184        1200 :                 });
     185             : 
     186           1 :                 addPipeline2(PipelineOpIndex::SumMatrixColumnsToResult, LossValueIdx, ActivationEltwiseMulIdx,
     187        1200 :                                 [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
     188             : 
     189        1200 :                         BufferMemoryBarrier barrier(buffers[ActivationEltwiseMulIdx].buffer, BufferAccessFlags, BufferAccessFlags);
     190             : 
     191        1200 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
     192             : 
     193        1200 :                         SumMatrixColumns( buf, pipeline, front->getBatchSize(), front->getClassesCount() );
     194        1200 :                 });
     195             : 
     196           1 :                 addPipeline3(PipelineOpIndex::VectorDotProduct, ParamsIdx, WeightsIdx, LossValueIdx,
     197        1200 :                                 [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
     198             : 
     199             :                         BufferMemoryBarrier barriers[2] = {
     200           0 :                                 BufferMemoryBarrier(buffers[WeightsIdx].buffer, BufferAccessFlags, BufferAccessFlags),
     201           0 :                                 BufferMemoryBarrier(buffers[LossValueIdx].buffer, BufferAccessFlags, BufferAccessFlags),
     202        1200 :                         };
     203             : 
     204        1200 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 2));
     205             : 
     206        1200 :                         VectorDotProduct( buf, pipeline, front->getBatchSize());
     207        1202 :                 }, Vector<SpecializationConstant>{
     208             :                         SpecializationConstant(Front::P_Loss),
     209             :                         SpecializationConstant(1),
     210             :                         SpecializationConstant(Front::P_LossDivider),
     211             :                 });
     212             : 
     213           1 :                 if (front->getModel()->isTrainable()) {
     214           1 :                         addPipeline3(PipelineOpIndex::VectorSub, ActivationEltwiseMulIdx, ActivationIdx, InputLabelsIdx,
     215        1200 :                                         [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
     216             : 
     217             :                                 BufferMemoryBarrier barriers[2] = {
     218           0 :                                         BufferMemoryBarrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags),
     219           0 :                                         BufferMemoryBarrier(buffers[ActivationEltwiseMulIdx].buffer, BufferAccessFlags, BufferAccessFlags),
     220        1200 :                                 };
     221             : 
     222        1200 :                                 buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 2));
     223             : 
     224        1200 :                                 VectorSub( buf, pipeline, front->getBatchSize() * front->getClassesCount() );
     225        1200 :                         });
     226             : 
     227           1 :                         addPipeline2(PipelineOpIndex::SumMatrixColumnsLabels, ActivationIdx, InputLabelsIdx,
     228        1200 :                                         [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
     229             : 
     230        1200 :                                 BufferMemoryBarrier barrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags);
     231             : 
     232        1200 :                                 buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
     233             : 
     234        1200 :                                 SumMatrixColumns( buf, pipeline, front->getBatchSize(), front->getClassesCount() );
     235        1200 :                         });
     236             : 
     237           1 :                         addPipeline3(PipelineOpIndex::MultiplyDiagMatrixByMatrix, LossGradientIdx, ActivationIdx, ActivationEltwiseMulIdx,
     238        1200 :                                         [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
     239             : 
     240             :                                 BufferMemoryBarrier barriers[2] = {
     241           0 :                                         BufferMemoryBarrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags),
     242           0 :                                         BufferMemoryBarrier(buffers[ActivationEltwiseMulIdx].buffer, BufferAccessFlags, BufferAccessFlags),
     243        1200 :                                 };
     244             : 
     245        1200 :                                 buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 2));
     246             : 
     247        1200 :                                 MultiplyDiagMatrixByMatrix( buf, pipeline, front->getBatchSize(), front->getClassesCount(), front->getBatchSize() * front->getClassesCount() );
     248        1200 :                         });
     249             : 
     250           1 :                         addPipeline3(PipelineOpIndex::MultiplyDiagMatrixByMatrixForInput, ActivationIdx, WeightsIdx, LossGradientIdx,
     251        1200 :                                         [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
     252             : 
     253        1200 :                                 BufferMemoryBarrier barrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags);
     254             : 
     255        1200 :                                 buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
     256             : 
     257        1200 :                                 MultiplyDiagMatrixByMatrix( buf, pipeline, front->getBatchSize(), front->getClassesCount(), front->getBatchSize() * front->getClassesCount() );
     258        1202 :                         }, Vector<SpecializationConstant>{
     259             :                                 SpecializationConstant(1), // MODIFIERS_ENABLED
     260             :                                 SpecializationConstant(ParamsIdx), // PARAMETERS_INDEX
     261             :                                 SpecializationConstant(Front::P_LossGradientDivider), // MULTIPLIER_PARAMETER_OFFSET
     262             :                                 SpecializationConstant(Front::P_MinGradient), // MIN_PARAMETER_OFFSET
     263             :                                 SpecializationConstant(Front::P_MaxGradient), // MAX_PARAMETER_OFFSET
     264             :                         });
     265             :                 }
     266             : 
     267           1 :                 subpassBuilder.setPrepareCallback([] (const core::SubpassData &subpass, core::FrameQueue &q) {
     268        1200 :                         auto layer = (CrossEntropyLossLayer *)subpass.pass->pass.get();
     269             : 
     270        1200 :                         vk::BufferAttachmentHandle *inputNetworkBuffer = nullptr;
     271        1200 :                         vk::BufferAttachmentHandle *inputLabelsBuffer = nullptr;
     272        1200 :                         vk::BufferAttachmentHandle *outputBuffer = nullptr;
     273        1200 :                         vk::BufferAttachmentHandle *weightsBuffer = nullptr;
     274             : 
     275        1200 :                         if (auto bufferAttachment = q.getAttachment(layer->getInputNetworkAttachment())) {
     276        1200 :                                 inputNetworkBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     277             :                         }
     278             : 
     279        1200 :                         if (auto bufferAttachment = q.getAttachment(layer->getInputLabelsAttachment())) {
     280        1200 :                                 inputLabelsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     281             :                         }
     282             : 
     283        1200 :                         if (auto bufferAttachment = q.getAttachment(layer->getOutputAttachment())) {
     284        1200 :                                 outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     285             :                         }
     286             : 
     287        1200 :                         if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
     288        1200 :                                 weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     289             :                         }
     290             : 
     291        1200 :                         auto front = layer->getFront();
     292             : 
     293        1200 :                         auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
     294        1200 :                         auto &pool = handle->getMemPool(nullptr);
     295             : 
     296        1200 :                         auto batchSize = front->getBatchSize();
     297        1200 :                         auto vectorSize = front->getClassesCount();
     298        1200 :                         auto totalSize = batchSize * vectorSize;
     299             : 
     300             :                         auto activationBuffer = pool->spawn(AllocationUsage::DeviceLocal,
     301        1200 :                                         BufferInfo(size_t(totalSize * sizeof(float)), BufferUsage::StorageBuffer | BufferUsage::TransferSrc));
     302             :                         auto activationEltwiseMulBuffer = pool->spawn(AllocationUsage::DeviceLocal,
     303        1200 :                                         BufferInfo(size_t(totalSize * sizeof(float)), BufferUsage::StorageBuffer | BufferUsage::TransferSrc));
     304             : 
     305        1200 :                         outputBuffer->addBufferView(weightsBuffer->getBuffers()[2].buffer);
     306        1200 :                         outputBuffer->addBufferView(weightsBuffer->getBuffers()[0].buffer);
     307             : 
     308        1200 :                         weightsBuffer->addBufferView(inputNetworkBuffer->getBuffers().front().buffer);
     309        1200 :                         weightsBuffer->addBufferView(inputLabelsBuffer->getBuffers().front().buffer);
     310        1200 :                         weightsBuffer->addBufferView(activationBuffer);
     311        1200 :                         weightsBuffer->addBufferView(activationEltwiseMulBuffer);
     312        1200 :                 });
     313             : 
     314           1 :                 subpassBuilder.setCommandsCallback([] (const SubpassData &subpass, FrameQueue &q, core::CommandBuffer &b) {
     315        1200 :                         auto &buf = static_cast<CommandBuffer &>(b);
     316        1200 :                         auto pass = static_cast<RenderPass *>(subpass.pass->impl.get());
     317        1200 :                         auto layer = (CrossEntropyLossLayer *)subpass.pass->pass.get();
     318             : 
     319        1200 :                         vk::BufferAttachmentHandle *weightsBuffer = nullptr;
     320        1200 :                         if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
     321        1200 :                                 weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     322             :                         }
     323             : 
     324        1200 :                         buf.cmdBindDescriptorSets(pass, 0);
     325             : 
     326        1200 :                         layer->runAll(buf, weightsBuffer->getBuffers());
     327        1200 :                 });
     328           1 :         });
     329             : 
     330           1 :         builder.addCompleteCallback([this, front = _front] (const QueuePassData &pass, FrameQueue &q, bool success) {
     331        1200 :                 auto layer = (CrossEntropyLossLayer *)pass.pass.get();
     332             : 
     333        1200 :                 vk::BufferAttachmentHandle *weightsBuffer = nullptr;
     334        1200 :                 if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
     335        1200 :                         weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
     336             :                 }
     337             : 
     338        1200 :                 auto params = weightsBuffer->getBuffers()[ParamsIdx].buffer;
     339             : 
     340             :                 if (success) {
     341             :                         /*q.getFrame()->getLoop()->captureBuffer([front] (const BufferInfo &info, BytesView view) {
     342             :                                 auto name = toString(front->getName(),".", front->getInputIndex(), ".label.bin");
     343             :                                 xenolith::shadernn::Model::saveBlob(
     344             :                                                 filesystem::currentDir<Interface>(name).data(),
     345             :                                                 view.data(), view.size());
     346             :                                 std::cout << "Save " << name << "\n";
     347             :                         }, weightsBuffer->getBuffers()[InputLabelsIdx].buffer.get());
     348             : 
     349             :                         q.getFrame()->getLoop()->captureBuffer([front] (const BufferInfo &info, BytesView view) {
     350             :                                 auto name = toString(front->getName(),".", front->getInputIndex(), ".activation.bin");
     351             :                                 xenolith::shadernn::Model::saveBlob(
     352             :                                                 filesystem::currentDir<Interface>(name).data(),
     353             :                                                 view.data(), view.size());
     354             :                                 std::cout << "Save " << name << "\n";
     355             :                         }, weightsBuffer->getBuffers()[ActivationIdx].buffer.get());
     356             : 
     357             :                         q.getFrame()->getLoop()->captureBuffer([front] (const BufferInfo &info, BytesView view) {
     358             :                                 auto name = toString(front->getName(),".", front->getInputIndex(), ".activation.mul.bin");
     359             :                                 xenolith::shadernn::Model::saveBlob(
     360             :                                                 filesystem::currentDir<Interface>(name).data(),
     361             :                                                 view.data(), view.size());
     362             :                                 std::cout << "Save " << name << "\n";
     363             :                         }, weightsBuffer->getBuffers()[ActivationEltwiseMulIdx].buffer.get());
     364             : 
     365             :                         q.getFrame()->getLoop()->captureBuffer([front] (const BufferInfo &info, BytesView view) {
     366             :                                 auto name = toString(front->getName(),".", front->getInputIndex(), ".value.bin");
     367             :                                 xenolith::shadernn::Model::saveBlob(
     368             :                                                 filesystem::currentDir<Interface>(name).data(),
     369             :                                                 view.data(), view.size());
     370             :                                 std::cout << "Save " << name << "\n";
     371             :                         }, weightsBuffer->getBuffers()[LossValueIdx].buffer.get());
     372             :                         */
     373             :                 }
     374        1200 :         });
     375             : 
     376           1 :         _inputLabelsAttachment = inputLabels;
     377           1 :         _inputNetworkAttachment = inputNetwork;
     378           1 :         _weightAttachment = weightsAttachment;
     379           1 :         _outputAttachment = output;
     380             : 
     381        1200 :         _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
     382        1200 :                 return Rc<vk::QueuePassHandle>::create(pass, q);
     383           1 :         };
     384             : 
     385           1 :         if (_front->getModel()->isTrainable()) {
     386           1 :                 initPropagation(queueBuilder, builder);
     387             :         }
     388             : 
     389           2 :         return QueuePass::init(builder);
     390             : }
     391             : 
     392           1 : void CrossEntropyLossLayer::initPropagation(Queue::Builder &queueBuilder, QueuePassBuilder &builder) {
     393           1 :         const core::QueuePassData *pass = _inputNetworkAttachment->passes.front()->pass;
     394           1 :         if (auto trainableLayer = dynamic_cast<TrainableLayer *>(pass->pass.get())) {
     395           1 :                 trainableLayer->initPropagation(queueBuilder, builder, _weightAttachment, ActivationIdx);
     396             :         }
     397           1 : }
     398             : 
     399        1200 : void CrossEntropyLossLayer::runAll(CommandBuffer &buf, SpanView<BufferView> buffers) {
     400       12000 :         for (auto &it : _pipelineOps) {
     401       10800 :                 auto pipeline = static_cast<ComputePipeline *>(it.second.pipeline->pipeline.get());
     402             : 
     403       10800 :                 it.second.command(_front, buf, pipeline, buffers);
     404             :         }
     405        1200 : }
     406             : 
     407             : }

Generated by: LCOV version 1.14