Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnLossLayer.h"
24 : #include "XLSnnAttachment.h"
25 : #include "XLSnnVkLossLayer.h"
26 :
27 : namespace stappler::xenolith::shadernn {
28 :
29 : static constexpr float MaxGradient = 1e+06;
30 :
31 6 : void LossLayer::setParameter(ParameterIndex idx, float val) {
32 6 : if (idx < P_Count) {
33 6 : _params[idx] = val;
34 : }
35 6 : }
36 :
37 2 : float LossLayer::getParameter(ParameterIndex idx) const {
38 2 : if (idx < P_Count) {
39 2 : return _params[idx];
40 : }
41 0 : return 0.0f;
42 : }
43 :
44 0 : void LossLayer::synchronizeParameters(SpanView<float> data) {
45 0 : memcpy(_params.data(), data.data(), _params.size() * sizeof(float));
46 0 : }
47 :
48 1 : bool CrossEntropyLossLayer::init(Model *m, StringView tag, size_t idx, const Value &data) {
49 1 : if (!LossLayer::init(m, tag, idx, data)) {
50 0 : return false;
51 : }
52 :
53 1 : _labelsInputName = data.getString("labels");
54 1 : _batchSize = data.getInteger("batch_size");
55 1 : _classesCount = data.getInteger("classes_count");
56 :
57 1 : setParameter(P_LossWeight, 1.0f);
58 1 : setParameter(P_Loss, 0.0f);
59 1 : setParameter(P_LossDivider, 1.f / float(_batchSize));
60 1 : setParameter(P_LossGradientDivider, getParameter(P_LossDivider) * getParameter(P_LossWeight));
61 1 : setParameter(P_MinGradient, -MaxGradient);
62 1 : setParameter(P_MaxGradient, MaxGradient);
63 :
64 1 : return true;
65 : }
66 :
67 2 : void CrossEntropyLossLayer::setInputExtent(uint32_t index, Attachment *a, Extent3 e) {
68 2 : Layer::setInputExtent(index, a, e);
69 2 : if (a->getOutputBy()->getName() == _labelsInputName) {
70 1 : _inputLabels = a->getOutputBy();
71 : } else {
72 1 : _inputNetwork = a->getOutputBy();
73 : }
74 2 : }
75 :
76 1 : const core::QueuePassData *CrossEntropyLossLayer::prepare(core::Queue::Builder &builder,
77 : Map<Layer *, const core::AttachmentData *> inputs,
78 : Map<Attachment *, const core::AttachmentData *> attachments) {
79 1 : auto inputLabelsIt = attachments.find(_inputs[0].attachment);
80 1 : auto inputNetworkIt = attachments.find(_inputs[1].attachment);
81 1 : auto outputIt = attachments.find(getOutput());
82 :
83 1 : if (inputLabelsIt == attachments.end() || inputNetworkIt == attachments.end() || outputIt == attachments.end()) {
84 0 : log::error("snn::InputLayer", "No attachments specified");
85 0 : return nullptr;
86 : }
87 :
88 2 : return builder.addPass(getName(), core::PassType::Compute, core::RenderOrdering(_inputIndex),
89 1 : [&] (core::QueuePassBuilder &passBuilder) -> Rc<core::QueuePass> {
90 2 : return Rc<vk::shadernn::CrossEntropyLossLayer>::create(builder, passBuilder, this, inputLabelsIt->second, inputNetworkIt->second, outputIt->second);
91 1 : });
92 : }
93 :
94 1 : const core::AttachmentData *CrossEntropyLossLayer::makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) {
95 2 : return builder.addAttachemnt(toString(getName(), "_output"),
96 1 : [&] (core::AttachmentBuilder &attachmentBuilder) -> Rc<core::Attachment> {
97 1 : if (isGlobalOutput) {
98 1 : attachmentBuilder.defineAsOutput(core::FrameRenderPassState::Complete);
99 : }
100 2 : return Rc<vk::BufferAttachment>::create(attachmentBuilder,
101 1 : core::BufferInfo(size_t(1 * 4 * sizeof(float)),
102 2 : core::BufferUsage::StorageBuffer | core::BufferUsage::TransferDst, core::PassType::Compute)
103 2 : );
104 2 : });
105 : }
106 :
107 : }
|