LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkGenerationLayer.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 60 0.0 %
Date: 2024-05-06 04:51:23 Functions: 0 18 0.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLVkPipeline.h"
      24             : #include "XLCoreAttachment.h"
      25             : #include "XLCoreFrameQueue.h"
      26             : #include "XLCoreFrameRequest.h"
      27             : #include "XLSnnVkGenerationLayer.h"
      28             : #include "XLSnnVkShaders.h"
      29             : 
      30             : namespace stappler::xenolith::vk::shadernn {
      31             : 
      32           0 : GenerationLayer::~GenerationLayer() { }
      33             : 
      34           0 : bool GenerationLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, const AttachmentData *output) {
      35             :         using namespace core;
      36             : 
      37           0 :         auto dataBuffer = queueBuilder.addAttachemnt("GenerationLayerData", [] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
      38           0 :                 builder.defineAsInput();
      39           0 :                 auto a = Rc<GenericAttachment>::create(builder);
      40           0 :                 a->setValidateInputCallback([] (const Attachment &, const Rc<AttachmentInputData> &data) {
      41           0 :                         return dynamic_cast<GenerationDataInput *>(data.get()) != nullptr;
      42             :                 });
      43           0 :                 a->setFrameHandleCallback([] (Attachment &a, const FrameQueue &q) {
      44           0 :                         auto h = Rc<core::AttachmentHandle>::create(a, q);
      45           0 :                         h->setInputCallback([] (AttachmentHandle &handle, FrameQueue &queue, AttachmentInputData *input, Function<void(bool)> &&cb) {
      46           0 :                                 cb(true);
      47           0 :                         });
      48           0 :                         return h;
      49           0 :                 });
      50           0 :                 return a;
      51           0 :         });
      52             : 
      53           0 :         auto passOutput = builder.addAttachment(output, [] (AttachmentPassBuilder &builder) {
      54           0 :                 builder.setDependency(AttachmentDependencyInfo{
      55             :                         PipelineStage::ComputeShader, AccessType::ShaderWrite | AccessType::ShaderRead,
      56             :                         PipelineStage::ComputeShader, AccessType::ShaderWrite | AccessType::ShaderRead,
      57             :                         FrameRenderPassState::Submitted,
      58             :                 });
      59           0 :         });
      60             : 
      61           0 :         builder.addAttachment(dataBuffer);
      62             : 
      63           0 :         auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
      64           0 :                 layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
      65           0 :                         setBuilder.addDescriptor(passOutput, DescriptorType::StorageImage, AttachmentLayout::General);
      66           0 :                 });
      67           0 :         });
      68             : 
      69           0 :         auto precision = getAttachmentPrecision(output);
      70             : 
      71           0 :         builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
      72           0 :                 subpassBuilder.addComputePipeline("GenerationLayerPipeline", layout,
      73           0 :                                 queueBuilder.addProgramByRef("GenerationLayerPipeline", getShader(LayerShader::Gen, precision)));
      74           0 :         });
      75             : 
      76           0 :         _outputAttachment = output;
      77           0 :         _dataAttachment = dataBuffer;
      78             : 
      79           0 :         _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
      80           0 :                 return Rc<LayerHandle>::create(pass, q);
      81           0 :         };
      82             : 
      83           0 :         return QueuePass::init(builder);
      84             : }
      85             : 
      86           0 : GenerationLayer::LayerHandle::~LayerHandle() { }
      87             : 
      88           0 : bool GenerationLayer::LayerHandle::prepare(FrameQueue &q, Function<void(bool)> &&cb) {
      89           0 :         auto pass = (GenerationLayer *)_queuePass.get();
      90             : 
      91           0 :         if (auto imageAttachment = q.getAttachment(pass->getOutputAttachment())) {
      92           0 :                 _outputImage = (const vk::ImageAttachmentHandle *)imageAttachment->handle.get();
      93             :         }
      94             : 
      95           0 :         if (auto bufferAttachment = q.getAttachment(pass->getDataAttachment())) {
      96           0 :                 _dataBuffer = (const vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
      97             :         }
      98             : 
      99           0 :         return vk::QueuePassHandle::prepare(q, move(cb));
     100             : }
     101             : 
     102           0 : Vector<const vk::CommandBuffer *> GenerationLayer::LayerHandle::doPrepareCommands(FrameHandle &handle) {
     103           0 :         auto buf = _pool->recordBuffer(*_device, [&] (vk::CommandBuffer &buf) {
     104           0 :                 auto pass = _data->impl.cast<vk::RenderPass>().get();
     105           0 :                 pass->perform(*this, buf, [&] {
     106           0 :                         auto extent = handle.getFrameConstraints().extent;
     107           0 :                         auto input = static_cast<GenerationDataInput *>(_dataBuffer->getInput());
     108             : 
     109           0 :                         buf.cmdBindDescriptorSets(pass, 0);
     110           0 :                         buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&input->data), sizeof(GenerationData)));
     111             : 
     112           0 :                         auto pipeline = static_cast<vk::ComputePipeline *>((*_data->subpasses[0]->computePipelines.begin())->pipeline.get());
     113             : 
     114           0 :                         buf.cmdBindPipeline(pipeline);
     115           0 :                         buf.cmdDispatch((extent.width - 1) / pipeline->getLocalX() + 1,
     116           0 :                                         (extent.height - 1) / pipeline->getLocalY() + 1,
     117           0 :                                         (extent.depth - 1) / pipeline->getLocalZ() + 1);
     118           0 :                 }, true);
     119           0 :                 return true;
     120             :         });
     121           0 :         return Vector<const vk::CommandBuffer *>{buf};
     122             : }
     123             : 
     124             : }

Generated by: LCOV version 1.14