Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnModel.h"
24 : #include "XLSnnLossLayer.h"
25 : #include "XLSnnAttachment.h"
26 :
27 : namespace stappler::xenolith::shadernn {
28 :
29 0 : void Model::saveBlob(const char *path, const uint8_t *buf, size_t size) {
30 0 : ::unlink(path);
31 :
32 0 : auto f = fopen(path, "w");
33 0 : ::fwrite(&size, sizeof(size), 1, f);
34 0 : ::fwrite(buf, size, 1, f);
35 0 : ::fclose(f);
36 0 : }
37 :
38 0 : void Model::loadBlob(const char *path, const std::function<void(const uint8_t *, size_t)> &cb) {
39 0 : auto f = fopen(path, "r");
40 0 : if (!f) {
41 0 : return;
42 : }
43 :
44 0 : size_t size = 0;
45 0 : ::fread(&size, sizeof(size), 1, f);
46 :
47 0 : auto buf = new uint8_t[size];
48 0 : if (size > 0) {
49 0 : ::fread(buf, size, 1, f);
50 :
51 0 : cb(buf, size);
52 : }
53 :
54 0 : ::fclose(f);
55 0 : delete[] buf;
56 : }
57 :
58 0 : bool Model::compareBlob(const uint8_t *a, size_t na, const uint8_t *b, size_t nb, float v) {
59 0 : return compareBlob((const float *)a, na / sizeof(float), (const float *)b, nb / sizeof(float), v);
60 : }
61 :
62 0 : bool Model::compareBlob(const float *a, size_t na, const float *b, size_t nb, float v) {
63 0 : if (na != nb) {
64 0 : return false;
65 : }
66 :
67 0 : while (na > 0) {
68 0 : if (std::abs(*a - *b) > v) {
69 0 : return false;
70 : }
71 0 : ++ a; ++ b; -- na;
72 : }
73 0 : return true;
74 : }
75 :
76 2 : Model::~Model() { }
77 :
78 1 : bool Model::init(ModelFlags f, const Value &val, uint32_t numLayers, StringView dataFilePath) {
79 1 : _flags = f;
80 1 : _numLayers = numLayers;
81 :
82 1 : if (!dataFilePath.empty()) {
83 0 : _dataFile = filesystem::openForReading(dataFilePath);
84 0 : if (!_dataFile) {
85 0 : log::error("snn::Model", "Fail to open model data file: ", dataFilePath);
86 0 : return false;
87 : }
88 : }
89 :
90 1 : if (val.isBool("trainable")) {
91 1 : if (val.getBool("trainable")) {
92 1 : _flags |= ModelFlags::Trainable;
93 : }
94 : }
95 :
96 1 : auto range = val.getString("inputRange");
97 1 : if (range == "[0,1]") {
98 0 : _flags |= ModelFlags::Range01;
99 : }
100 :
101 1 : if (auto &block = val.getValue("block_0")) {
102 0 : for (auto &it : block.asDict()) {
103 0 : if (it.first == "Input Height") {
104 0 : _inputHeight = it.second.getInteger();
105 0 : } else if (it.first == "Input Width") {
106 0 : _inputWidth = it.second.getInteger();
107 : }
108 : }
109 : }
110 :
111 1 : if (auto &node = val.getValue("node")) {
112 0 : for (auto &it : node.asDict()) {
113 0 : if (it.first == "inputChannels") {
114 0 : _inputChannels = it.second.getInteger();
115 0 : } else if (it.first == "upscale") {
116 0 : _upscale = it.second.getInteger();
117 0 : } else if (it.first == "useSubpixel") {
118 0 : _useSubPixel = it.second.getBool();
119 : }
120 : }
121 : }
122 :
123 1 : return true;
124 1 : }
125 :
126 6 : void Model::addLayer(Rc<Layer> &&l) {
127 6 : _layers.emplace(l->getInputIndex(), move(l));
128 6 : }
129 :
130 1 : bool Model::link() {
131 : // find inputs
132 1 : Vector<Layer *> linkedLayers;
133 7 : for (auto &it : _layers) {
134 6 : if (it.second->isInput()) {
135 2 : auto attachment = Rc<Attachment>::create(_attachments.size(), it.second->getOutputExtent(), it.second);
136 2 : it.second->setOutput(attachment);
137 2 : linkedLayers.emplace_back(it.second.get());
138 2 : linkInput(linkedLayers, it.second, _attachments.emplace_back(move(attachment)));
139 2 : }
140 : }
141 :
142 1 : if (linkedLayers.size() == _layers.size()) {
143 1 : _sortedLayers = move(linkedLayers);
144 1 : return true;
145 : }
146 :
147 0 : log::error("snn::Model", "Fail to link model: potential loop in execution tree");
148 0 : return false;
149 1 : }
150 :
151 0 : float Model::readFloatData() {
152 0 : float value = 0.0f;
153 0 : if (_dataFile.read((uint8_t*) &value, sizeof(float)) == sizeof(float)) {
154 0 : if (isHalfPrecision()) {
155 0 : value = convertToMediumPrecision(value);
156 : }
157 : }
158 0 : return value;
159 : }
160 :
161 0 : float Model::getLastLoss() const {
162 0 : if (auto l = dynamic_cast<LossLayer *>(_sortedLayers.back())) {
163 0 : return l->getParameter(LossLayer::P_Loss);
164 : }
165 0 : return 0.0f;
166 : }
167 :
168 1201 : Vector<Layer *> Model::getInputs() const {
169 1201 : Vector<Layer *> ret;
170 8407 : for (auto &it : _sortedLayers) {
171 7206 : if (it->isInput()) {
172 2402 : ret.emplace_back(it);
173 : }
174 : }
175 1201 : return ret;
176 0 : }
177 :
178 6 : void Model::linkInput(Vector<Layer *> &layers, Layer *inputLayer, Attachment *attachment) {
179 : // find usage
180 42 : for (auto &it : _layers) {
181 66 : for (auto &input : it.second->getInputs()) {
182 30 : if (input.layer == inputLayer->getInputIndex()) {
183 5 : attachment->addInputBy(it.second);
184 5 : it.second->setInputExtent(input.index, attachment, inputLayer->getOutputExtent());
185 5 : if (it.second->isInputDefined()) {
186 4 : if (std::find(layers.begin(), layers.end(), it.second.get()) == layers.end()) {
187 4 : auto attachment = Rc<Attachment>::create(_attachments.size(), it.second->getOutputExtent(), it.second);
188 4 : it.second->setOutput(attachment);
189 4 : layers.emplace_back(it.second.get());
190 4 : std::cout << "Layer " << it.second->getName() << " (" << it.second->getTag() << ") " << it.second->getOutputExtent() << "\n";
191 4 : linkInput(layers, it.second, _attachments.emplace_back(move(attachment)));
192 4 : }
193 : }
194 : }
195 : }
196 : }
197 6 : }
198 :
199 3 : Activation getActivationValue(StringView istr) {
200 3 : auto str = string::toupper<Interface>(istr);
201 :
202 3 : if (str == "RELU") {
203 2 : return Activation::RELU;
204 1 : } else if (str == "RELU6") {
205 0 : return Activation::RELU6;
206 1 : } else if (str == "TANH") {
207 0 : return Activation::TANH;
208 1 : } else if (str == "SIGMOID") {
209 0 : return Activation::SIGMOID;
210 1 : } else if (str == "LEAKYRELU") {
211 0 : return Activation::LEAKYRELU;
212 1 : } else if (str == "SILU") {
213 0 : return Activation::SILU;
214 1 : } else if (str == "LINEAR") {
215 1 : return Activation::None;
216 : }
217 0 : log::error("snn::Model", "Unknown activation: ", istr);
218 0 : return Activation::None;
219 3 : }
220 :
221 0 : float convertToMediumPrecision(float in) {
222 : union tmp {
223 : unsigned int unsint;
224 : float flt;
225 : };
226 :
227 : tmp _16to32, _32to16;
228 :
229 0 : _32to16.flt = in;
230 :
231 0 : unsigned short sign = (_32to16.unsint & 0x80000000) >> 31;
232 0 : unsigned short exponent = (_32to16.unsint & 0x7F800000) >> 23;
233 0 : unsigned int mantissa = _32to16.unsint & 0x7FFFFF;
234 :
235 0 : short newexp = exponent + (-127 + 15);
236 :
237 : unsigned int newMantissa;
238 0 : if (newexp >= 31) {
239 0 : newexp = 31;
240 0 : newMantissa = 0x00;
241 0 : } else if (newexp <= 0) {
242 0 : newexp = 0;
243 0 : newMantissa = 0;
244 : } else {
245 0 : newMantissa = mantissa >> 13;
246 : }
247 :
248 0 : if (newexp == 0) {
249 0 : if (newMantissa == 0) {
250 0 : _16to32.unsint = sign << 31;
251 : } else {
252 0 : newexp = 0;
253 0 : while ((newMantissa & 0x200) == 0) {
254 0 : newMantissa <<= 1;
255 0 : newexp++;
256 : }
257 0 : newMantissa <<= 1;
258 0 : newMantissa &= 0x3FF;
259 0 : _16to32.unsint = (sign << 31) | ((newexp + (-15 + 127)) << 23) | (newMantissa << 13);
260 : }
261 0 : } else if (newexp == 31) {
262 0 : _16to32.unsint = (sign << 31) | (0xFF << 23) | (newMantissa << 13);
263 : } else {
264 0 : _16to32.unsint = (sign << 31) | ((newexp + (-15 + 127)) << 23) | (newMantissa << 13);
265 : }
266 :
267 0 : return _16to32.flt;
268 : }
269 :
270 0 : float convertToHighPrecision(uint16_t in) {
271 : union tmp {
272 : unsigned int unsint;
273 : float flt;
274 : };
275 :
276 0 : unsigned short sign = (in & 0x8000) >> 15;
277 0 : unsigned short exponent = (in & 0x7C00) >> 10;
278 0 : unsigned int mantissa = (in & 0x3FF);
279 :
280 : tmp _16to32;
281 :
282 0 : short newexp = exponent + 127 - 15;
283 0 : unsigned int newMantissa = mantissa;
284 0 : if (newexp <= 0) {
285 0 : newexp = 0;
286 0 : newMantissa = 0;
287 : }
288 :
289 0 : _16to32.unsint = (sign << 31) | (newexp << 23) | (newMantissa << 13);
290 0 : return _16to32.flt;
291 : }
292 :
293 0 : void convertToMediumPrecision(Vector<float> &in) {
294 0 : for (auto &val : in) {
295 0 : val = convertToMediumPrecision(val);
296 : }
297 0 : }
298 :
299 0 : void convertToMediumPrecision(Vector<double> &in) {
300 0 : for (auto &val : in) {
301 0 : val = convertToMediumPrecision(val);
302 : }
303 0 : }
304 :
305 0 : void getByteRepresentation(float in, Vector<unsigned char> &byteRep, bool fp16) {
306 0 : if (fp16) {
307 : union tmp {
308 : unsigned int unsint;
309 : float flt;
310 : };
311 : tmp _32to16;
312 0 : _32to16.flt = in;
313 : unsigned short fp16Val;
314 :
315 0 : unsigned short sign = (_32to16.unsint & 0x80000000) >> 31;
316 0 : unsigned short exponent = (_32to16.unsint & 0x7F800000) >> 23;
317 0 : unsigned int mantissa = (_32to16.unsint & 0x7FFFFF);
318 :
319 0 : short newexp = exponent + (-127 + 15);
320 :
321 : unsigned int newMantissa;
322 0 : if (newexp >= 31) {
323 0 : newexp = 31;
324 0 : newMantissa = 0x00;
325 0 : } else if (newexp <= 0) {
326 0 : newexp = 0;
327 0 : newMantissa = 0;
328 : } else {
329 0 : newMantissa = mantissa >> 13;
330 : }
331 :
332 0 : fp16Val = sign << 15;
333 0 : fp16Val = fp16Val | (newexp << 10);
334 0 : fp16Val = fp16Val | (newMantissa);
335 :
336 0 : byteRep.push_back(fp16Val & 0xFF);
337 0 : byteRep.push_back(fp16Val >> 8 & 0xFF);
338 : } else {
339 : uint32_t fp32Val;
340 0 : ::memcpy(&fp32Val, &in, 4);
341 0 : for (int i = 0; i < 4; i++) {
342 0 : byteRep.push_back((fp32Val >> 8 * i) & 0xFF);
343 : }
344 : }
345 0 : }
346 :
347 : }
|