Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLVkPipeline.h"
24 : #include "XLCoreAttachment.h"
25 : #include "XLCoreFrameQueue.h"
26 : #include "XLCoreFrameRequest.h"
27 : #include "XLSnnVkGenerationLayer.h"
28 : #include "XLSnnVkShaders.h"
29 :
30 : namespace stappler::xenolith::vk::shadernn {
31 :
32 0 : GenerationLayer::~GenerationLayer() { }
33 :
34 0 : bool GenerationLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, const AttachmentData *output) {
35 : using namespace core;
36 :
37 0 : auto dataBuffer = queueBuilder.addAttachemnt("GenerationLayerData", [] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
38 0 : builder.defineAsInput();
39 0 : auto a = Rc<GenericAttachment>::create(builder);
40 0 : a->setValidateInputCallback([] (const Attachment &, const Rc<AttachmentInputData> &data) {
41 0 : return dynamic_cast<GenerationDataInput *>(data.get()) != nullptr;
42 : });
43 0 : a->setFrameHandleCallback([] (Attachment &a, const FrameQueue &q) {
44 0 : auto h = Rc<core::AttachmentHandle>::create(a, q);
45 0 : h->setInputCallback([] (AttachmentHandle &handle, FrameQueue &queue, AttachmentInputData *input, Function<void(bool)> &&cb) {
46 0 : cb(true);
47 0 : });
48 0 : return h;
49 0 : });
50 0 : return a;
51 0 : });
52 :
53 0 : auto passOutput = builder.addAttachment(output, [] (AttachmentPassBuilder &builder) {
54 0 : builder.setDependency(AttachmentDependencyInfo{
55 : PipelineStage::ComputeShader, AccessType::ShaderWrite | AccessType::ShaderRead,
56 : PipelineStage::ComputeShader, AccessType::ShaderWrite | AccessType::ShaderRead,
57 : FrameRenderPassState::Submitted,
58 : });
59 0 : });
60 :
61 0 : builder.addAttachment(dataBuffer);
62 :
63 0 : auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
64 0 : layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
65 0 : setBuilder.addDescriptor(passOutput, DescriptorType::StorageImage, AttachmentLayout::General);
66 0 : });
67 0 : });
68 :
69 0 : auto precision = getAttachmentPrecision(output);
70 :
71 0 : builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
72 0 : subpassBuilder.addComputePipeline("GenerationLayerPipeline", layout,
73 0 : queueBuilder.addProgramByRef("GenerationLayerPipeline", getShader(LayerShader::Gen, precision)));
74 0 : });
75 :
76 0 : _outputAttachment = output;
77 0 : _dataAttachment = dataBuffer;
78 :
79 0 : _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
80 0 : return Rc<LayerHandle>::create(pass, q);
81 0 : };
82 :
83 0 : return QueuePass::init(builder);
84 : }
85 :
86 0 : GenerationLayer::LayerHandle::~LayerHandle() { }
87 :
88 0 : bool GenerationLayer::LayerHandle::prepare(FrameQueue &q, Function<void(bool)> &&cb) {
89 0 : auto pass = (GenerationLayer *)_queuePass.get();
90 :
91 0 : if (auto imageAttachment = q.getAttachment(pass->getOutputAttachment())) {
92 0 : _outputImage = (const vk::ImageAttachmentHandle *)imageAttachment->handle.get();
93 : }
94 :
95 0 : if (auto bufferAttachment = q.getAttachment(pass->getDataAttachment())) {
96 0 : _dataBuffer = (const vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
97 : }
98 :
99 0 : return vk::QueuePassHandle::prepare(q, move(cb));
100 : }
101 :
102 0 : Vector<const vk::CommandBuffer *> GenerationLayer::LayerHandle::doPrepareCommands(FrameHandle &handle) {
103 0 : auto buf = _pool->recordBuffer(*_device, [&] (vk::CommandBuffer &buf) {
104 0 : auto pass = _data->impl.cast<vk::RenderPass>().get();
105 0 : pass->perform(*this, buf, [&] {
106 0 : auto extent = handle.getFrameConstraints().extent;
107 0 : auto input = static_cast<GenerationDataInput *>(_dataBuffer->getInput());
108 :
109 0 : buf.cmdBindDescriptorSets(pass, 0);
110 0 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&input->data), sizeof(GenerationData)));
111 :
112 0 : auto pipeline = static_cast<vk::ComputePipeline *>((*_data->subpasses[0]->computePipelines.begin())->pipeline.get());
113 :
114 0 : buf.cmdBindPipeline(pipeline);
115 0 : buf.cmdDispatch((extent.width - 1) / pipeline->getLocalX() + 1,
116 0 : (extent.height - 1) / pipeline->getLocalY() + 1,
117 0 : (extent.depth - 1) / pipeline->getLocalZ() + 1);
118 0 : }, true);
119 0 : return true;
120 : });
121 0 : return Vector<const vk::CommandBuffer *>{buf};
122 : }
123 :
124 : }
|