Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnVkLossLayer.h"
24 :
25 : namespace stappler::xenolith::vk::shadernn {
26 :
27 : static auto BufferAccessFlags = VkAccessFlags(core::AccessType::ShaderWrite | core::AccessType::ShaderRead);
28 :
29 9 : static StringView getPipelineOpName(CrossEntropyLossLayer::PipelineOpIndex idx) {
30 9 : switch (idx) {
31 1 : case CrossEntropyLossLayer::PipelineOpIndex::MatrixSoftmaxByRows: return "MatrixSoftmaxByRows"; break;
32 1 : case CrossEntropyLossLayer::PipelineOpIndex::VectorNegLog: return "VectorNegLog"; break;
33 1 : case CrossEntropyLossLayer::PipelineOpIndex::VectorEltwiseMultiply: return "VectorEltwiseMultiply"; break;
34 1 : case CrossEntropyLossLayer::PipelineOpIndex::SumMatrixColumnsToResult: return "SumMatrixColumnsToResult"; break;
35 1 : case CrossEntropyLossLayer::PipelineOpIndex::VectorSub: return "VectorSub"; break;
36 1 : case CrossEntropyLossLayer::PipelineOpIndex::SumMatrixColumnsLabels: return "SumMatrixColumns"; break;
37 1 : case CrossEntropyLossLayer::PipelineOpIndex::MultiplyDiagMatrixByMatrix: return "MultiplyDiagMatrixByMatrix"; break;
38 1 : case CrossEntropyLossLayer::PipelineOpIndex::VectorDotProduct: return "VectorDotProduct"; break;
39 1 : case CrossEntropyLossLayer::PipelineOpIndex::MultiplyDiagMatrixByMatrixForInput: return "MultiplyDiagMatrixByMatrixForInput"; break;
40 : }
41 0 : return StringView();
42 : }
43 :
44 9 : static LayerShader getPipelineOpShader(CrossEntropyLossLayer::PipelineOpIndex idx) {
45 9 : switch (idx) {
46 1 : case CrossEntropyLossLayer::PipelineOpIndex::MatrixSoftmaxByRows: return LayerShader::MatrixSoftmaxByRows; break;
47 1 : case CrossEntropyLossLayer::PipelineOpIndex::VectorNegLog: return LayerShader::VectorLog; break;
48 1 : case CrossEntropyLossLayer::PipelineOpIndex::VectorEltwiseMultiply: return LayerShader::VectorEltwiseMultiply; break;
49 1 : case CrossEntropyLossLayer::PipelineOpIndex::SumMatrixColumnsToResult: return LayerShader::SumMatrixColumns; break;
50 1 : case CrossEntropyLossLayer::PipelineOpIndex::VectorSub: return LayerShader::VectorSub; break;
51 1 : case CrossEntropyLossLayer::PipelineOpIndex::SumMatrixColumnsLabels: return LayerShader::SumMatrixColumns; break;
52 1 : case CrossEntropyLossLayer::PipelineOpIndex::MultiplyDiagMatrixByMatrix: return LayerShader::MultiplyDiagMatrixByMatrix; break;
53 1 : case CrossEntropyLossLayer::PipelineOpIndex::VectorDotProduct: return LayerShader::VectorDotProduct; break;
54 1 : case CrossEntropyLossLayer::PipelineOpIndex::MultiplyDiagMatrixByMatrixForInput: return LayerShader::MultiplyDiagMatrixByMatrix; break;
55 : }
56 0 : return LayerShader::Gen;
57 : }
58 :
59 2 : CrossEntropyLossLayer::~CrossEntropyLossLayer() { }
60 :
61 1 : bool CrossEntropyLossLayer::init(Queue::Builder &queueBuilder, QueuePassBuilder &builder, Front * front,
62 : const AttachmentData *inputLabels, const AttachmentData *inputNetwork, const AttachmentData *output) {
63 : using namespace core;
64 :
65 1 : _front = front;
66 :
67 3 : auto paramsBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_params_buffer"),
68 2 : BufferInfo(front->getParameters().size() * sizeof(float), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
69 1 : [front = front] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
70 1 : memcpy(buf, front->getParameters().data(), size);
71 1 : });
72 :
73 3 : auto weightsBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_weights_buffer"),
74 2 : BufferInfo(front->getWeightBufferSize(), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
75 1 : [] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
76 1 : FillFloatBuffer(buf, size, 1.0f);
77 1 : });
78 :
79 3 : auto resultBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_result_buffer"),
80 2 : BufferInfo(front->getResultBufferSize(), BufferUsage::StorageBuffer | BufferUsage::TransferSrc, PassType::Compute),
81 1 : [] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
82 1 : FillFloatBuffer(buf, size, 0.0f);
83 1 : });
84 :
85 3 : auto lossGradientBuffer = queueBuilder.addBuffer(toString(builder.getName(), "_lossGradient_buffer"),
86 2 : BufferInfo(front->getLossGradientBufferSize(), BufferUsage::StorageBuffer, PassType::Compute),
87 1 : [] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &cb) {
88 1 : FillFloatBuffer(buf, size, 0.0f);
89 1 : });
90 :
91 1 : auto weightsAttachment = queueBuilder.addAttachemnt(toString(builder.getName(), "_weights"),
92 1 : [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> {
93 2 : return Rc<vk::BufferAttachment>::create(builder, Vector<const BufferData *>{
94 1 : paramsBuffer,
95 1 : weightsBuffer,
96 1 : resultBuffer,
97 1 : lossGradientBuffer,
98 2 : });
99 : });
100 :
101 1 : builder.addAttachment(inputLabels, AttachmentDependencyInfo::make(PipelineStage::ComputeShader, AccessType::ShaderRead));
102 1 : builder.addAttachment(inputNetwork, AttachmentDependencyInfo::make(PipelineStage::ComputeShader, AccessType::ShaderRead));
103 1 : builder.addAttachment(output);
104 1 : auto passWeights = builder.addAttachment(weightsAttachment);
105 :
106 1 : auto layout = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layoutBuilder) {
107 1 : layoutBuilder.addSet([&] (DescriptorSetBuilder &setBuilder) {
108 2 : setBuilder.addDescriptorArray(passWeights, DescriptorArraySize, DescriptorType::StorageBuffer);
109 1 : });
110 2 : });
111 :
112 1 : builder.addSubpass([&] (SubpassBuilder &subpassBuilder) {
113 4 : auto addPipeline2 = [&] (PipelineOpIndex idx, uint32_t output, uint32_t input, PipelineOpFn &&fn) {
114 4 : auto name = getPipelineOpName(idx);
115 4 : auto shader = getPipelineOpShader(idx);
116 :
117 4 : auto data = subpassBuilder.addComputePipeline(toString(builder.getName(), "_", name), layout,
118 12 : SpecializationInfo(
119 8 : queueBuilder.addProgramByRef(toString(builder.getName(), "_", name, "_shader"),
120 4 : getShader(shader, Precision::Unknown)),
121 8 : Vector<SpecializationConstant>{
122 : SpecializationConstant(DescriptorArraySize), // nbuffers
123 : SpecializationConstant(output), // output
124 : SpecializationConstant(input), // input
125 : }));
126 :
127 4 : _pipelineOps.emplace(idx, PipelineOp(
128 4 : idx, data, move(fn)
129 : ));
130 4 : };
131 :
132 5 : auto addPipeline3 = [&] (PipelineOpIndex idx, uint32_t output, uint32_t inputA, uint32_t inputB, PipelineOpFn &&fn,
133 : Vector<SpecializationConstant> &&extra = Vector<SpecializationConstant>()) {
134 5 : auto name = getPipelineOpName(idx);
135 5 : auto shader = getPipelineOpShader(idx);
136 :
137 : auto constants = Vector<SpecializationConstant>{
138 : SpecializationConstant(DescriptorArraySize), // nbuffers
139 : SpecializationConstant(output), // output
140 : SpecializationConstant(inputA), // input
141 : SpecializationConstant(inputB), // input
142 5 : };
143 :
144 13 : for (auto &it : extra) {
145 8 : constants.emplace_back(it);
146 : }
147 :
148 5 : auto data = subpassBuilder.addComputePipeline(toString(builder.getName(), "_", name), layout,
149 15 : SpecializationInfo(
150 10 : queueBuilder.addProgramByRef(toString(builder.getName(), "_", name, "_shader"),
151 5 : getShader(shader, Precision::Unknown)),
152 : constants));
153 :
154 5 : _pipelineOps.emplace(idx, PipelineOp(
155 5 : idx, data, move(fn)
156 : ));
157 :
158 5 : return data;
159 5 : };
160 :
161 1 : addPipeline2(PipelineOpIndex::MatrixSoftmaxByRows, ActivationIdx, InputNetworkIdx,
162 1200 : [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
163 1200 : MatrixSoftmaxByRows( buf, pipeline, front->getBatchSize(), front->getClassesCount() );
164 1200 : });
165 :
166 1 : addPipeline2(PipelineOpIndex::VectorNegLog, ActivationEltwiseMulIdx, ActivationIdx,
167 1200 : [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
168 :
169 1200 : BufferMemoryBarrier barrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags);
170 :
171 1200 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
172 :
173 1200 : VectorNegLog( buf, pipeline, front->getBatchSize() * front->getClassesCount() );
174 1200 : });
175 :
176 1 : addPipeline3(PipelineOpIndex::VectorEltwiseMultiply, ActivationEltwiseMulIdx, InputLabelsIdx, ActivationEltwiseMulIdx,
177 1200 : [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
178 :
179 1200 : BufferMemoryBarrier barrier(buffers[ActivationEltwiseMulIdx].buffer, BufferAccessFlags, BufferAccessFlags);
180 :
181 1200 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
182 :
183 1200 : VectorEltwiseMultiply( buf, pipeline, front->getBatchSize() * front->getClassesCount() );
184 1200 : });
185 :
186 1 : addPipeline2(PipelineOpIndex::SumMatrixColumnsToResult, LossValueIdx, ActivationEltwiseMulIdx,
187 1200 : [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
188 :
189 1200 : BufferMemoryBarrier barrier(buffers[ActivationEltwiseMulIdx].buffer, BufferAccessFlags, BufferAccessFlags);
190 :
191 1200 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
192 :
193 1200 : SumMatrixColumns( buf, pipeline, front->getBatchSize(), front->getClassesCount() );
194 1200 : });
195 :
196 1 : addPipeline3(PipelineOpIndex::VectorDotProduct, ParamsIdx, WeightsIdx, LossValueIdx,
197 1200 : [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
198 :
199 : BufferMemoryBarrier barriers[2] = {
200 0 : BufferMemoryBarrier(buffers[WeightsIdx].buffer, BufferAccessFlags, BufferAccessFlags),
201 0 : BufferMemoryBarrier(buffers[LossValueIdx].buffer, BufferAccessFlags, BufferAccessFlags),
202 1200 : };
203 :
204 1200 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 2));
205 :
206 1200 : VectorDotProduct( buf, pipeline, front->getBatchSize());
207 1202 : }, Vector<SpecializationConstant>{
208 : SpecializationConstant(Front::P_Loss),
209 : SpecializationConstant(1),
210 : SpecializationConstant(Front::P_LossDivider),
211 : });
212 :
213 1 : if (front->getModel()->isTrainable()) {
214 1 : addPipeline3(PipelineOpIndex::VectorSub, ActivationEltwiseMulIdx, ActivationIdx, InputLabelsIdx,
215 1200 : [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
216 :
217 : BufferMemoryBarrier barriers[2] = {
218 0 : BufferMemoryBarrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags),
219 0 : BufferMemoryBarrier(buffers[ActivationEltwiseMulIdx].buffer, BufferAccessFlags, BufferAccessFlags),
220 1200 : };
221 :
222 1200 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 2));
223 :
224 1200 : VectorSub( buf, pipeline, front->getBatchSize() * front->getClassesCount() );
225 1200 : });
226 :
227 1 : addPipeline2(PipelineOpIndex::SumMatrixColumnsLabels, ActivationIdx, InputLabelsIdx,
228 1200 : [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
229 :
230 1200 : BufferMemoryBarrier barrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags);
231 :
232 1200 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
233 :
234 1200 : SumMatrixColumns( buf, pipeline, front->getBatchSize(), front->getClassesCount() );
235 1200 : });
236 :
237 1 : addPipeline3(PipelineOpIndex::MultiplyDiagMatrixByMatrix, LossGradientIdx, ActivationIdx, ActivationEltwiseMulIdx,
238 1200 : [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
239 :
240 : BufferMemoryBarrier barriers[2] = {
241 0 : BufferMemoryBarrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags),
242 0 : BufferMemoryBarrier(buffers[ActivationEltwiseMulIdx].buffer, BufferAccessFlags, BufferAccessFlags),
243 1200 : };
244 :
245 1200 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(barriers, 2));
246 :
247 1200 : MultiplyDiagMatrixByMatrix( buf, pipeline, front->getBatchSize(), front->getClassesCount(), front->getBatchSize() * front->getClassesCount() );
248 1200 : });
249 :
250 1 : addPipeline3(PipelineOpIndex::MultiplyDiagMatrixByMatrixForInput, ActivationIdx, WeightsIdx, LossGradientIdx,
251 1200 : [] (Front *front, CommandBuffer &buf, ComputePipeline *pipeline, SpanView<BufferView> buffers) {
252 :
253 1200 : BufferMemoryBarrier barrier(buffers[ActivationIdx].buffer, BufferAccessFlags, BufferAccessFlags);
254 :
255 1200 : buf.cmdPipelineBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, makeSpanView(&barrier, 1));
256 :
257 1200 : MultiplyDiagMatrixByMatrix( buf, pipeline, front->getBatchSize(), front->getClassesCount(), front->getBatchSize() * front->getClassesCount() );
258 1202 : }, Vector<SpecializationConstant>{
259 : SpecializationConstant(1), // MODIFIERS_ENABLED
260 : SpecializationConstant(ParamsIdx), // PARAMETERS_INDEX
261 : SpecializationConstant(Front::P_LossGradientDivider), // MULTIPLIER_PARAMETER_OFFSET
262 : SpecializationConstant(Front::P_MinGradient), // MIN_PARAMETER_OFFSET
263 : SpecializationConstant(Front::P_MaxGradient), // MAX_PARAMETER_OFFSET
264 : });
265 : }
266 :
267 1 : subpassBuilder.setPrepareCallback([] (const core::SubpassData &subpass, core::FrameQueue &q) {
268 1200 : auto layer = (CrossEntropyLossLayer *)subpass.pass->pass.get();
269 :
270 1200 : vk::BufferAttachmentHandle *inputNetworkBuffer = nullptr;
271 1200 : vk::BufferAttachmentHandle *inputLabelsBuffer = nullptr;
272 1200 : vk::BufferAttachmentHandle *outputBuffer = nullptr;
273 1200 : vk::BufferAttachmentHandle *weightsBuffer = nullptr;
274 :
275 1200 : if (auto bufferAttachment = q.getAttachment(layer->getInputNetworkAttachment())) {
276 1200 : inputNetworkBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
277 : }
278 :
279 1200 : if (auto bufferAttachment = q.getAttachment(layer->getInputLabelsAttachment())) {
280 1200 : inputLabelsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
281 : }
282 :
283 1200 : if (auto bufferAttachment = q.getAttachment(layer->getOutputAttachment())) {
284 1200 : outputBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
285 : }
286 :
287 1200 : if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
288 1200 : weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
289 : }
290 :
291 1200 : auto front = layer->getFront();
292 :
293 1200 : auto handle = static_cast<DeviceFrameHandle *>(q.getFrame().get());
294 1200 : auto &pool = handle->getMemPool(nullptr);
295 :
296 1200 : auto batchSize = front->getBatchSize();
297 1200 : auto vectorSize = front->getClassesCount();
298 1200 : auto totalSize = batchSize * vectorSize;
299 :
300 : auto activationBuffer = pool->spawn(AllocationUsage::DeviceLocal,
301 1200 : BufferInfo(size_t(totalSize * sizeof(float)), BufferUsage::StorageBuffer | BufferUsage::TransferSrc));
302 : auto activationEltwiseMulBuffer = pool->spawn(AllocationUsage::DeviceLocal,
303 1200 : BufferInfo(size_t(totalSize * sizeof(float)), BufferUsage::StorageBuffer | BufferUsage::TransferSrc));
304 :
305 1200 : outputBuffer->addBufferView(weightsBuffer->getBuffers()[2].buffer);
306 1200 : outputBuffer->addBufferView(weightsBuffer->getBuffers()[0].buffer);
307 :
308 1200 : weightsBuffer->addBufferView(inputNetworkBuffer->getBuffers().front().buffer);
309 1200 : weightsBuffer->addBufferView(inputLabelsBuffer->getBuffers().front().buffer);
310 1200 : weightsBuffer->addBufferView(activationBuffer);
311 1200 : weightsBuffer->addBufferView(activationEltwiseMulBuffer);
312 1200 : });
313 :
314 1 : subpassBuilder.setCommandsCallback([] (const SubpassData &subpass, FrameQueue &q, core::CommandBuffer &b) {
315 1200 : auto &buf = static_cast<CommandBuffer &>(b);
316 1200 : auto pass = static_cast<RenderPass *>(subpass.pass->impl.get());
317 1200 : auto layer = (CrossEntropyLossLayer *)subpass.pass->pass.get();
318 :
319 1200 : vk::BufferAttachmentHandle *weightsBuffer = nullptr;
320 1200 : if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
321 1200 : weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
322 : }
323 :
324 1200 : buf.cmdBindDescriptorSets(pass, 0);
325 :
326 1200 : layer->runAll(buf, weightsBuffer->getBuffers());
327 1200 : });
328 1 : });
329 :
330 1 : builder.addCompleteCallback([this, front = _front] (const QueuePassData &pass, FrameQueue &q, bool success) {
331 1200 : auto layer = (CrossEntropyLossLayer *)pass.pass.get();
332 :
333 1200 : vk::BufferAttachmentHandle *weightsBuffer = nullptr;
334 1200 : if (auto bufferAttachment = q.getAttachment(layer->getWeightsAttachment())) {
335 1200 : weightsBuffer = (vk::BufferAttachmentHandle *)bufferAttachment->handle.get();
336 : }
337 :
338 1200 : auto params = weightsBuffer->getBuffers()[ParamsIdx].buffer;
339 :
340 : if (success) {
341 : /*q.getFrame()->getLoop()->captureBuffer([front] (const BufferInfo &info, BytesView view) {
342 : auto name = toString(front->getName(),".", front->getInputIndex(), ".label.bin");
343 : xenolith::shadernn::Model::saveBlob(
344 : filesystem::currentDir<Interface>(name).data(),
345 : view.data(), view.size());
346 : std::cout << "Save " << name << "\n";
347 : }, weightsBuffer->getBuffers()[InputLabelsIdx].buffer.get());
348 :
349 : q.getFrame()->getLoop()->captureBuffer([front] (const BufferInfo &info, BytesView view) {
350 : auto name = toString(front->getName(),".", front->getInputIndex(), ".activation.bin");
351 : xenolith::shadernn::Model::saveBlob(
352 : filesystem::currentDir<Interface>(name).data(),
353 : view.data(), view.size());
354 : std::cout << "Save " << name << "\n";
355 : }, weightsBuffer->getBuffers()[ActivationIdx].buffer.get());
356 :
357 : q.getFrame()->getLoop()->captureBuffer([front] (const BufferInfo &info, BytesView view) {
358 : auto name = toString(front->getName(),".", front->getInputIndex(), ".activation.mul.bin");
359 : xenolith::shadernn::Model::saveBlob(
360 : filesystem::currentDir<Interface>(name).data(),
361 : view.data(), view.size());
362 : std::cout << "Save " << name << "\n";
363 : }, weightsBuffer->getBuffers()[ActivationEltwiseMulIdx].buffer.get());
364 :
365 : q.getFrame()->getLoop()->captureBuffer([front] (const BufferInfo &info, BytesView view) {
366 : auto name = toString(front->getName(),".", front->getInputIndex(), ".value.bin");
367 : xenolith::shadernn::Model::saveBlob(
368 : filesystem::currentDir<Interface>(name).data(),
369 : view.data(), view.size());
370 : std::cout << "Save " << name << "\n";
371 : }, weightsBuffer->getBuffers()[LossValueIdx].buffer.get());
372 : */
373 : }
374 1200 : });
375 :
376 1 : _inputLabelsAttachment = inputLabels;
377 1 : _inputNetworkAttachment = inputNetwork;
378 1 : _weightAttachment = weightsAttachment;
379 1 : _outputAttachment = output;
380 :
381 1200 : _frameHandleCallback = [] (core::QueuePass &pass, const FrameQueue &q) {
382 1200 : return Rc<vk::QueuePassHandle>::create(pass, q);
383 1 : };
384 :
385 1 : if (_front->getModel()->isTrainable()) {
386 1 : initPropagation(queueBuilder, builder);
387 : }
388 :
389 2 : return QueuePass::init(builder);
390 : }
391 :
392 1 : void CrossEntropyLossLayer::initPropagation(Queue::Builder &queueBuilder, QueuePassBuilder &builder) {
393 1 : const core::QueuePassData *pass = _inputNetworkAttachment->passes.front()->pass;
394 1 : if (auto trainableLayer = dynamic_cast<TrainableLayer *>(pass->pass.get())) {
395 1 : trainableLayer->initPropagation(queueBuilder, builder, _weightAttachment, ActivationIdx);
396 : }
397 1 : }
398 :
399 1200 : void CrossEntropyLossLayer::runAll(CommandBuffer &buf, SpanView<BufferView> buffers) {
400 12000 : for (auto &it : _pipelineOps) {
401 10800 : auto pipeline = static_cast<ComputePipeline *>(it.second.pipeline->pipeline.get());
402 :
403 10800 : it.second.command(_front, buf, pipeline, buffers);
404 : }
405 1200 : }
406 :
407 : }
|