LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkConvLayer.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 142 0.0 %
Date: 2024-05-06 04:51:23 Functions: 0 17 0.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnVkConvLayer.h"
      24             : #include "XLCoreAttachment.h"
      25             : #include "XLCoreFrameQueue.h"
      26             : #include "XLCoreFrameRequest.h"
      27             : #include "XLSnnVkShaders.h"
      28             : #include "XLVkPipeline.h"
      29             : 
      30             : namespace stappler::xenolith::vk::shadernn {
      31             : 
      32           0 : static core::ImageFormat getPrecisionKernelFormat(Precision p) {
      33           0 :         switch (p) {
      34           0 :         case Precision::Unknown:
      35           0 :                 return core::ImageFormat::Undefined;
      36             :                 break;
      37           0 :         case Precision::F8:
      38           0 :                 return core::ImageFormat::R8G8B8A8_UNORM;
      39             :                 break;
      40           0 :         case Precision::F16:
      41           0 :                 return core::ImageFormat::R16G16B16A16_SFLOAT;
      42             :                 break;
      43           0 :         case Precision::F32:
      44           0 :                 return core::ImageFormat::R32G32B32A32_SFLOAT;
      45             :                 break;
      46           0 :         case Precision::F64:
      47           0 :                 return core::ImageFormat::R64G64B64A64_SFLOAT;
      48             :                 break;
      49             :         }
      50           0 :         return core::ImageFormat::Undefined;
      51             : }
      52             : 
      53           0 : Conv2DLayer::~Conv2DLayer() { }
      54             : 
      55           0 : bool Conv2DLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front *front,
      56             :                 const AttachmentData *input, const AttachmentData *output) {
      57             :         using namespace core;
      58             : 
      59           0 :         auto precision = getAttachmentPrecision(output);
      60             : 
      61           0 :         _front = front;
      62             : 
      63           0 :         auto kernelImage = queueBuilder.addImageByRef(toString(front->getName(), "_kernelImage"),
      64           0 :                         ImageInfo(front->getKernelExtent(), ImageUsage::Storage, ImageTiling::Optimal,
      65           0 :                                         getPrecisionKernelFormat(precision), PassType::Compute, ImageHints::Static),
      66           0 :                         front->getKernelImageData(), AttachmentLayout::General);
      67             : 
      68           0 :         auto biasBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_biasBuffer"),
      69           0 :                         BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
      70           0 :                         front->getBiasBufferData());
      71             : 
      72           0 :         auto betaBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_betaBuffer"),
      73           0 :                         BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
      74           0 :                         front->getNormBetaBufferData());
      75             : 
      76           0 :         auto gammaBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_gammaBuffer"),
      77           0 :                         BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
      78           0 :                         front->getNormGammaBufferData());
      79             : 
      80           0 :         auto meanBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_meanBuffer"),
      81           0 :                         BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
      82           0 :                         front->getNormMeanBufferData());
      83             : 
      84           0 :         auto varianceBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_varianceBuffer"),
      85           0 :                         BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
      86           0 :                         front->getNormVarianceBufferData());
      87             : 
      88           0 :         auto kernelAttachment = queueBuilder.addAttachemnt(toString(front->getName(), "_kernel"),
      89           0 :                         [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
      90           0 :                 return Rc<vk::ImageAttachment>::create(builder,
      91           0 :                         kernelImage,
      92           0 :                         ImageAttachment::AttachmentInfo{
      93             :                                 .initialLayout = AttachmentLayout::Ignored,
      94             :                                 .finalLayout = AttachmentLayout::Ignored,
      95             :                                 .clearOnLoad = false
      96             :                         }
      97           0 :                 );
      98             :         });
      99             : 
     100           0 :         auto dataAttachment = queueBuilder.addAttachemnt(toString(front->getName(), "_data"),
     101           0 :                         [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
     102           0 :                 return Rc<vk::BufferAttachment>::create(builder,
     103           0 :                         Vector<const BufferData *>{
     104           0 :                                 biasBuffer,
     105           0 :                                 betaBuffer,
     106           0 :                                 gammaBuffer,
     107           0 :                                 meanBuffer,
     108           0 :                                 varianceBuffer
     109             :                         }
     110           0 :                 );
     111             :         });
     112             : 
     113           0 :         auto passInput = builder.addAttachment(input, [] (AttachmentPassBuilder &builder) {
     114           0 :                 builder.setDependency(AttachmentDependencyInfo{
     115             :                         PipelineStage::ComputeShader, AccessType::ShaderRead,
     116             :                         PipelineStage::ComputeShader, AccessType::ShaderRead,
     117             :                         FrameRenderPassState::Submitted,
     118             :                 });
     119           0 :                 builder.setInitialLayout(AttachmentLayout::General);
     120           0 :                 builder.setFinalLayout(AttachmentLayout::General);
     121           0 :         });
     122             : 
     123           0 :         auto passOutput = builder.addAttachment(output, [] (AttachmentPassBuilder &builder) {
     124           0 :                 builder.setDependency(AttachmentDependencyInfo{
     125             :                         PipelineStage::ComputeShader, AccessType::ShaderWrite,
     126             :                         PipelineStage::ComputeShader, AccessType::ShaderWrite,
     127             :                         FrameRenderPassState::Submitted,
     128             :                 });
     129           0 :                 builder.setInitialLayout(AttachmentLayout::General);
     130           0 :                 builder.setFinalLayout(AttachmentLayout::General);
     131           0 :         });
     132             : 
     133           0 :         auto passKernel = builder.addAttachment(kernelAttachment, [] (AttachmentPassBuilder &builder) {
     134           0 :                 builder.setDependency(AttachmentDependencyInfo{
     135             :                         PipelineStage::ComputeShader, AccessType::ShaderRead,
     136             :                         PipelineStage::ComputeShader, AccessType::ShaderRead,
     137             :                         FrameRenderPassState::Submitted,
     138             :                 });
     139           0 :                 builder.setInitialLayout(AttachmentLayout::General);
     140           0 :                 builder.setFinalLayout(AttachmentLayout::General);
     141           0 :         });
     142             : 
     143           0 :         auto passData = builder.addAttachment(dataAttachment);
     144             : 
     145           0 :         auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
     146           0 :                 layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
     147           0 :                         setBuilder.addDescriptor(passOutput, DescriptorType::StorageImage, AttachmentLayout::General);
     148           0 :                         setBuilder.addDescriptor(passInput, DescriptorType::StorageImage, AttachmentLayout::General);
     149           0 :                         setBuilder.addDescriptor(passKernel, DescriptorType::StorageImage, AttachmentLayout::General);
     150           0 :                         setBuilder.addDescriptor(passData, DescriptorType::StorageBuffer);
     151           0 :                 });
     152           0 :         });
     153             : 
     154           0 :         builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
     155           0 :                 auto paddings = front->getPaddingOffset();
     156           0 :                 auto kernel = front->getKernelSize();
     157           0 :                 auto stride = front->getStride();
     158           0 :                 auto mode = front->getPaddingMode();
     159           0 :             uint32_t dilate = 1;
     160             : 
     161           0 :             uint32_t paddingMode = 0;
     162           0 :             if (mode == "constant") {
     163           0 :                 paddingMode = 1;
     164           0 :             } else if (mode == "replicate") {
     165           0 :                 paddingMode = 2;
     166           0 :             } else if (mode == "reflect") {
     167           0 :                 paddingMode = 3;
     168             :             }
     169             : 
     170           0 :                 SpecializationInfo spec;
     171           0 :                 spec.data = queueBuilder.addProgramByRef(toString(front->getName(), "_shader"), getShader(LayerShader::Conv2d, precision));
     172           0 :                 spec.constants.emplace_back(SpecializationConstant(paddings.x)); // 0
     173           0 :                 spec.constants.emplace_back(SpecializationConstant(paddings.z)); // 1
     174           0 :                 spec.constants.emplace_back(SpecializationConstant(kernel)); // 2
     175           0 :                 spec.constants.emplace_back(SpecializationConstant(kernel)); // 3
     176           0 :                 spec.constants.emplace_back(SpecializationConstant(stride)); // 4
     177           0 :                 spec.constants.emplace_back(SpecializationConstant(stride)); // 5
     178           0 :                 spec.constants.emplace_back(SpecializationConstant(dilate)); // 6
     179           0 :                 spec.constants.emplace_back(SpecializationConstant(dilate)); // 7
     180           0 :                 spec.constants.emplace_back(SpecializationConstant(4)); // 8
     181           0 :                 spec.constants.emplace_back(SpecializationConstant(front->getActivation())); // 9
     182           0 :                 spec.constants.emplace_back(SpecializationConstant(paddingMode)); // 10
     183           0 :                 spec.constants.emplace_back(SpecializationConstant(uint32_t(front->useBatchNormalization()))); // 11
     184           0 :                 spec.constants.emplace_back(SpecializationConstant(uint32_t(front->useBias()))); // 12
     185           0 :                 spec.constants.emplace_back(SpecializationConstant(front->getLeakyReluAlpha())); // 13
     186             : 
     187           0 :                 subpassBuilder.addComputePipeline(toString(front->getName(), "_pipeline"), layout, move(spec));
     188           0 :         });
     189             : 
     190           0 :         _inputAttachment = input;
     191           0 :         _outputAttachment = output;
     192           0 :         _kernelAttachment = kernelAttachment;
     193           0 :         _dataAttachment = dataAttachment;
     194             : 
     195           0 :         _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
     196           0 :                 return Rc<LayerHandle>::create(pass, q);
     197           0 :         };
     198             : 
     199           0 :         return QueuePass::init(builder);
     200             : }
     201             : 
     202           0 : bool Conv2DLayer::LayerHandle::prepare(FrameQueue &q, Function<void(bool)> &&cb) {
     203           0 :         auto pass = (Conv2DLayer *)_queuePass.get();
     204             : 
     205           0 :         if (auto imageAttachment = q.getAttachment(pass->getInputAttachment())) {
     206           0 :                 _inputImage = (const vk::ImageAttachmentHandle *)imageAttachment->handle.get();
     207             :         }
     208             : 
     209           0 :         if (auto imageAttachment = q.getAttachment(pass->getOutputAttachment())) {
     210           0 :                 _outputImage = (const vk::ImageAttachmentHandle *)imageAttachment->handle.get();
     211             :         }
     212             : 
     213           0 :         if (auto kernelAttachment = q.getAttachment(pass->getKernelAttachment())) {
     214           0 :                 _kernelImage = (const vk::ImageAttachmentHandle *)kernelAttachment->handle.get();
     215             :         }
     216             : 
     217           0 :         if (auto bufferAttachment = q.getAttachment(pass->getDataAttachment())) {
     218           0 :                 _dataHandle = bufferAttachment->handle.get();
     219             :         }
     220             : 
     221           0 :         _front = pass->getFront();
     222             : 
     223           0 :         return vk::QueuePassHandle::prepare(q, move(cb));
     224             : }
     225             : 
     226           0 : Vector<const vk::CommandBuffer *> Conv2DLayer::LayerHandle::doPrepareCommands(FrameHandle &frame) {
     227           0 :         auto buf = _pool->recordBuffer(*_device, [&] (vk::CommandBuffer &buf) {
     228           0 :                 auto pass = _data->impl.cast<vk::RenderPass>().get();
     229           0 :                 pass->perform(*this, buf, [&] {
     230           0 :                         buf.cmdBindDescriptorSets(pass, 0);
     231             : 
     232           0 :                         auto extent = _outputImage->getQueueData()->image->getInfo().extent;
     233             : 
     234           0 :                         auto oc_4 = UP_DIV(_front->getNumOutputPlanes(), uint32_t(4));
     235             : 
     236           0 :                         auto pipeline = static_cast<vk::ComputePipeline *>((*_data->subpasses[0]->computePipelines.begin())->pipeline.get());
     237             : 
     238           0 :                         buf.cmdBindPipeline(pipeline);
     239           0 :                         buf.cmdDispatch((extent.width - 1) / pipeline->getLocalX() + 1,
     240           0 :                                         (extent.height - 1) / pipeline->getLocalY() + 1,
     241           0 :                                         (oc_4 - 1) / pipeline->getLocalZ() + 1);
     242           0 :                 }, true);
     243           0 :                 return true;
     244             :         });
     245           0 :         return Vector<const vk::CommandBuffer *>{buf};
     246             : }
     247             : 
     248             : }

Generated by: LCOV version 1.14