Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLVkPipeline.h"
24 : #include "XLCoreAttachment.h"
25 : #include "XLCoreFrameQueue.h"
26 : #include "XLCoreFrameRequest.h"
27 : #include "XLSnnVkActivationLayer.h"
28 : #include "XLSnnVkShaders.h"
29 :
30 : namespace stappler::xenolith::vk::shadernn {
31 :
32 0 : ActivationLayer::~ActivationLayer() { }
33 :
34 0 : bool ActivationLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, const AttachmentData *input, const AttachmentData *output) {
35 : using namespace core;
36 :
37 0 : auto dataBuffer = queueBuilder.addAttachemnt("ActivationLayerBuffer", [] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
38 0 : builder.defineAsInput();
39 0 : auto a = Rc<GenericAttachment>::create(builder);
40 0 : a->setValidateInputCallback([] (const Attachment &, const Rc<AttachmentInputData> &data) {
41 0 : return dynamic_cast<ActivationDataInput *>(data.get()) != nullptr;
42 : });
43 0 : a->setFrameHandleCallback([] (Attachment &a, const FrameQueue &q) {
44 0 : auto h = Rc<AttachmentHandle>::create(a, q);
45 0 : h->setInputCallback([] (AttachmentHandle &handle, FrameQueue &queue, AttachmentInputData *input, Function<void(bool)> &&cb) {
46 0 : cb(true);
47 0 : });
48 0 : return h;
49 0 : });
50 0 : return a;
51 0 : });
52 :
53 0 : auto passInput = builder.addAttachment(input, [] (AttachmentPassBuilder &builder) {
54 0 : builder.setDependency(AttachmentDependencyInfo{
55 : PipelineStage::ComputeShader, AccessType::ShaderRead,
56 : PipelineStage::ComputeShader, AccessType::ShaderRead,
57 : FrameRenderPassState::Submitted,
58 : });
59 0 : });
60 :
61 0 : auto passOutput = builder.addAttachment(output, [] (AttachmentPassBuilder &builder) {
62 0 : builder.setDependency(AttachmentDependencyInfo{
63 : PipelineStage::ComputeShader, AccessType::ShaderWrite,
64 : PipelineStage::ComputeShader, AccessType::ShaderWrite,
65 : FrameRenderPassState::Submitted,
66 : });
67 0 : });
68 :
69 0 : builder.addAttachment(dataBuffer);
70 :
71 0 : auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
72 0 : layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
73 0 : setBuilder.addDescriptor(passOutput, DescriptorType::StorageImage, AttachmentLayout::General);
74 0 : setBuilder.addDescriptor(passInput, DescriptorType::StorageImage, AttachmentLayout::General);
75 0 : });
76 0 : });
77 :
78 0 : auto precision = getAttachmentPrecision(output);
79 :
80 0 : builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
81 0 : subpassBuilder.addComputePipeline("ActivationLayerPipeline", layout,
82 0 : queueBuilder.addProgramByRef("ActivationLayerProgram", getShader(LayerShader::Activation, precision)));
83 0 : });
84 :
85 0 : _inputAttachment = input;
86 0 : _outputAttachment = output;
87 0 : _dataAttachment = dataBuffer;
88 :
89 0 : _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
90 0 : return Rc<LayerHandle>::create(pass, q);
91 0 : };
92 :
93 0 : return QueuePass::init(builder);
94 : }
95 :
96 0 : ActivationLayer::LayerHandle::~LayerHandle() { }
97 :
98 0 : bool ActivationLayer::LayerHandle::prepare(FrameQueue &q, Function<void(bool)> &&cb) {
99 0 : auto pass = (ActivationLayer *)_queuePass.get();
100 :
101 0 : if (auto imageAttachment = q.getAttachment(pass->getInputAttachment())) {
102 0 : _inputImage = (const vk::ImageAttachmentHandle *)imageAttachment->handle.get();
103 : }
104 :
105 0 : if (auto imageAttachment = q.getAttachment(pass->getOutputAttachment())) {
106 0 : _outputImage = (const vk::ImageAttachmentHandle *)imageAttachment->handle.get();
107 : }
108 :
109 0 : if (auto bufferAttachment = q.getAttachment(pass->getDataAttachment())) {
110 0 : _dataBuffer = (const vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
111 : }
112 :
113 0 : return vk::QueuePassHandle::prepare(q, move(cb));
114 : }
115 :
116 0 : Vector<const vk::CommandBuffer *> ActivationLayer::LayerHandle::doPrepareCommands(FrameHandle &) {
117 0 : auto buf = _pool->recordBuffer(*_device, [&] (vk::CommandBuffer &buf) {
118 0 : auto pass = _data->impl.cast<vk::RenderPass>().get();
119 0 : pass->perform(*this, buf, [&] {
120 0 : auto data = static_cast<ActivationDataInput *>(_dataBuffer->getInput());
121 :
122 0 : buf.cmdBindDescriptorSets(pass, 0);
123 0 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&data->data), sizeof(ActivationData)));
124 :
125 0 : auto pipeline = static_cast<vk::ComputePipeline *>((*_data->subpasses[0]->computePipelines.begin())->pipeline.get());
126 :
127 0 : buf.cmdBindPipeline(pipeline);
128 0 : buf.cmdDispatch((data->data.inputSize.x - 1) / pipeline->getLocalX() + 1,
129 0 : (data->data.inputSize.y - 1) / pipeline->getLocalY() + 1,
130 0 : (data->data.inputSize.z - 1) / pipeline->getLocalZ() + 1);
131 0 : }, true);
132 0 : return true;
133 : });
134 0 : return Vector<const vk::CommandBuffer *>{buf};
135 : }
136 :
137 : }
|