Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnVkMatrixMulLayer.h"
24 : #include "XLSnnVkNmeMath.h"
25 : #include "XLSnnMatrixMulLayer.h"
26 : #include "XLSnnModel.h"
27 :
28 : namespace stappler::xenolith::vk::shadernn {
29 :
30 6 : MatrixMulLayer::~MatrixMulLayer() { }
31 :
32 3 : bool MatrixMulLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front *front,
33 : const AttachmentData *input, const AttachmentData *output) {
34 : using namespace core;
35 :
36 3 : _front = front;
37 :
38 12 : auto weightsBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_weights_buffer"),
39 6 : BufferInfo(front->getWeightBufferSize(), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
40 3 : [front = _front] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
41 : /*auto name = toString(front->getName(), ".", front->getInputIndex(), ".weights.bin");
42 : xenolith::shadernn::Model::loadBlob(name.data(), [&] (const uint8_t *blob, size_t s) {
43 : memcpy(buf, blob, size);
44 : });*/
45 :
46 3 : front->generateWeights(buf, size, cb);
47 3 : });
48 :
49 12 : auto freeTermsBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_freeTerms_buffer"),
50 6 : BufferInfo(front->getKernelSize() * sizeof(float), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
51 3 : [front = _front] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
52 : /*auto name = toString(front->getName(), ".", front->getInputIndex(), ".terms.bin");
53 : xenolith::shadernn::Model::loadBlob(name.data(), [&] (const uint8_t *blob, size_t s) {
54 : memcpy(buf, blob, size);
55 : });*/
56 :
57 3 : front->generateFreeTerms(buf, size, cb);
58 3 : });
59 :
60 3 : _nbuffers = 4;
61 3 : _weightsBufferIndex = 0;
62 3 : _freeTermBufferIndex = 1;
63 3 : _inputBufferIndex = 2;
64 3 : _outputBufferIndex = 3;
65 :
66 3 : const AttachmentData *weightsAttachment = nullptr;
67 :
68 3 : weightsAttachment = queueBuilder.addAttachemnt(toString(builder.getName(), "_weights"),
69 3 : [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
70 6 : return Rc<vk::BufferAttachment>::create(builder, Vector<const BufferData *>{
71 3 : weightsBuffer,
72 3 : freeTermsBuffer
73 6 : });
74 : });
75 :
76 3 : builder.addAttachment(input, AttachmentDependencyInfo::make(PipelineStage::ComputeShader, AccessType::ShaderRead));
77 3 : builder.addAttachment(output);
78 3 : auto passWeights = builder.addAttachment(weightsAttachment);
79 :
80 3 : auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
81 3 : layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
82 6 : setBuilder.addDescriptorArray(passWeights, _nbuffers, DescriptorType::StorageBuffer);
83 3 : });
84 6 : });
85 :
86 3 : builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
87 :
88 3 : auto matMul = subpassBuilder.addComputePipeline(toString(builder.getName(), "_matMul_pipeline"), layout,
89 9 : SpecializationInfo(
90 6 : queueBuilder.addProgramByRef(toString(builder.getName(), "_matMul_shader"),
91 3 : getShader(LayerShader::MultiplyMatrixByTransposedMatrix, Precision::Unknown)),
92 6 : Vector<SpecializationConstant>{
93 3 : SpecializationConstant(_nbuffers), // nbuffers
94 : SpecializationConstant(_outputBufferIndex), // output
95 : SpecializationConstant(_inputBufferIndex), // input
96 : SpecializationConstant(_weightsBufferIndex), // weight
97 3 : SpecializationConstant(_front->getInputIndex())
98 : }));
99 :
100 3 : auto matMulBorders = subpassBuilder.addComputePipeline(toString(builder.getName(), "_matMulBorders_pipeline"), layout,
101 9 : SpecializationInfo(
102 6 : queueBuilder.addProgramByRef(toString(builder.getName(), "_matMulBorders_shader"),
103 3 : getShader(LayerShader::MultiplyMatrixByTransposedMatrixBorder, Precision::Unknown)),
104 9 : Vector<SpecializationConstant>{
105 : SpecializationConstant(_nbuffers), // nbuffers
106 : SpecializationConstant(_outputBufferIndex), // output
107 : SpecializationConstant(_inputBufferIndex), // input
108 : SpecializationConstant(_weightsBufferIndex), // weight
109 3 : SpecializationConstant(_front->getInputIndex())
110 : }));
111 :
112 3 : auto addVec = subpassBuilder.addComputePipeline(toString(builder.getName(), "_addVec_pipeline"), layout,
113 9 : SpecializationInfo(
114 6 : queueBuilder.addProgramByRef(toString(builder.getName(), "_addVec_shader"),
115 3 : getShader(LayerShader::AddVectorToMatrixRows, Precision::Unknown)),
116 6 : Vector<SpecializationConstant>{
117 : SpecializationConstant(_nbuffers), // nbuffers
118 : SpecializationConstant(_outputBufferIndex), // output
119 : SpecializationConstant(_outputBufferIndex), // output
120 : SpecializationConstant(_freeTermBufferIndex) // terms
121 : }));
122 :
123 3 : auto relu = subpassBuilder.addComputePipeline(toString(builder.getName(), "_relu_pipeline"), layout,
124 9 : SpecializationInfo(
125 6 : queueBuilder.addProgramByRef(toString(builder.getName(), "_relu_shader"),
126 3 : getShader(LayerShader::VectorReLU, Precision::Unknown)),
127 6 : Vector<SpecializationConstant>{
128 : SpecializationConstant(_nbuffers), // nbuffers
129 : SpecializationConstant(_outputBufferIndex), // output
130 : SpecializationConstant(_outputBufferIndex), // output
131 : }));
132 :
133 3 : auto relu4 = subpassBuilder.addComputePipeline(toString(builder.getName(), "_relu4_pipeline"), layout,
134 9 : SpecializationInfo(
135 6 : queueBuilder.addProgramByRef(toString(builder.getName(), "_relu4_shader"),
136 3 : getShader(LayerShader::VectorReLU4, Precision::Unknown)),
137 6 : Vector<SpecializationConstant>{
138 : SpecializationConstant(_nbuffers), // nbuffers
139 : SpecializationConstant(_outputBufferIndex), // output
140 : SpecializationConstant(_outputBufferIndex), // output
141 : }));
142 :
143 3 : subpassBuilder.setPrepareCallback([this] (const SubpassData &subpass, FrameQueue &q) {
144 : // log::debug("MatrixMulLayer", getName(), ": prepare");
145 3600 : auto layer = (MatrixMulLayer *)subpass.pass->pass.get();
146 :
147 3600 : vk::BufferAttachmentHandle *inputBuffer = nullptr;
148 3600 : vk::BufferAttachmentHandle *outputBuffer = nullptr;
149 3600 : vk::BufferAttachmentHandle *weightsBuffer = nullptr;
150 :
151 3600 : if (auto bufferAttachment = q.getAttachment(layer->getInputAttachment())) {
152 3600 : inputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
153 : }
154 :
155 3600 : if (auto bufferAttachment = q.getAttachment(layer->getOutputAttachment())) {
156 3600 : outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
157 : }
158 :
159 3600 : if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
160 3600 : weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
161 : }
162 :
163 3600 : auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
164 3600 : auto &pool = handle->getMemPool(nullptr);
165 :
166 3600 : auto extent = layer->getFront()->getOutputExtent();
167 :
168 3600 : auto input = inputBuffer->getBuffers().front().buffer;
169 : auto output = pool->spawn(AllocationUsage::DeviceLocal,
170 3600 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
171 3600 : size_t(extent.depth * layer->getFront()->getKernelSize() * sizeof(float))
172 7200 : ));
173 :
174 : auto feedback = pool->spawn(AllocationUsage::DeviceLocal,
175 3600 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
176 3600 : size_t(output->getSize())
177 7200 : ));
178 :
179 3600 : weightsBuffer->addBufferView(input);
180 3600 : weightsBuffer->addBufferView(output);
181 :
182 3600 : outputBuffer->addBufferView(output);
183 3600 : outputBuffer->addBufferView(feedback);
184 3600 : });
185 :
186 6 : subpassBuilder.setCommandsCallback(
187 3 : [this, outputBufferIndex = _outputBufferIndex, matMul, matMulBorders, addVec, relu4, relu]
188 15600 : (const SubpassData &subpass, FrameQueue &q, core::CommandBuffer &b) {
189 : // log::debug("MatrixMulLayer", getName(), ": commands");
190 3600 : auto &buf = static_cast<CommandBuffer &>(b);
191 3600 : auto pass = static_cast<RenderPass *>(subpass.pass->impl.get());
192 3600 : auto layer = static_cast<MatrixMulLayer *>(subpass.pass->pass.get());
193 :
194 3600 : vk::BufferAttachmentHandle *weightsBuffer = nullptr;
195 3600 : if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
196 3600 : weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
197 : }
198 :
199 3600 : vk::BufferAttachmentHandle *outputBuffer = nullptr;
200 3600 : if (auto bufferAttachment = q.getAttachment(layer->getOutputAttachment())) {
201 3600 : outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
202 : }
203 :
204 3600 : auto output = outputBuffer->getBuffers()[0].buffer;
205 3600 : auto feedback = outputBuffer->getBuffers()[1].buffer;
206 :
207 3600 : auto kernelSize = layer->getFront()->getWeightSize();
208 :
209 3600 : const int secondHeight = kernelSize.height;
210 3600 : const int secondWidth = kernelSize.width;
211 :
212 3600 : auto input = layer->getFront()->getInput();
213 :
214 3600 : const int firstHeight = input->getOutputExtent().depth;
215 3600 : const int firstWidth = input->getOutputExtent().width;
216 3600 : const int resultWidth = layer->getFront()->getOutputExtent().width;
217 :
218 : /*std::cout << "secondHeight: " << secondHeight << "\n";
219 : std::cout << "secondWidth: " << secondWidth << "\n";
220 : std::cout << "firstHeight: " << firstHeight << "\n";
221 : std::cout << "firstWidth: " << firstWidth << "\n";
222 : std::cout << "resultWidth: " << resultWidth << "\n";
223 : std::cout << "\n";*/
224 :
225 3600 : buf.cmdBindDescriptorSets(pass, 0);
226 :
227 3600 : auto flags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
228 :
229 : BufferMemoryBarrier barriers[4] = {
230 3600 : BufferMemoryBarrier(weightsBuffer->getBuffers()[0].buffer, flags, flags),
231 3600 : BufferMemoryBarrier(weightsBuffer->getBuffers()[1].buffer, flags, flags),
232 3600 : BufferMemoryBarrier(weightsBuffer->getBuffers()[2].buffer, flags, flags),
233 3600 : BufferMemoryBarrier(weightsBuffer->getBuffers()[3].buffer, flags, flags)
234 18000 : };
235 :
236 3600 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 4));
237 :
238 3600 : MultiplyMatrixByTransposedMatrix(
239 3600 : buf, static_cast<ComputePipeline *>(matMul->pipeline.get()), static_cast<ComputePipeline *>(matMulBorders->pipeline.get()),
240 : /*first inputData, */firstHeight, firstWidth, firstWidth,
241 : /*second weightData, */secondHeight, secondWidth,
242 : /*resultoutputData, */resultWidth, /*unused*/0 );
243 :
244 0 : BufferMemoryBarrier barrier(output, VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead),
245 3600 : VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead));
246 3600 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
247 :
248 3600 : AddVectorToMatrixRows(buf, static_cast<ComputePipeline *>(addVec->pipeline.get()), /*batchSize*/1, firstHeight, resultWidth);
249 :
250 3600 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 4));
251 :
252 : // save original output
253 3600 : buf.cmdCopyBuffer(output, feedback);
254 :
255 3600 : if (layer->getFront()->getActivation() == Activation::RELU) {
256 2400 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
257 2400 : VectorReLU(buf, static_cast<ComputePipeline *>(relu4->pipeline.get()), static_cast<ComputePipeline *>(relu->pipeline.get()),
258 2400 : output->getSize() / sizeof(float), 0.0f);
259 : }
260 :
261 : BufferMemoryBarrier barrier1[1] = {
262 : BufferMemoryBarrier(feedback, flags, flags),
263 3600 : };
264 :
265 3600 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barrier1, 1));
266 3600 : });
267 3 : });
268 :
269 3 : builder.addCompleteCallback([this] (const QueuePassData &, FrameQueue &q, bool success) {
270 : // log::debug("MatrixMulLayer", getName(), ": submitted");
271 : /*vk::BufferAttachmentHandle *weightsBuffer = nullptr;
272 : vk::BufferAttachmentHandle *outputBuffer = nullptr;
273 :
274 : if (auto bufferAttachment = q.getAttachment(getWeightsAttachment())) {
275 : weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
276 : }
277 :
278 : if (auto bufferAttachment = q.getAttachment(getOutputAttachment())) {
279 : outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
280 : }
281 :
282 : auto sec = Time::now().toSeconds();
283 :
284 : q.getFrame()->getLoop()->captureBuffer([this, sec] (const BufferInfo &info, BytesView view) {
285 : xenolith::shadernn::Model::saveBlob(
286 : filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".weights.bin")).data(),
287 : view.data(), view.size());
288 : }, weightsBuffer->getBuffers()[_weightsBufferIndex].buffer.get());
289 :
290 : q.getFrame()->getLoop()->captureBuffer([this, sec] (const BufferInfo &info, BytesView view) {
291 : xenolith::shadernn::Model::saveBlob(
292 : filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".terms.bin")).data(),
293 : view.data(), view.size());
294 : }, weightsBuffer->getBuffers()[_freeTermBufferIndex].buffer.get());*/
295 :
296 : /*outputBuffer->getBuffers()[1].buffer->map([this, sec] (uint8_t *data, VkDeviceSize size) {
297 : std::cout << getName() << " ";
298 : base16::encode(std::cout, BytesView(data, 64));
299 : std::cout << "\n";
300 : xenolith::shadernn::Model::saveBlob(
301 : filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".output.bin")).data(),
302 : data, size);
303 : });*/
304 :
305 : /*q.getFrame()->getLoop()->captureBuffer([this, sec] (const BufferInfo &info, BytesView view) {
306 : xenolith::shadernn::Model::saveBlob(
307 : filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".input.bin")).data(),
308 : view.data(), view.size());
309 : }, weightsBuffer->getBuffers()[_inputBufferIndex].buffer.get());
310 :
311 : q.getFrame()->getLoop()->captureBuffer([this, weightsBuffer] (const BufferInfo &info, BytesView view) {
312 : xenolith::shadernn::Model::saveBlob(
313 : filesystem::currentDir<Interface>(toString(getName(),".", _front->getInputIndex(), ".output.bin")).data(),
314 : view.data(), view.size());
315 : }, weightsBuffer->getBuffers()[_outputBufferIndex].buffer.get());*/
316 3600 : });
317 :
318 3 : _inputAttachment = input;
319 3 : _outputAttachment = output;
320 3 : _weightsAttachment = weightsAttachment;
321 :
322 3600 : _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
323 3600 : return Rc<vk::QueuePassHandle>::create(pass, q);
324 3 : };
325 :
326 6 : return QueuePass::init(builder);
327 : }
328 :
329 3 : void MatrixMulLayer::initPropagationSubpass(core::Queue::Builder &builder, core::QueuePassBuilder &queueBuilder,
330 : core::SubpassBuilder &subpass, const core::PipelineLayoutData *layout) {
331 :
332 3 : auto backwardNeeded = isBackwardNeeded();
333 :
334 3 : _fullPropagationBuffers = _staticPropagationBuffers;
335 :
336 3 : _propWeightsIndex = _fullPropagationBuffers ++;
337 3 : _propTermsIndex = _fullPropagationBuffers ++;
338 3 : _propOriginalOutput = _fullPropagationBuffers ++;
339 3 : _propOriginalInput = _fullPropagationBuffers ++;
340 :
341 3 : _propOutputDiff = _fullPropagationBuffers ++;
342 3 : _propWeightsDiff = _fullPropagationBuffers ++;
343 3 : _propTermsDiff = _fullPropagationBuffers ++;
344 3 : _propFeedback = _fullPropagationBuffers ++;
345 3 : _propTargetIndex = _fullPropagationBuffers ++;
346 :
347 3 : const core::ComputePipelineData *matMul = nullptr;
348 3 : const core::ComputePipelineData *matMulBorders = nullptr;
349 3 : const core::ComputePipelineData *reluDiff = nullptr;
350 :
351 3 : subpass.setPrepareCallback([this, backwardNeeded, sourceWeights = _weightsBufferIndex, sourceTerms = _freeTermBufferIndex]
352 43200 : (const core::SubpassData &subpass, FrameQueue &q) {
353 3600 : auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
354 3600 : auto &pool = handle->getMemPool(nullptr);
355 :
356 3600 : vk::BufferAttachmentHandle *weightsBuffer = nullptr;
357 3600 : vk::BufferAttachmentHandle *outputBuffer = nullptr;
358 3600 : vk::BufferAttachmentHandle *inputBuffer = nullptr;
359 3600 : vk::BufferAttachmentHandle *propagationBuffer = nullptr;
360 3600 : vk::BufferAttachmentHandle *externalPropagationSource = nullptr;
361 :
362 3600 : if (auto bufferAttachment = q.getAttachment(getWeightsAttachment())) {
363 3600 : weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
364 : }
365 :
366 3600 : if (auto bufferAttachment = q.getAttachment(getOutputAttachment())) {
367 3600 : outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
368 : }
369 :
370 3600 : if (auto bufferAttachment = q.getAttachment(getInputAttachment())) {
371 3600 : inputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
372 : }
373 :
374 3600 : if (auto bufferAttachment = q.getAttachment(getPropagationAttachment())) {
375 3600 : propagationBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
376 : }
377 :
378 3600 : if (auto bufferAttachment = q.getAttachment(getExternalPropagationDataSource())) {
379 3600 : externalPropagationSource = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
380 : }
381 :
382 3600 : propagationBuffer->addBufferView(weightsBuffer->getBuffers()[sourceWeights].buffer);
383 3600 : propagationBuffer->addBufferView(weightsBuffer->getBuffers()[sourceTerms].buffer);
384 3600 : propagationBuffer->addBufferView(outputBuffer->getBuffers().back().buffer); // use feedback, direct output transformed with activation
385 3600 : propagationBuffer->addBufferView(inputBuffer->getBuffers().front().buffer);
386 :
387 : // output from prev layer
388 3600 : propagationBuffer->addBufferView(externalPropagationSource->getBuffers()[getExternalPropagationBufferIdx()].buffer);
389 :
390 : auto weightsDiff = pool->spawn(AllocationUsage::DeviceLocal,
391 3600 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::TransferDst | core::BufferUsage::StorageBuffer, PassType::Compute,
392 3600 : size_t(_front->getWeightBufferSize())
393 7200 : ));
394 3600 : propagationBuffer->addBufferView(weightsDiff);
395 :
396 : auto termsDiff = pool->spawn(AllocationUsage::DeviceLocal,
397 3600 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::TransferDst | core::BufferUsage::StorageBuffer, PassType::Compute,
398 3600 : size_t(_front->getKernelSize() * sizeof(float))
399 7200 : ));
400 3600 : propagationBuffer->addBufferView(termsDiff);
401 :
402 3600 : auto weightExtent = _front->getWeightSize();
403 3600 : auto outputExtent = _front->getOutputExtent();
404 :
405 3600 : const int resultBufferSize = outputExtent.depth * weightExtent.width;
406 :
407 : auto feedback = pool->spawn(AllocationUsage::DeviceLocal,
408 3600 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::TransferDst | core::BufferUsage::StorageBuffer, PassType::Compute,
409 3600 : size_t(resultBufferSize * sizeof(float))
410 7200 : ));
411 3600 : propagationBuffer->addBufferView(feedback);
412 :
413 : auto inputDiff = pool->spawn(AllocationUsage::DeviceLocal,
414 3600 : BufferInfo(core::BufferUsage::TransferSrc | core::BufferUsage::StorageBuffer, PassType::Compute,
415 3600 : size_t(resultBufferSize * sizeof(float))
416 7200 : ));
417 :
418 3600 : propagationBuffer->addBufferView(inputDiff);
419 3600 : });
420 :
421 3 : if (backwardNeeded) {
422 2 : matMul = subpass.addComputePipeline(toString(getName(), "_BackwardOnce_mul"), layout,
423 6 : core::SpecializationInfo(
424 4 : builder.addProgramByRef(toString(getName(), "_BackwardOnce_mul"),
425 2 : getShader(LayerShader::MultiplyMatrixByMatrix, Precision::Unknown)),
426 6 : Vector<core::SpecializationConstant>{
427 2 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
428 : core::SpecializationConstant(_propTargetIndex), // output
429 : core::SpecializationConstant(_propOutputDiff), // input
430 : core::SpecializationConstant(_propWeightsIndex), // weight
431 2 : core::SpecializationConstant(_front->getInputIndex()),
432 : }));
433 :
434 2 : matMulBorders = subpass.addComputePipeline(toString(getName(), "_BackwardOnce_mulBorders"), layout,
435 6 : core::SpecializationInfo(
436 4 : builder.addProgramByRef(toString(getName(), "_BackwardOnce_mulBorders"),
437 2 : getShader(LayerShader::MultiplyMatrixByMatrixBorder, Precision::Unknown)),
438 4 : Vector<core::SpecializationConstant>{
439 2 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
440 : core::SpecializationConstant(_propTargetIndex), // output
441 : core::SpecializationConstant(_propOutputDiff), // input
442 : core::SpecializationConstant(_propWeightsIndex) // weight
443 : }));
444 : }
445 :
446 3 : if (_front->getActivation() == Activation::RELU) {
447 2 : reluDiff = subpass.addComputePipeline(toString(getName(), "_BackwardOnce_reluDiff"), layout,
448 6 : core::SpecializationInfo(
449 4 : builder.addProgramByRef(toString(getName(), "_BackwardOnce_reluDiff"),
450 2 : getShader(LayerShader::VectorReLUDiff, Precision::Unknown)),
451 4 : Vector<core::SpecializationConstant>{
452 2 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
453 : core::SpecializationConstant(_propOutputDiff), // input diff
454 : core::SpecializationConstant(_propOriginalOutput), // original output
455 : core::SpecializationConstant(_propOutputDiff) // output diff
456 : }));
457 : }
458 :
459 3 : auto learnMatMul = subpass.addComputePipeline(toString(getName(), "_LearnOnce_MatMul"), layout,
460 9 : core::SpecializationInfo(
461 6 : builder.addProgramByRef(toString(getName(), "_LearnOnce_MatMul"),
462 3 : getShader(LayerShader::MultiplyTransposedMatrixByMatrix, Precision::Unknown)),
463 6 : Vector<core::SpecializationConstant>{
464 3 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
465 : core::SpecializationConstant(_propWeightsDiff),
466 : core::SpecializationConstant(_propOutputDiff),
467 : core::SpecializationConstant(_propOriginalInput)
468 : }));
469 :
470 3 : auto learnMatMulBorder = subpass.addComputePipeline(toString(getName(), "_LearnOnce_MatMulBorder"), layout,
471 9 : core::SpecializationInfo(
472 6 : builder.addProgramByRef(toString(getName(), "_LearnOnce_MatMulBorder"),
473 3 : getShader(LayerShader::MultiplyTransposedMatrixByMatrixBorder, Precision::Unknown)),
474 6 : Vector<core::SpecializationConstant>{
475 3 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
476 : core::SpecializationConstant(_propWeightsDiff),
477 : core::SpecializationConstant(_propOutputDiff),
478 : core::SpecializationConstant(_propOriginalInput)
479 : }));
480 :
481 3 : auto learnSum = subpass.addComputePipeline(toString(getName(), "_LearnOnce_Sum"), layout,
482 9 : core::SpecializationInfo(
483 6 : builder.addProgramByRef(toString(getName(), "_LearnOnce_Sum"),
484 3 : getShader(LayerShader::SumMatrixRows, Precision::Unknown)),
485 6 : Vector<core::SpecializationConstant>{
486 3 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
487 : core::SpecializationConstant(_propTermsDiff),
488 : core::SpecializationConstant(_propOutputDiff),
489 : }));
490 :
491 : struct TrainPipelines {
492 : const core::ComputePipelineData *decayHistory;
493 : const core::ComputePipelineData *multHistory;
494 : const core::ComputePipelineData *add4;
495 : const core::ComputePipelineData *add1;
496 : };
497 :
498 6 : auto initTrainPipelines = [&] (uint32_t staticparam, uint32_t diff, uint32_t target) {
499 : TrainPipelines ret;
500 :
501 6 : ret.decayHistory = subpass.addComputePipeline(toString(getName(), "_trainDecayHistory", staticparam), layout,
502 18 : core::SpecializationInfo(
503 12 : builder.addProgramByRef(toString(getName(), "_trainDecayHistory", staticparam),
504 6 : getShader(LayerShader::VectorEltwiseMultiply, Precision::Unknown)),
505 12 : Vector<core::SpecializationConstant>{
506 6 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
507 : core::SpecializationConstant(staticparam),
508 : core::SpecializationConstant(staticparam),
509 : core::SpecializationConstant(_staticParams),
510 : core::SpecializationConstant(TV_MomentDecayRateVar),
511 : }));
512 :
513 6 : ret.multHistory = subpass.addComputePipeline(toString(getName(), "_trainHistoryAdd", staticparam), layout,
514 18 : core::SpecializationInfo(
515 12 : builder.addProgramByRef(toString(getName(), "_trainHistoryAdd", staticparam),
516 6 : getShader(LayerShader::VectorMultiplyAndAdd, Precision::Unknown)),
517 12 : Vector<core::SpecializationConstant>{
518 6 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
519 : core::SpecializationConstant(staticparam),
520 : core::SpecializationConstant(staticparam),
521 : core::SpecializationConstant(diff),
522 : core::SpecializationConstant(_staticParams),
523 : core::SpecializationConstant(TV_RateVar),
524 : }));
525 :
526 6 : ret.add4 = subpass.addComputePipeline(toString(getName(), "_trainAdd4_", staticparam), layout,
527 18 : core::SpecializationInfo(
528 12 : builder.addProgramByRef(toString(getName(), "_trainAdd4_", staticparam),
529 6 : getShader(LayerShader::VectorAddFloat4, Precision::Unknown)),
530 12 : Vector<core::SpecializationConstant>{
531 6 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
532 : core::SpecializationConstant(target),
533 : core::SpecializationConstant(target),
534 : core::SpecializationConstant(staticparam)
535 : }));
536 :
537 6 : ret.add1 = subpass.addComputePipeline(toString(getName(), "_trainAdd1_", staticparam), layout,
538 18 : core::SpecializationInfo(
539 12 : builder.addProgramByRef(toString(getName(), "_trainAdd1_", staticparam),
540 6 : getShader(LayerShader::VectorAddFloat1, Precision::Unknown)),
541 12 : Vector<core::SpecializationConstant>{
542 6 : core::SpecializationConstant(getPropagationSubpassBufferCount()), // nbuffers
543 : core::SpecializationConstant(target),
544 : core::SpecializationConstant(target),
545 : core::SpecializationConstant(staticparam)
546 : }));
547 :
548 6 : return ret;
549 3 : };
550 :
551 3 : auto trainWeights = initTrainPipelines(_staticWeightsHistoryIndex, _propWeightsDiff, _propWeightsIndex);
552 3 : auto trainTerms = initTrainPipelines(_staticTermsHistoryIndex, _propTermsDiff, _propTermsIndex);
553 :
554 3 : subpass.setCommandsCallback([this, backwardNeeded, layoutIndex = layout->index, matMul, matMulBorders,
555 : reluDiff, learnMatMul, learnMatMulBorder, learnSum,
556 : trainWeights, trainTerms]
557 76800 : (const core::SubpassData &subpass, core::FrameQueue &q, core::CommandBuffer &b) {
558 3600 : auto &buf = static_cast<CommandBuffer &>(b);
559 3600 : auto pass = static_cast<RenderPass *>(subpass.pass->impl.get());
560 3600 : auto front = getFront();
561 :
562 3600 : auto weightExtent = front->getWeightSize();
563 3600 : auto outputExtent = front->getOutputExtent();
564 :
565 10800 : auto makeFullBarrier = [&] (Buffer *b) {
566 10800 : auto BufferAccessFlags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
567 10800 : BufferMemoryBarrier barrier(b, BufferAccessFlags, BufferAccessFlags);
568 10800 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
569 10800 : };
570 :
571 7200 : auto makeFullBarrier2 = [&] (Buffer *b1, Buffer *b2) {
572 7200 : auto BufferAccessFlags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
573 : BufferMemoryBarrier barriers[] = {
574 : BufferMemoryBarrier(b1, BufferAccessFlags, BufferAccessFlags),
575 : BufferMemoryBarrier(b2, BufferAccessFlags, BufferAccessFlags)
576 7200 : };
577 7200 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 2));
578 7200 : };
579 :
580 3600 : auto makeFullBarrier4 = [&] (Buffer *b1, Buffer *b2, Buffer *b3, Buffer *b4) {
581 3600 : auto BufferAccessFlags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
582 : BufferMemoryBarrier barriers[] = {
583 : BufferMemoryBarrier(b1, BufferAccessFlags, BufferAccessFlags),
584 : BufferMemoryBarrier(b2, BufferAccessFlags, BufferAccessFlags),
585 : BufferMemoryBarrier(b3, BufferAccessFlags, BufferAccessFlags),
586 : BufferMemoryBarrier(b4, BufferAccessFlags, BufferAccessFlags)
587 3600 : };
588 3600 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 4));
589 3600 : };
590 :
591 3600 : vk::BufferAttachmentHandle *propagationBuffer = nullptr;
592 :
593 3600 : if (auto bufferAttachment = q.getAttachment(getPropagationAttachment())) {
594 3600 : propagationBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
595 : }
596 :
597 3600 : buf.cmdBindDescriptorSets(pass, layoutIndex);
598 :
599 3600 : makeFullBarrier(propagationBuffer->getBuffers()[_propOutputDiff].buffer);
600 :
601 3600 : if (front->getActivation() == Activation::RELU) {
602 2400 : const int inputBufferSize = outputExtent.depth * front->getKernelSize();
603 2400 : VectorReLUDiff(buf, static_cast<ComputePipeline *>(reluDiff->pipeline.get()), inputBufferSize, 0.0f);
604 2400 : makeFullBarrier(propagationBuffer->getBuffers()[_propOutputDiff].buffer);
605 : }
606 :
607 3600 : if (backwardNeeded) {
608 2400 : const int secondWidth = weightExtent.width;
609 2400 : const int firstHeight = outputExtent.depth;
610 2400 : const int firstWidth = outputExtent.width;
611 2400 : const int resultBufferSize = firstHeight * secondWidth;
612 :
613 2400 : MultiplyMatrixByMatrix(buf, static_cast<ComputePipeline *>(matMul->pipeline.get()), static_cast<ComputePipeline *>(matMulBorders->pipeline.get()),
614 : 1, firstHeight, firstWidth, secondWidth, resultBufferSize);
615 :
616 2400 : makeFullBarrier(propagationBuffer->getBuffers()[_propTargetIndex].buffer);
617 2400 : buf.cmdCopyBuffer(propagationBuffer->getBuffers()[_propTargetIndex].buffer, propagationBuffer->getBuffers()[_propFeedback].buffer);
618 2400 : makeFullBarrier(propagationBuffer->getBuffers()[_propFeedback].buffer);
619 : }
620 :
621 3600 : MultiplyTransposedMatrixByMatrix(buf, static_cast<ComputePipeline *>(learnMatMul->pipeline.get()),
622 3600 : static_cast<ComputePipeline *>(learnMatMulBorder->pipeline.get()),
623 3600 : outputExtent.depth, weightExtent.height, weightExtent.height,
624 3600 : weightExtent.width, weightExtent.width, weightExtent.width,
625 3600 : weightExtent.width * weightExtent.height);
626 :
627 3600 : SumMatrixRows(buf, static_cast<ComputePipeline *>(learnSum->pipeline.get()),
628 3600 : 1, outputExtent.depth, weightExtent.height);
629 :
630 3600 : auto weightsSize = front->getWeightBufferSize() / sizeof(float);
631 3600 : auto termsSize = front->getKernelSize();
632 :
633 3600 : VectorMultiply(buf, static_cast<ComputePipeline *>(trainWeights.decayHistory->pipeline.get()), weightsSize);
634 3600 : VectorMultiply(buf, static_cast<ComputePipeline *>(trainTerms.decayHistory->pipeline.get()), termsSize);
635 :
636 7200 : makeFullBarrier4(propagationBuffer->getBuffers()[_propWeightsDiff].buffer, propagationBuffer->getBuffers()[_propTermsDiff].buffer,
637 7200 : propagationBuffer->getBuffers()[_staticWeightsHistoryIndex].buffer, propagationBuffer->getBuffers()[_staticTermsHistoryIndex].buffer);
638 3600 : VectorMultiplyAndAdd(buf, static_cast<ComputePipeline *>(trainWeights.multHistory->pipeline.get()), weightsSize);
639 3600 : VectorMultiplyAndAdd(buf, static_cast<ComputePipeline *>(trainTerms.multHistory->pipeline.get()), termsSize);
640 :
641 3600 : makeFullBarrier2(propagationBuffer->getBuffers()[_staticWeightsHistoryIndex].buffer,
642 3600 : propagationBuffer->getBuffers()[_staticTermsHistoryIndex].buffer);
643 :
644 3600 : VectorAdd(buf, static_cast<ComputePipeline *>(trainWeights.add4->pipeline.get()),
645 3600 : static_cast<ComputePipeline *>(trainWeights.add1->pipeline.get()), weightsSize);
646 3600 : VectorAdd(buf, static_cast<ComputePipeline *>(trainTerms.add4->pipeline.get()),
647 3600 : static_cast<ComputePipeline *>(trainTerms.add1->pipeline.get()), termsSize);
648 :
649 3600 : makeFullBarrier2(propagationBuffer->getBuffers()[_propWeightsIndex].buffer,
650 3600 : propagationBuffer->getBuffers()[_propTermsIndex].buffer);
651 3600 : });
652 :
653 3 : queueBuilder.addCompleteCallback([this] (const core::QueuePassData &, FrameQueue &q, bool success) {
654 : /*vk::BufferAttachmentHandle *propagationBuffer = nullptr;
655 : if (auto bufferAttachment = q.getAttachment(getPropagationAttachment())) {
656 : propagationBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
657 : }
658 :
659 : q.getFrame()->getLoop()->captureBuffer([front = _front] (const BufferInfo &info, BytesView view) {
660 : xenolith::shadernn::Model::saveBlob(
661 : filesystem::currentDir<Interface>(toString(front->getName(),".", front->getInputIndex(), ".weightsDiff.bin")).data(),
662 : view.data(), view.size());
663 : std::cout << "Save: " << toString(front->getName(),".", front->getInputIndex(), ".weightsDiff.bin") << "\n";
664 : }, propagationBuffer->getBuffers()[_propWeightsDiff].buffer.get());
665 :
666 : q.getFrame()->getLoop()->captureBuffer([front = _front] (const BufferInfo &info, BytesView view) {
667 : xenolith::shadernn::Model::saveBlob(
668 : filesystem::currentDir<Interface>(toString(front->getName(),".", front->getInputIndex(), ".termsDiff.bin")).data(),
669 : view.data(), view.size());
670 : std::cout << "Save: " << toString(front->getName(),".", front->getInputIndex(), ".termsDiff.bin") << "\n";
671 : }, propagationBuffer->getBuffers()[_propTermsDiff].buffer.get());
672 :
673 : q.getFrame()->getLoop()->captureBuffer([front = _front] (const BufferInfo &info, BytesView view) {
674 : xenolith::shadernn::Model::saveBlob(
675 : filesystem::currentDir<Interface>(toString(front->getName(),".", front->getInputIndex(), ".feedback.bin")).data(),
676 : view.data(), view.size());
677 : std::cout << "Save: " << toString(front->getName(),".", front->getInputIndex(), ".output.bin") << "\n";
678 : }, propagationBuffer->getBuffers()[_propFeedback].buffer.get());*/
679 3600 : });
680 :
681 3 : _targetPropagationIdx = _propTargetIndex;
682 3 : }
683 :
684 3 : Vector<const core::BufferData *> MatrixMulLayer::getTrainableGradients(Queue::Builder &queueBuilder) const {
685 9 : auto weightsGradientBuffer = queueBuilder.addBuffer(toString(getName(), "_weightsGradient_buffer"),
686 6 : BufferInfo(_front->getWeightBufferSize(), core::BufferUsage::StorageBuffer, PassType::Compute),
687 3 : [] (uint8_t *buf, uint64_t size, const core::BufferData::DataCallback &cb) {
688 3 : FillFloatBuffer(buf, size, 0.0f);
689 3 : });
690 :
691 9 : auto freeTermsGradientBuffer = queueBuilder.addBuffer(toString(getName(), "_freeTermsGradient_buffer"),
692 6 : BufferInfo(_front->getKernelSize() * sizeof(float), core::BufferUsage::StorageBuffer, PassType::Compute),
693 3 : [] (uint8_t *buf, uint64_t size, const core::BufferData::DataCallback &cb) {
694 3 : FillFloatBuffer(buf, size, 0.0f);
695 3 : });
696 :
697 : return Vector<const core::BufferData *>{
698 : weightsGradientBuffer,
699 : freeTermsGradientBuffer
700 3 : };
701 : }
702 :
703 : }
|