Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnVkConvLayer.h"
24 : #include "XLCoreAttachment.h"
25 : #include "XLCoreFrameQueue.h"
26 : #include "XLCoreFrameRequest.h"
27 : #include "XLSnnVkShaders.h"
28 : #include "XLVkPipeline.h"
29 :
30 : namespace stappler::xenolith::vk::shadernn {
31 :
32 0 : static core::ImageFormat getPrecisionKernelFormat(Precision p) {
33 0 : switch (p) {
34 0 : case Precision::Unknown:
35 0 : return core::ImageFormat::Undefined;
36 : break;
37 0 : case Precision::F8:
38 0 : return core::ImageFormat::R8G8B8A8_UNORM;
39 : break;
40 0 : case Precision::F16:
41 0 : return core::ImageFormat::R16G16B16A16_SFLOAT;
42 : break;
43 0 : case Precision::F32:
44 0 : return core::ImageFormat::R32G32B32A32_SFLOAT;
45 : break;
46 0 : case Precision::F64:
47 0 : return core::ImageFormat::R64G64B64A64_SFLOAT;
48 : break;
49 : }
50 0 : return core::ImageFormat::Undefined;
51 : }
52 :
53 0 : Conv2DLayer::~Conv2DLayer() { }
54 :
55 0 : bool Conv2DLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front *front,
56 : const AttachmentData *input, const AttachmentData *output) {
57 : using namespace core;
58 :
59 0 : auto precision = getAttachmentPrecision(output);
60 :
61 0 : _front = front;
62 :
63 0 : auto kernelImage = queueBuilder.addImageByRef(toString(front->getName(), "_kernelImage"),
64 0 : ImageInfo(front->getKernelExtent(), ImageUsage::Storage, ImageTiling::Optimal,
65 0 : getPrecisionKernelFormat(precision), PassType::Compute, ImageHints::Static),
66 0 : front->getKernelImageData(), AttachmentLayout::General);
67 :
68 0 : auto biasBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_biasBuffer"),
69 0 : BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
70 0 : front->getBiasBufferData());
71 :
72 0 : auto betaBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_betaBuffer"),
73 0 : BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
74 0 : front->getNormBetaBufferData());
75 :
76 0 : auto gammaBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_gammaBuffer"),
77 0 : BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
78 0 : front->getNormGammaBufferData());
79 :
80 0 : auto meanBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_meanBuffer"),
81 0 : BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
82 0 : front->getNormMeanBufferData());
83 :
84 0 : auto varianceBuffer = queueBuilder.addBufferByRef(toString(front->getName(), "_varianceBuffer"),
85 0 : BufferInfo(BufferUsage::StorageBuffer, BufferPersistent(true), PassType::Compute),
86 0 : front->getNormVarianceBufferData());
87 :
88 0 : auto kernelAttachment = queueBuilder.addAttachemnt(toString(front->getName(), "_kernel"),
89 0 : [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
90 0 : return Rc<vk::ImageAttachment>::create(builder,
91 0 : kernelImage,
92 0 : ImageAttachment::AttachmentInfo{
93 : .initialLayout = AttachmentLayout::Ignored,
94 : .finalLayout = AttachmentLayout::Ignored,
95 : .clearOnLoad = false
96 : }
97 0 : );
98 : });
99 :
100 0 : auto dataAttachment = queueBuilder.addAttachemnt(toString(front->getName(), "_data"),
101 0 : [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
102 0 : return Rc<vk::BufferAttachment>::create(builder,
103 0 : Vector<const BufferData *>{
104 0 : biasBuffer,
105 0 : betaBuffer,
106 0 : gammaBuffer,
107 0 : meanBuffer,
108 0 : varianceBuffer
109 : }
110 0 : );
111 : });
112 :
113 0 : auto passInput = builder.addAttachment(input, [] (AttachmentPassBuilder &builder) {
114 0 : builder.setDependency(AttachmentDependencyInfo{
115 : PipelineStage::ComputeShader, AccessType::ShaderRead,
116 : PipelineStage::ComputeShader, AccessType::ShaderRead,
117 : FrameRenderPassState::Submitted,
118 : });
119 0 : builder.setInitialLayout(AttachmentLayout::General);
120 0 : builder.setFinalLayout(AttachmentLayout::General);
121 0 : });
122 :
123 0 : auto passOutput = builder.addAttachment(output, [] (AttachmentPassBuilder &builder) {
124 0 : builder.setDependency(AttachmentDependencyInfo{
125 : PipelineStage::ComputeShader, AccessType::ShaderWrite,
126 : PipelineStage::ComputeShader, AccessType::ShaderWrite,
127 : FrameRenderPassState::Submitted,
128 : });
129 0 : builder.setInitialLayout(AttachmentLayout::General);
130 0 : builder.setFinalLayout(AttachmentLayout::General);
131 0 : });
132 :
133 0 : auto passKernel = builder.addAttachment(kernelAttachment, [] (AttachmentPassBuilder &builder) {
134 0 : builder.setDependency(AttachmentDependencyInfo{
135 : PipelineStage::ComputeShader, AccessType::ShaderRead,
136 : PipelineStage::ComputeShader, AccessType::ShaderRead,
137 : FrameRenderPassState::Submitted,
138 : });
139 0 : builder.setInitialLayout(AttachmentLayout::General);
140 0 : builder.setFinalLayout(AttachmentLayout::General);
141 0 : });
142 :
143 0 : auto passData = builder.addAttachment(dataAttachment);
144 :
145 0 : auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
146 0 : layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
147 0 : setBuilder.addDescriptor(passOutput, DescriptorType::StorageImage, AttachmentLayout::General);
148 0 : setBuilder.addDescriptor(passInput, DescriptorType::StorageImage, AttachmentLayout::General);
149 0 : setBuilder.addDescriptor(passKernel, DescriptorType::StorageImage, AttachmentLayout::General);
150 0 : setBuilder.addDescriptor(passData, DescriptorType::StorageBuffer);
151 0 : });
152 0 : });
153 :
154 0 : builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
155 0 : auto paddings = front->getPaddingOffset();
156 0 : auto kernel = front->getKernelSize();
157 0 : auto stride = front->getStride();
158 0 : auto mode = front->getPaddingMode();
159 0 : uint32_t dilate = 1;
160 :
161 0 : uint32_t paddingMode = 0;
162 0 : if (mode == "constant") {
163 0 : paddingMode = 1;
164 0 : } else if (mode == "replicate") {
165 0 : paddingMode = 2;
166 0 : } else if (mode == "reflect") {
167 0 : paddingMode = 3;
168 : }
169 :
170 0 : SpecializationInfo spec;
171 0 : spec.data = queueBuilder.addProgramByRef(toString(front->getName(), "_shader"), getShader(LayerShader::Conv2d, precision));
172 0 : spec.constants.emplace_back(SpecializationConstant(paddings.x)); // 0
173 0 : spec.constants.emplace_back(SpecializationConstant(paddings.z)); // 1
174 0 : spec.constants.emplace_back(SpecializationConstant(kernel)); // 2
175 0 : spec.constants.emplace_back(SpecializationConstant(kernel)); // 3
176 0 : spec.constants.emplace_back(SpecializationConstant(stride)); // 4
177 0 : spec.constants.emplace_back(SpecializationConstant(stride)); // 5
178 0 : spec.constants.emplace_back(SpecializationConstant(dilate)); // 6
179 0 : spec.constants.emplace_back(SpecializationConstant(dilate)); // 7
180 0 : spec.constants.emplace_back(SpecializationConstant(4)); // 8
181 0 : spec.constants.emplace_back(SpecializationConstant(front->getActivation())); // 9
182 0 : spec.constants.emplace_back(SpecializationConstant(paddingMode)); // 10
183 0 : spec.constants.emplace_back(SpecializationConstant(uint32_t(front->useBatchNormalization()))); // 11
184 0 : spec.constants.emplace_back(SpecializationConstant(uint32_t(front->useBias()))); // 12
185 0 : spec.constants.emplace_back(SpecializationConstant(front->getLeakyReluAlpha())); // 13
186 :
187 0 : subpassBuilder.addComputePipeline(toString(front->getName(), "_pipeline"), layout, move(spec));
188 0 : });
189 :
190 0 : _inputAttachment = input;
191 0 : _outputAttachment = output;
192 0 : _kernelAttachment = kernelAttachment;
193 0 : _dataAttachment = dataAttachment;
194 :
195 0 : _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
196 0 : return Rc<LayerHandle>::create(pass, q);
197 0 : };
198 :
199 0 : return QueuePass::init(builder);
200 : }
201 :
202 0 : bool Conv2DLayer::LayerHandle::prepare(FrameQueue &q, Function<void(bool)> &&cb) {
203 0 : auto pass = (Conv2DLayer *)_queuePass.get();
204 :
205 0 : if (auto imageAttachment = q.getAttachment(pass->getInputAttachment())) {
206 0 : _inputImage = (const vk::ImageAttachmentHandle *)imageAttachment->handle.get();
207 : }
208 :
209 0 : if (auto imageAttachment = q.getAttachment(pass->getOutputAttachment())) {
210 0 : _outputImage = (const vk::ImageAttachmentHandle *)imageAttachment->handle.get();
211 : }
212 :
213 0 : if (auto kernelAttachment = q.getAttachment(pass->getKernelAttachment())) {
214 0 : _kernelImage = (const vk::ImageAttachmentHandle *)kernelAttachment->handle.get();
215 : }
216 :
217 0 : if (auto bufferAttachment = q.getAttachment(pass->getDataAttachment())) {
218 0 : _dataHandle = bufferAttachment->handle.get();
219 : }
220 :
221 0 : _front = pass->getFront();
222 :
223 0 : return vk::QueuePassHandle::prepare(q, move(cb));
224 : }
225 :
226 0 : Vector<const vk::CommandBuffer *> Conv2DLayer::LayerHandle::doPrepareCommands(FrameHandle &frame) {
227 0 : auto buf = _pool->recordBuffer(*_device, [&] (vk::CommandBuffer &buf) {
228 0 : auto pass = _data->impl.cast<vk::RenderPass>().get();
229 0 : pass->perform(*this, buf, [&] {
230 0 : buf.cmdBindDescriptorSets(pass, 0);
231 :
232 0 : auto extent = _outputImage->getQueueData()->image->getInfo().extent;
233 :
234 0 : auto oc_4 = UP_DIV(_front->getNumOutputPlanes(), uint32_t(4));
235 :
236 0 : auto pipeline = static_cast<vk::ComputePipeline *>((*_data->subpasses[0]->computePipelines.begin())->pipeline.get());
237 :
238 0 : buf.cmdBindPipeline(pipeline);
239 0 : buf.cmdDispatch((extent.width - 1) / pipeline->getLocalX() + 1,
240 0 : (extent.height - 1) / pipeline->getLocalY() + 1,
241 0 : (oc_4 - 1) / pipeline->getLocalZ() + 1);
242 0 : }, true);
243 0 : return true;
244 : });
245 0 : return Vector<const vk::CommandBuffer *>{buf};
246 : }
247 :
248 : }
|