LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkStatPercentLayer.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 160 0.0 %
Date: 2024-05-06 04:51:23 Functions: 0 25 0.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnVkStatPercentLayer.h"
      24             : #include "XLSnnVkShaders.h"
      25             : 
      26             : namespace stappler::xenolith::vk::shadernn {
      27             : 
      28           0 : StatPercentLayer::~StatPercentLayer() { }
      29             : 
      30           0 : bool StatPercentLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front *front,
      31             :                 const AttachmentData *input, const AttachmentData *output) {
      32             :         using namespace core;
      33             : 
      34           0 :         auto classesBuffer = queueBuilder.addAttachemnt("StatPercentLayerClassesBuffer", [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
      35           0 :                 return Rc<BufferAttachment>::create(builder, BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer));
      36             :         });
      37             : 
      38           0 :         auto passInput = builder.addAttachment(input);
      39           0 :         auto passOutput = builder.addAttachment(output);
      40           0 :         auto passClasses = builder.addAttachment(classesBuffer);
      41             : 
      42           0 :         auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
      43           0 :                 layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
      44           0 :                         setBuilder.addDescriptor(passOutput, DescriptorType::StorageBuffer);
      45           0 :                         setBuilder.addDescriptor(passInput, DescriptorType::StorageBuffer);
      46           0 :                         setBuilder.addDescriptor(passClasses, DescriptorType::StorageBuffer);
      47           0 :                 });
      48           0 :         });
      49             : 
      50           0 :         builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
      51           0 :                 subpassBuilder.addComputePipeline(StatPercentLayerClassesPipeline, layout,
      52           0 :                                 queueBuilder.addProgramByRef("StatPercentLayerClassesPProgram", getShader(LayerShader::StatClassMap, Precision::Unknown)));
      53           0 :                 subpassBuilder.addComputePipeline(StatPercentLayerPercentPipeline, layout,
      54           0 :                                 queueBuilder.addProgramByRef("StatPercentLayerPercentProgram", getShader(LayerShader::StatClassPercent, Precision::Unknown)));
      55           0 :         });
      56             : 
      57           0 :         _inputAttachment = input;
      58           0 :         _outputAttachment = output;
      59           0 :         _classesAttachment = classesBuffer;
      60           0 :         _front = front;
      61             : 
      62           0 :         _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
      63           0 :                 return Rc<LayerHandle>::create(pass, q);
      64           0 :         };
      65             : 
      66           0 :         return QueuePass::init(builder);
      67             : }
      68             : 
      69           0 : bool StatPercentLayer::LayerHandle::prepare(FrameQueue &q, Function<void(bool)> &&cb) {
      70           0 :         auto pass = (StatPercentLayer *)_queuePass.get();
      71             : 
      72           0 :         if (auto imageAttachment = q.getAttachment(pass->getInputAttachment())) {
      73           0 :                 _inputBuffer = (vk::BufferAttachmentHandle *)imageAttachment->handle.get();
      74             :         }
      75             : 
      76           0 :         if (auto imageAttachment = q.getAttachment(pass->getOutputAttachment())) {
      77           0 :                 _outputBuffer = (vk::BufferAttachmentHandle *)imageAttachment->handle.get();
      78             :         }
      79             : 
      80           0 :         if (auto bufferAttachment = q.getAttachment(pass->getClassesAttachment())) {
      81           0 :                 _classesBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
      82             :         }
      83             : 
      84           0 :         _front = pass->getFront();
      85             : 
      86           0 :         auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
      87           0 :         auto &pool = handle->getMemPool(nullptr);
      88           0 :         auto extent = handle->getFrameConstraints().extent;
      89             : 
      90           0 :         _classesSizes = pool->spawnPersistent(AllocationUsage::DeviceLocal,
      91           0 :                         BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
      92           0 :                                 size_t(_front->getClassCount() * sizeof(uint32_t))
      93           0 :         ));
      94           0 :         _classesIndexes = pool->spawn(AllocationUsage::DeviceLocal,
      95           0 :                         BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
      96           0 :                                 size_t(_front->getClassCount() * extent.height * sizeof(uint32_t))
      97           0 :         ));
      98           0 :         _output = pool->spawnPersistent(AllocationUsage::DeviceLocal,
      99           0 :                         BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
     100           0 :                                 size_t(_front->getClassCount() * (sizeof(float) * 4 + sizeof(uint32_t) * 4))
     101           0 :         ));
     102             : 
     103           0 :         _classesBuffer->addBufferView(_classesSizes);
     104           0 :         _classesBuffer->addBufferView(_classesIndexes);
     105           0 :         _outputBuffer->addBufferView(_output);
     106             : 
     107           0 :         return vk::QueuePassHandle::prepare(q, move(cb));
     108             : }
     109             : 
     110           0 : Vector<const vk::CommandBuffer *> StatPercentLayer::LayerHandle::doPrepareCommands(FrameHandle &handle) {
     111           0 :         auto buf = _pool->recordBuffer(*_device, [&] (vk::CommandBuffer &buf) {
     112           0 :                 auto pass = _data->impl.cast<vk::RenderPass>().get();
     113           0 :                 pass->perform(*this, buf, [&] {
     114             :                         struct ClassesInputInfo {
     115             :                                 int size;
     116             :                                 int fields;
     117             :                                 int fieldClass;
     118             :                                 int classMin;
     119             :                                 int classMax;
     120             :                                 int fieldSource;
     121             :                                 int fieldTarget;
     122             :                                 int classCount;
     123             :                         };
     124             : 
     125           0 :                         auto extent = handle.getFrameConstraints().extent;
     126             : 
     127             :                         ClassesInputInfo pcb1;
     128           0 :                         pcb1.size = extent.height;
     129           0 :                         pcb1.fields = _inputBuffer->getBuffers().front().buffer->getSize() / (sizeof(uint64_t) * pcb1.size);
     130           0 :                         pcb1.fieldClass = _front->getFieldClass();
     131           0 :                         pcb1.classMin = _front->getClassMin();
     132           0 :                         pcb1.classMax = _front->getClassMin() + _front->getClassCount() - 1;
     133           0 :                         pcb1.fieldSource = _front->getFieldSource();
     134           0 :                         pcb1.fieldTarget = _front->getFieldTarget();
     135           0 :                         pcb1.classCount = _front->getClassCount();
     136             : 
     137           0 :                         buf.cmdFillBuffer(_classesIndexes, 0);
     138           0 :                         buf.cmdFillBuffer(_classesSizes, 0);
     139             : 
     140             :                         BufferMemoryBarrier b[2] = {
     141             :                                 BufferMemoryBarrier(_classesSizes, VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_WRITE_BIT),
     142             :                                 BufferMemoryBarrier(_classesSizes, VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_WRITE_BIT)
     143           0 :                         };
     144             : 
     145           0 :                         buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(b, 2));
     146             : 
     147           0 :                         buf.cmdBindDescriptorSets(pass, 0);
     148           0 :                         buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&pcb1), sizeof(ClassesInputInfo)));
     149             : 
     150           0 :                         vk::ComputePipeline *classesPipeline = nullptr;
     151           0 :                         auto classesPipelineIt = _data->subpasses[0]->computePipelines.find(StatPercentLayerClassesPipeline);
     152           0 :                         if (classesPipelineIt != _data->subpasses[0]->computePipelines.end()) {
     153           0 :                                 classesPipeline = static_cast<vk::ComputePipeline *>((*classesPipelineIt)->pipeline.get());
     154             :                         }
     155             : 
     156           0 :                         buf.cmdBindPipeline(classesPipeline);
     157           0 :                         buf.cmdDispatch(1, (pcb1.size - 1) / classesPipeline->getLocalY() + 1, 1);
     158             : 
     159           0 :                         vk::ComputePipeline *percentPipeline = nullptr;
     160           0 :                         auto percentPipelineeIt = _data->subpasses[0]->computePipelines.find(StatPercentLayerPercentPipeline);
     161           0 :                         if (percentPipelineeIt != _data->subpasses[0]->computePipelines.end()) {
     162           0 :                                 percentPipeline = static_cast<vk::ComputePipeline *>((*percentPipelineeIt)->pipeline.get());
     163             :                         }
     164             : 
     165           0 :                         buf.cmdBindPipeline(percentPipeline);
     166           0 :                         buf.cmdDispatch((pcb1.classCount - 1) / percentPipeline->getLocalX() + 1, 1, 1);
     167           0 :                 }, true);
     168           0 :                 return true;
     169             :         });
     170           0 :         return Vector<const vk::CommandBuffer *>{buf};
     171             : }
     172             : 
     173           0 : void StatPercentLayer::LayerHandle::doSubmitted(FrameHandle &h, Function<void(bool)> &&cb, bool s, Rc<Fence> &&fence) {
     174           0 :         vk::QueuePassHandle::doSubmitted(h, move(cb), s, move(fence));
     175             : 
     176             :         /*h.getLoop()->captureBuffer([] (const BufferInfo &info, BytesView view) {
     177             :                 std::cout << view.size() / (sizeof(float) * 4) << "\n";
     178             : 
     179             :                 std::cout << "0";
     180             : 
     181             :                 size_t row = 1;
     182             :                 size_t i = 0;
     183             :                 while (!view.empty()) {
     184             :                         switch (i) {
     185             :                         case 0:
     186             :                         case 1:
     187             :                         case 2:
     188             :                         case 3:
     189             :                                 std::cout << ", " << view.readFloat32();
     190             :                                 break;
     191             :                         case 4:
     192             :                         case 5:
     193             :                         case 6:
     194             :                                 std::cout << ", " << view.readUnsigned32();
     195             :                                 break;
     196             :                         default:
     197             :                                 view.readUnsigned32();
     198             :                                 break;
     199             :                         }
     200             :                         ++ i;
     201             :                         if (i > 7) {
     202             :                                 i = 0;
     203             :                                 std::cout << "\n" << row ++;
     204             :                         }
     205             :                 }
     206             :                 std::cout << "\n";
     207             :         }, _output);*/
     208           0 : }
     209             : 
     210           0 : StatAnalysisLayer::~StatAnalysisLayer() { }
     211             : 
     212           0 : bool StatAnalysisLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front *front,
     213             :                 const AttachmentData *inputData, const AttachmentData *inputClasses, const AttachmentData *output) {
     214             :         using namespace core;
     215             : 
     216           0 :         auto passInputData = builder.addAttachment(inputData);
     217           0 :         auto passInputClasses = builder.addAttachment(inputClasses);
     218           0 :         auto passOutput = builder.addAttachment(output);
     219             : 
     220           0 :         auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
     221           0 :                 layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
     222           0 :                         setBuilder.addDescriptor(passOutput, DescriptorType::StorageBuffer);
     223           0 :                         setBuilder.addDescriptor(passInputData, DescriptorType::StorageBuffer);
     224           0 :                         setBuilder.addDescriptor(passInputClasses, DescriptorType::StorageBuffer);
     225           0 :                 });
     226           0 :         });
     227             : 
     228           0 :         builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
     229           0 :                 subpassBuilder.addComputePipeline("StatAnalysisLayerProgram", layout,
     230           0 :                                 queueBuilder.addProgramByRef("StatAnalysisLayerProgram", getShader(LayerShader::StatAnalysis, Precision::Unknown)));
     231           0 :         });
     232             : 
     233           0 :         _inputDataAttachment = inputData;
     234           0 :         _inputClassesAttachment = inputClasses;
     235           0 :         _outputAttachment = output;
     236           0 :         _front = front;
     237             : 
     238           0 :         _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
     239           0 :                 return Rc<LayerHandle>::create(pass, q);
     240           0 :         };
     241             : 
     242           0 :         return QueuePass::init(builder);
     243             : }
     244             : 
     245           0 : bool StatAnalysisLayer::LayerHandle::prepare(FrameQueue &q, Function<void(bool)> &&cb) {
     246           0 :         auto pass = (StatAnalysisLayer *)_queuePass.get();
     247             : 
     248           0 :         if (auto attachment = q.getAttachment(pass->getInputDataAttachment())) {
     249           0 :                 _inputDataBuffer = (vk::BufferAttachmentHandle *)attachment->handle.get();
     250             :         }
     251             : 
     252           0 :         if (auto attachment = q.getAttachment(pass->getInputClassesAttachment())) {
     253           0 :                 _inputClassesBuffer = (vk::BufferAttachmentHandle *)attachment->handle.get();
     254             :         }
     255             : 
     256           0 :         if (auto imageAttachment = q.getAttachment(pass->getOutputAttachment())) {
     257           0 :                 _outputBuffer = (vk::BufferAttachmentHandle *)imageAttachment->handle.get();
     258             :         }
     259             : 
     260           0 :         _front = pass->getFront();
     261             : 
     262           0 :         auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
     263           0 :         auto &pool = handle->getMemPool(nullptr);
     264           0 :         auto extent = handle->getFrameConstraints().extent;
     265             : 
     266           0 :         _output = pool->spawnPersistent(AllocationUsage::DeviceLocal,
     267           0 :                         BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
     268           0 :                                 size_t(extent.height * (sizeof(float) * 4))
     269           0 :         ));
     270             : 
     271           0 :         _outputBuffer->addBufferView(_output);
     272             : 
     273           0 :         return vk::QueuePassHandle::prepare(q, move(cb));
     274             : }
     275             : 
     276           0 : Vector<const vk::CommandBuffer *> StatAnalysisLayer::LayerHandle::doPrepareCommands(FrameHandle &handle) {
     277           0 :         auto buf = _pool->recordBuffer(*_device, [&] (vk::CommandBuffer &buf) {
     278           0 :                 auto pass = _data->impl.cast<vk::RenderPass>().get();
     279           0 :                 pass->perform(*this, buf, [&] {
     280             :                         struct InputInfo {
     281             :                                 int size;
     282             :                                 int fields;
     283             :                                 int fieldClass;
     284             :                                 int classMin;
     285             :                                 int classMax;
     286             :                                 int fieldSource;
     287             :                                 int fieldTarget;
     288             :                                 int classCount;
     289             :                                 float threshold;
     290             :                         };
     291             : 
     292           0 :                         auto extent = handle.getFrameConstraints().extent;
     293             : 
     294             :                         InputInfo pcb1;
     295           0 :                         pcb1.size = extent.height;
     296           0 :                         pcb1.fields = _inputDataBuffer->getBuffers().front().buffer->getSize() / (sizeof(uint64_t) * pcb1.size);
     297           0 :                         pcb1.fieldClass = _front->getFieldClass();
     298           0 :                         pcb1.classMin = _front->getClassMin();
     299           0 :                         pcb1.classMax = _front->getClassMin() + _front->getClassCount() - 1;
     300           0 :                         pcb1.fieldSource = _front->getFieldSource();
     301           0 :                         pcb1.fieldTarget = _front->getFieldTarget();
     302           0 :                         pcb1.classCount = _front->getClassCount();
     303           0 :                         pcb1.threshold = _front->getThreshold();
     304             : 
     305           0 :                         buf.cmdBindDescriptorSets(pass, 0);
     306           0 :                         buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&pcb1), sizeof(InputInfo)));
     307             : 
     308           0 :                         auto pipeline = static_cast<vk::ComputePipeline *>((*_data->subpasses[0]->computePipelines.begin())->pipeline.get());
     309             : 
     310           0 :                         buf.cmdBindPipeline(pipeline);
     311           0 :                         buf.cmdDispatch((pcb1.size - 1) / pipeline->getLocalX() + 1, 1, 1);
     312           0 :                 }, true);
     313           0 :                 return true;
     314             :         });
     315           0 :         return Vector<const vk::CommandBuffer *>{buf};
     316             : }
     317             : 
     318           0 : void StatAnalysisLayer::LayerHandle::doSubmitted(FrameHandle &h, Function<void(bool)> &&cb, bool s, Rc<Fence> &&fence) {
     319           0 :         vk::QueuePassHandle::doSubmitted(h, move(cb), s, move(fence));
     320             : 
     321             :         /*h.getLoop()->captureBuffer([] (const BufferInfo &info, BytesView view) {
     322             :                 std::cout << view.size() / (sizeof(float) * 4) << "\n";
     323             :         }, _output);*/
     324           0 : }
     325             : 
     326             : }

Generated by: LCOV version 1.14