Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnVkStatPercentLayer.h"
24 : #include "XLSnnVkShaders.h"
25 :
26 : namespace stappler::xenolith::vk::shadernn {
27 :
28 0 : StatPercentLayer::~StatPercentLayer() { }
29 :
30 0 : bool StatPercentLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front *front,
31 : const AttachmentData *input, const AttachmentData *output) {
32 : using namespace core;
33 :
34 0 : auto classesBuffer = queueBuilder.addAttachemnt("StatPercentLayerClassesBuffer", [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
35 0 : return Rc<BufferAttachment>::create(builder, BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer));
36 : });
37 :
38 0 : auto passInput = builder.addAttachment(input);
39 0 : auto passOutput = builder.addAttachment(output);
40 0 : auto passClasses = builder.addAttachment(classesBuffer);
41 :
42 0 : auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
43 0 : layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
44 0 : setBuilder.addDescriptor(passOutput, DescriptorType::StorageBuffer);
45 0 : setBuilder.addDescriptor(passInput, DescriptorType::StorageBuffer);
46 0 : setBuilder.addDescriptor(passClasses, DescriptorType::StorageBuffer);
47 0 : });
48 0 : });
49 :
50 0 : builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
51 0 : subpassBuilder.addComputePipeline(StatPercentLayerClassesPipeline, layout,
52 0 : queueBuilder.addProgramByRef("StatPercentLayerClassesPProgram", getShader(LayerShader::StatClassMap, Precision::Unknown)));
53 0 : subpassBuilder.addComputePipeline(StatPercentLayerPercentPipeline, layout,
54 0 : queueBuilder.addProgramByRef("StatPercentLayerPercentProgram", getShader(LayerShader::StatClassPercent, Precision::Unknown)));
55 0 : });
56 :
57 0 : _inputAttachment = input;
58 0 : _outputAttachment = output;
59 0 : _classesAttachment = classesBuffer;
60 0 : _front = front;
61 :
62 0 : _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
63 0 : return Rc<LayerHandle>::create(pass, q);
64 0 : };
65 :
66 0 : return QueuePass::init(builder);
67 : }
68 :
69 0 : bool StatPercentLayer::LayerHandle::prepare(FrameQueue &q, Function<void(bool)> &&cb) {
70 0 : auto pass = (StatPercentLayer *)_queuePass.get();
71 :
72 0 : if (auto imageAttachment = q.getAttachment(pass->getInputAttachment())) {
73 0 : _inputBuffer = (vk::BufferAttachmentHandle *)imageAttachment->handle.get();
74 : }
75 :
76 0 : if (auto imageAttachment = q.getAttachment(pass->getOutputAttachment())) {
77 0 : _outputBuffer = (vk::BufferAttachmentHandle *)imageAttachment->handle.get();
78 : }
79 :
80 0 : if (auto bufferAttachment = q.getAttachment(pass->getClassesAttachment())) {
81 0 : _classesBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
82 : }
83 :
84 0 : _front = pass->getFront();
85 :
86 0 : auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
87 0 : auto &pool = handle->getMemPool(nullptr);
88 0 : auto extent = handle->getFrameConstraints().extent;
89 :
90 0 : _classesSizes = pool->spawnPersistent(AllocationUsage::DeviceLocal,
91 0 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
92 0 : size_t(_front->getClassCount() * sizeof(uint32_t))
93 0 : ));
94 0 : _classesIndexes = pool->spawn(AllocationUsage::DeviceLocal,
95 0 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
96 0 : size_t(_front->getClassCount() * extent.height * sizeof(uint32_t))
97 0 : ));
98 0 : _output = pool->spawnPersistent(AllocationUsage::DeviceLocal,
99 0 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
100 0 : size_t(_front->getClassCount() * (sizeof(float) * 4 + sizeof(uint32_t) * 4))
101 0 : ));
102 :
103 0 : _classesBuffer->addBufferView(_classesSizes);
104 0 : _classesBuffer->addBufferView(_classesIndexes);
105 0 : _outputBuffer->addBufferView(_output);
106 :
107 0 : return vk::QueuePassHandle::prepare(q, move(cb));
108 : }
109 :
110 0 : Vector<const vk::CommandBuffer *> StatPercentLayer::LayerHandle::doPrepareCommands(FrameHandle &handle) {
111 0 : auto buf = _pool->recordBuffer(*_device, [&] (vk::CommandBuffer &buf) {
112 0 : auto pass = _data->impl.cast<vk::RenderPass>().get();
113 0 : pass->perform(*this, buf, [&] {
114 : struct ClassesInputInfo {
115 : int size;
116 : int fields;
117 : int fieldClass;
118 : int classMin;
119 : int classMax;
120 : int fieldSource;
121 : int fieldTarget;
122 : int classCount;
123 : };
124 :
125 0 : auto extent = handle.getFrameConstraints().extent;
126 :
127 : ClassesInputInfo pcb1;
128 0 : pcb1.size = extent.height;
129 0 : pcb1.fields = _inputBuffer->getBuffers().front().buffer->getSize() / (sizeof(uint64_t) * pcb1.size);
130 0 : pcb1.fieldClass = _front->getFieldClass();
131 0 : pcb1.classMin = _front->getClassMin();
132 0 : pcb1.classMax = _front->getClassMin() + _front->getClassCount() - 1;
133 0 : pcb1.fieldSource = _front->getFieldSource();
134 0 : pcb1.fieldTarget = _front->getFieldTarget();
135 0 : pcb1.classCount = _front->getClassCount();
136 :
137 0 : buf.cmdFillBuffer(_classesIndexes, 0);
138 0 : buf.cmdFillBuffer(_classesSizes, 0);
139 :
140 : BufferMemoryBarrier b[2] = {
141 : BufferMemoryBarrier(_classesSizes, VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_WRITE_BIT),
142 : BufferMemoryBarrier(_classesSizes, VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_WRITE_BIT)
143 0 : };
144 :
145 0 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(b, 2));
146 :
147 0 : buf.cmdBindDescriptorSets(pass, 0);
148 0 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&pcb1), sizeof(ClassesInputInfo)));
149 :
150 0 : vk::ComputePipeline *classesPipeline = nullptr;
151 0 : auto classesPipelineIt = _data->subpasses[0]->computePipelines.find(StatPercentLayerClassesPipeline);
152 0 : if (classesPipelineIt != _data->subpasses[0]->computePipelines.end()) {
153 0 : classesPipeline = static_cast<vk::ComputePipeline *>((*classesPipelineIt)->pipeline.get());
154 : }
155 :
156 0 : buf.cmdBindPipeline(classesPipeline);
157 0 : buf.cmdDispatch(1, (pcb1.size - 1) / classesPipeline->getLocalY() + 1, 1);
158 :
159 0 : vk::ComputePipeline *percentPipeline = nullptr;
160 0 : auto percentPipelineeIt = _data->subpasses[0]->computePipelines.find(StatPercentLayerPercentPipeline);
161 0 : if (percentPipelineeIt != _data->subpasses[0]->computePipelines.end()) {
162 0 : percentPipeline = static_cast<vk::ComputePipeline *>((*percentPipelineeIt)->pipeline.get());
163 : }
164 :
165 0 : buf.cmdBindPipeline(percentPipeline);
166 0 : buf.cmdDispatch((pcb1.classCount - 1) / percentPipeline->getLocalX() + 1, 1, 1);
167 0 : }, true);
168 0 : return true;
169 : });
170 0 : return Vector<const vk::CommandBuffer *>{buf};
171 : }
172 :
173 0 : void StatPercentLayer::LayerHandle::doSubmitted(FrameHandle &h, Function<void(bool)> &&cb, bool s, Rc<Fence> &&fence) {
174 0 : vk::QueuePassHandle::doSubmitted(h, move(cb), s, move(fence));
175 :
176 : /*h.getLoop()->captureBuffer([] (const BufferInfo &info, BytesView view) {
177 : std::cout << view.size() / (sizeof(float) * 4) << "\n";
178 :
179 : std::cout << "0";
180 :
181 : size_t row = 1;
182 : size_t i = 0;
183 : while (!view.empty()) {
184 : switch (i) {
185 : case 0:
186 : case 1:
187 : case 2:
188 : case 3:
189 : std::cout << ", " << view.readFloat32();
190 : break;
191 : case 4:
192 : case 5:
193 : case 6:
194 : std::cout << ", " << view.readUnsigned32();
195 : break;
196 : default:
197 : view.readUnsigned32();
198 : break;
199 : }
200 : ++ i;
201 : if (i > 7) {
202 : i = 0;
203 : std::cout << "\n" << row ++;
204 : }
205 : }
206 : std::cout << "\n";
207 : }, _output);*/
208 0 : }
209 :
210 0 : StatAnalysisLayer::~StatAnalysisLayer() { }
211 :
212 0 : bool StatAnalysisLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front *front,
213 : const AttachmentData *inputData, const AttachmentData *inputClasses, const AttachmentData *output) {
214 : using namespace core;
215 :
216 0 : auto passInputData = builder.addAttachment(inputData);
217 0 : auto passInputClasses = builder.addAttachment(inputClasses);
218 0 : auto passOutput = builder.addAttachment(output);
219 :
220 0 : auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
221 0 : layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
222 0 : setBuilder.addDescriptor(passOutput, DescriptorType::StorageBuffer);
223 0 : setBuilder.addDescriptor(passInputData, DescriptorType::StorageBuffer);
224 0 : setBuilder.addDescriptor(passInputClasses, DescriptorType::StorageBuffer);
225 0 : });
226 0 : });
227 :
228 0 : builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
229 0 : subpassBuilder.addComputePipeline("StatAnalysisLayerProgram", layout,
230 0 : queueBuilder.addProgramByRef("StatAnalysisLayerProgram", getShader(LayerShader::StatAnalysis, Precision::Unknown)));
231 0 : });
232 :
233 0 : _inputDataAttachment = inputData;
234 0 : _inputClassesAttachment = inputClasses;
235 0 : _outputAttachment = output;
236 0 : _front = front;
237 :
238 0 : _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
239 0 : return Rc<LayerHandle>::create(pass, q);
240 0 : };
241 :
242 0 : return QueuePass::init(builder);
243 : }
244 :
245 0 : bool StatAnalysisLayer::LayerHandle::prepare(FrameQueue &q, Function<void(bool)> &&cb) {
246 0 : auto pass = (StatAnalysisLayer *)_queuePass.get();
247 :
248 0 : if (auto attachment = q.getAttachment(pass->getInputDataAttachment())) {
249 0 : _inputDataBuffer = (vk::BufferAttachmentHandle *)attachment->handle.get();
250 : }
251 :
252 0 : if (auto attachment = q.getAttachment(pass->getInputClassesAttachment())) {
253 0 : _inputClassesBuffer = (vk::BufferAttachmentHandle *)attachment->handle.get();
254 : }
255 :
256 0 : if (auto imageAttachment = q.getAttachment(pass->getOutputAttachment())) {
257 0 : _outputBuffer = (vk::BufferAttachmentHandle *)imageAttachment->handle.get();
258 : }
259 :
260 0 : _front = pass->getFront();
261 :
262 0 : auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
263 0 : auto &pool = handle->getMemPool(nullptr);
264 0 : auto extent = handle->getFrameConstraints().extent;
265 :
266 0 : _output = pool->spawnPersistent(AllocationUsage::DeviceLocal,
267 0 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
268 0 : size_t(extent.height * (sizeof(float) * 4))
269 0 : ));
270 :
271 0 : _outputBuffer->addBufferView(_output);
272 :
273 0 : return vk::QueuePassHandle::prepare(q, move(cb));
274 : }
275 :
276 0 : Vector<const vk::CommandBuffer *> StatAnalysisLayer::LayerHandle::doPrepareCommands(FrameHandle &handle) {
277 0 : auto buf = _pool->recordBuffer(*_device, [&] (vk::CommandBuffer &buf) {
278 0 : auto pass = _data->impl.cast<vk::RenderPass>().get();
279 0 : pass->perform(*this, buf, [&] {
280 : struct InputInfo {
281 : int size;
282 : int fields;
283 : int fieldClass;
284 : int classMin;
285 : int classMax;
286 : int fieldSource;
287 : int fieldTarget;
288 : int classCount;
289 : float threshold;
290 : };
291 :
292 0 : auto extent = handle.getFrameConstraints().extent;
293 :
294 : InputInfo pcb1;
295 0 : pcb1.size = extent.height;
296 0 : pcb1.fields = _inputDataBuffer->getBuffers().front().buffer->getSize() / (sizeof(uint64_t) * pcb1.size);
297 0 : pcb1.fieldClass = _front->getFieldClass();
298 0 : pcb1.classMin = _front->getClassMin();
299 0 : pcb1.classMax = _front->getClassMin() + _front->getClassCount() - 1;
300 0 : pcb1.fieldSource = _front->getFieldSource();
301 0 : pcb1.fieldTarget = _front->getFieldTarget();
302 0 : pcb1.classCount = _front->getClassCount();
303 0 : pcb1.threshold = _front->getThreshold();
304 :
305 0 : buf.cmdBindDescriptorSets(pass, 0);
306 0 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&pcb1), sizeof(InputInfo)));
307 :
308 0 : auto pipeline = static_cast<vk::ComputePipeline *>((*_data->subpasses[0]->computePipelines.begin())->pipeline.get());
309 :
310 0 : buf.cmdBindPipeline(pipeline);
311 0 : buf.cmdDispatch((pcb1.size - 1) / pipeline->getLocalX() + 1, 1, 1);
312 0 : }, true);
313 0 : return true;
314 : });
315 0 : return Vector<const vk::CommandBuffer *>{buf};
316 : }
317 :
318 0 : void StatAnalysisLayer::LayerHandle::doSubmitted(FrameHandle &h, Function<void(bool)> &&cb, bool s, Rc<Fence> &&fence) {
319 0 : vk::QueuePassHandle::doSubmitted(h, move(cb), s, move(fence));
320 :
321 : /*h.getLoop()->captureBuffer([] (const BufferInfo &info, BytesView view) {
322 : std::cout << view.size() / (sizeof(float) * 4) << "\n";
323 : }, _output);*/
324 0 : }
325 :
326 : }
|