Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnLayer.h"
24 : #include "XLSnnModelProcessor.h"
25 :
26 : namespace stappler::xenolith::shadernn {
27 :
28 6 : bool Layer::init(Model *model, StringView tag, size_t idx, const Value &data) {
29 6 : _model = model;
30 6 : _tag = tag.str<Interface>();
31 6 : _name = data.getString("name");
32 6 : _inputIndex = idx;
33 6 : _numInputPlanes = uint32_t(data.getInteger("inputPlanes"));
34 6 : _numOutputPlanes = uint32_t(data.getInteger("outputPlanes"));
35 :
36 6 : auto numInputs = data.getInteger("numInputs");
37 6 : _inputs.reserve(numInputs);
38 :
39 6 : uint32_t i = 0;
40 11 : for (auto &it : data.getArray("inputId")) {
41 5 : _inputs.emplace_back(LayerInputInfo({i, static_cast<uint32_t>(it.getInteger())}));
42 5 : ++ i;
43 : }
44 :
45 6 : i = 0;
46 11 : for (auto &it : data.getArray("inbounds")) {
47 5 : if (i < _inputs.size()) {
48 5 : _inputs[i].name = it.getString();
49 : }
50 5 : ++ i;
51 : }
52 :
53 6 : return true;
54 : }
55 :
56 5 : void Layer::setInputExtent(uint32_t index, Attachment *a, Extent3 e) {
57 5 : if (index < _inputs.size()) {
58 5 : _inputs[index].extent = e;
59 5 : _inputs[index].attachment = a;
60 : }
61 5 : }
62 :
63 5 : bool Layer::isInputDefined() const {
64 10 : for (auto &it : _inputs) {
65 6 : if (it.extent == Extent3()) {
66 1 : return false;
67 : }
68 : }
69 :
70 4 : return true;
71 : }
72 :
73 18008 : Extent3 Layer::getOutputExtent() const {
74 18008 : Extent3 ret;
75 18008 : LayerTransformInfo accumulatedTransform( { 0, { { 0.0f, 0.0f, 0.0f, 0.0f } } });
76 18008 : auto t = getOutputTransform();
77 36018 : for (auto &dim : _inputs) {
78 18010 : if (!t.isFixed) {
79 18010 : accumulatedTransform.scaleWidth = std::max(accumulatedTransform.scaleWidth, t.scaleWidth * dim.extent.width);
80 18010 : accumulatedTransform.translateWidth = std::max(accumulatedTransform.translateWidth, t.translateWidth);
81 18010 : accumulatedTransform.scaleHeight = std::max(accumulatedTransform.scaleHeight, t.scaleHeight * dim.extent.height);
82 18010 : accumulatedTransform.translateHeight = std::max(accumulatedTransform.translateHeight, t.translateHeight);
83 18010 : ret.width = accumulatedTransform.scaleWidth + accumulatedTransform.translateWidth;
84 18010 : ret.height = accumulatedTransform.scaleHeight + accumulatedTransform.translateHeight;
85 18010 : ret.depth = std::max(ret.depth, dim.extent.depth);
86 : } else {
87 0 : accumulatedTransform.fixedWidth = std::max(accumulatedTransform.fixedWidth, t.fixedWidth);
88 0 : accumulatedTransform.fixedHeight = std::max(accumulatedTransform.fixedHeight, t.fixedHeight);
89 0 : accumulatedTransform.fixedDepth = std::max(accumulatedTransform.fixedDepth, t.fixedDepth);
90 0 : accumulatedTransform.fixedBatch = std::max(accumulatedTransform.fixedBatch, t.fixedBatch);
91 0 : ret.width = accumulatedTransform.fixedWidth + accumulatedTransform.translateWidth;
92 0 : ret.height = accumulatedTransform.fixedHeight + accumulatedTransform.translateHeight;
93 0 : ret.depth = std::max(ret.depth, dim.extent.depth);
94 : }
95 : }
96 18008 : ret.depth = (_numOutputPlanes + 3) / 4;
97 18008 : return ret;
98 : }
99 :
100 4800 : Extent3 Layer::getOutputExtent(const ModelSpecialization &spec) const {
101 4800 : Extent3 ret;
102 4800 : LayerTransformInfo accumulatedTransform( { 0, { { 0.0f, 0.0f, 0.0f, 0.0f } } });
103 4800 : auto t = getOutputTransform();
104 10800 : for (auto &d : _inputs) {
105 6000 : auto iit = spec.attachments.find(d.attachment);
106 6000 : if (iit != spec.attachments.end()) {
107 6000 : auto extent = iit->second;
108 6000 : if (!t.isFixed) {
109 6000 : accumulatedTransform.scaleWidth = std::max(accumulatedTransform.scaleWidth, t.scaleWidth * extent.width);
110 6000 : accumulatedTransform.translateWidth = std::max(accumulatedTransform.translateWidth, t.translateWidth);
111 6000 : accumulatedTransform.scaleHeight = std::max(accumulatedTransform.scaleHeight, t.scaleHeight * extent.height);
112 6000 : accumulatedTransform.translateHeight = std::max(accumulatedTransform.translateHeight, t.translateHeight);
113 6000 : ret.width = accumulatedTransform.scaleWidth + accumulatedTransform.translateWidth;
114 6000 : ret.height = accumulatedTransform.scaleHeight + accumulatedTransform.translateHeight;
115 6000 : ret.depth = std::max(ret.depth, extent.depth);
116 : } else {
117 0 : accumulatedTransform.fixedWidth = std::max(accumulatedTransform.fixedWidth, t.fixedWidth);
118 0 : accumulatedTransform.fixedHeight = std::max(accumulatedTransform.fixedHeight, t.fixedHeight);
119 0 : accumulatedTransform.fixedDepth = std::max(accumulatedTransform.fixedDepth, t.fixedDepth);
120 0 : accumulatedTransform.fixedBatch = std::max(accumulatedTransform.fixedBatch, t.fixedBatch);
121 0 : ret.width = accumulatedTransform.fixedWidth + accumulatedTransform.translateWidth;
122 0 : ret.height = accumulatedTransform.fixedHeight + accumulatedTransform.translateHeight;
123 0 : ret.depth = std::max(ret.depth, extent.depth);
124 : }
125 : } else {
126 0 : log::error("snn::Layer", "Extent is not defined for layer : ", _name);
127 : }
128 : }
129 4800 : ret.depth = (_numOutputPlanes + 3) / 4;
130 4800 : return ret;
131 : }
132 :
133 : }
|