Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnModelTest.h"
24 : #include "XLSnnModel.h"
25 : #include "XLSnnLayer.h"
26 : #include "XLSnnAttachment.h"
27 : #include "XLSnnVkInputLayer.h"
28 : #include "XLVkAttachment.h"
29 : #include "XLCoreFrameRequest.h"
30 : #include "SPFilesystem.h"
31 : #include "SPBitmap.h"
32 : #include "SPValid.h"
33 :
34 : #include <byteswap.h>
35 :
36 : namespace stappler::xenolith::shadernn {
37 :
38 1 : void MnistTrainData::loadVectors(StringView ipath) {
39 1 : auto path = filepath::merge<Interface>(ipath, "train-labels.idx1-ubyte");
40 1 : auto vectors = ::fopen(path.data(), "r");
41 1 : if (vectors) {
42 1 : ::fseek(vectors, 0, SEEK_END);
43 1 : auto fsize = ::ftell(vectors);
44 1 : ::fseek(vectors, 0, SEEK_SET);
45 :
46 1 : ::fread(&vectorsHeader, sizeof(VectorsHeader), 1, vectors);
47 :
48 1 : vectorsHeader.magic = bswap_32(vectorsHeader.magic);
49 1 : vectorsHeader.items = bswap_32(vectorsHeader.items);
50 :
51 1 : auto dataSize = fsize - sizeof(ImagesHeader);
52 :
53 1 : uint8_t *buf = (uint8_t *)::malloc(dataSize);
54 1 : ::fread(buf, dataSize, 1, vectors);
55 :
56 1 : vectorsData = (float *)malloc(dataSize * sizeof(float) * 10);
57 1 : ::memset(vectorsData, 0, dataSize * sizeof(float));
58 1 : auto ptr = vectorsData;
59 59993 : for (size_t i = 0; i < dataSize; ++ i) {
60 659912 : for (size_t j = 0; j < 10; ++ j) {
61 599920 : ptr[j] = 0.0f;
62 : }
63 59992 : ptr[buf[i]] = 1.0f;
64 59992 : ptr += 10;
65 : }
66 :
67 1 : ::free(buf);
68 1 : ::fclose(vectors);
69 : }
70 1 : }
71 :
72 1 : void MnistTrainData::loadImages(StringView ipath) {
73 1 : auto path = filepath::merge<Interface>(ipath, "train-images.idx3-ubyte");
74 1 : auto images = ::fopen(path.data(), "r");
75 1 : if (images) {
76 1 : ::fseek(images, 0, SEEK_END);
77 1 : auto fsize = ::ftell(images);
78 1 : ::fseek(images, 0, SEEK_SET);
79 :
80 1 : ::fread(&imagesHeader, sizeof(ImagesHeader), 1, images);
81 :
82 1 : imagesHeader.magic = bswap_32(imagesHeader.magic);
83 1 : imagesHeader.images = bswap_32(imagesHeader.images);
84 1 : imagesHeader.rows = bswap_32(imagesHeader.rows);
85 1 : imagesHeader.columns = bswap_32(imagesHeader.columns);
86 :
87 1 : auto dataSize = fsize - sizeof(ImagesHeader);
88 :
89 :
90 1 : uint8_t *buf = (uint8_t *)::malloc(dataSize);
91 :
92 1 : ::fread(buf, dataSize, 1, images);
93 :
94 1 : imagesData = (float *)malloc(dataSize * sizeof(float));
95 1 : auto ptr = imagesData;
96 47040001 : for (size_t i = 0; i < dataSize; ++ i) {
97 47040000 : *ptr ++ = float(buf[i]) / 255.0f;
98 : }
99 :
100 1 : ::free(buf);
101 1 : ::fclose(images);
102 : }
103 1 : }
104 :
105 1 : void MnistTrainData::loadIndexes() {
106 1 : indexes.resize(imagesHeader.images);
107 60001 : for (size_t i = 0; i < imagesHeader.images; ++ i) {
108 60000 : indexes[i] = i;
109 : }
110 1 : }
111 :
112 1200 : void MnistTrainData::readImages(uint8_t *iptr, uint64_t size, size_t offset) {
113 1200 : auto count = size / (imagesHeader.rows * imagesHeader.columns * sizeof(float));
114 :
115 1200 : float *ptr = (float *)iptr;
116 :
117 121200 : for (size_t i = 0; i < count; ++ i) {
118 120000 : auto idx = indexes[offset + i];
119 120000 : auto blockSize = imagesHeader.rows * imagesHeader.columns;
120 120000 : ::memcpy(ptr + i * blockSize, imagesData + idx * blockSize, blockSize * sizeof(float));
121 : }
122 1200 : }
123 :
124 1200 : void MnistTrainData::readLabels(uint8_t *iptr, uint64_t size, size_t offset) {
125 1200 : auto count = size / (10 * sizeof(float));
126 :
127 1200 : float *ptr = (float *)iptr;
128 :
129 121200 : for (size_t i = 0; i < count; ++ i) {
130 120000 : auto idx = indexes[offset + i];
131 120000 : auto blockSize = 10;
132 120000 : ::memcpy(ptr + i * blockSize, vectorsData + idx * blockSize, blockSize * sizeof(float));
133 : }
134 1200 : }
135 :
136 0 : bool MnistTrainData::validateImages(const uint8_t *iptr, uint64_t size, size_t offset) {
137 0 : auto count = size / (imagesHeader.rows * imagesHeader.columns * sizeof(float));
138 :
139 0 : const float *ptr = (const float *)iptr;
140 :
141 0 : for (size_t i = 0; i < count; ++ i) {
142 0 : auto idx = indexes[offset + i];
143 0 : auto blockSize = imagesHeader.rows * imagesHeader.columns;
144 :
145 0 : auto sourcePtr = ptr + i * blockSize;
146 0 : auto targetPtr = imagesData + idx * blockSize;
147 :
148 0 : if (::memcmp(sourcePtr, targetPtr, blockSize * sizeof(float)) != 0) {
149 0 : for (size_t j = 0; j < imagesHeader.rows; ++ j) {
150 0 : for (size_t k = 0; k < imagesHeader.columns; ++ k) {
151 0 : std::cout << " " << *(sourcePtr + (j * imagesHeader.rows) + k);
152 : }
153 0 : std::cout << "\n";
154 : }
155 0 : std::cout << "\n";
156 0 : for (size_t j = 0; j < imagesHeader.rows; ++ j) {
157 0 : for (size_t k = 0; k < imagesHeader.columns; ++ k) {
158 0 : std::cout << " " << *(targetPtr + (j * imagesHeader.rows) + k);
159 : }
160 0 : std::cout << "\n";
161 : }
162 0 : std::cout << "\n";
163 0 : return false;
164 : }
165 : }
166 0 : return true;
167 : }
168 :
169 : struct StdRandom {
170 : using result_type = unsigned int;
171 :
172 : static constexpr auto min() { return std::numeric_limits<unsigned int>::min(); };
173 : static constexpr auto max() { return std::numeric_limits<unsigned int>::max(); };
174 :
175 71432 : unsigned int operator() () {
176 71432 : return rnd->next();
177 : }
178 :
179 : Random *rnd;
180 : };
181 :
182 2 : void MnistTrainData::shuffle(Random &rnd) {
183 2 : std::shuffle(indexes.begin(), indexes.end(), StdRandom{&rnd});
184 2 : }
185 :
186 1 : MnistTrainData::MnistTrainData(StringView path) {
187 1 : loadVectors(path);
188 1 : loadImages(path);
189 1 : loadIndexes();
190 1 : }
191 :
192 2 : MnistTrainData::~MnistTrainData() {
193 1 : if (imagesData) {
194 1 : ::free(imagesData);
195 1 : imagesData = nullptr;
196 : }
197 1 : if (vectorsData) {
198 1 : ::free(vectorsData);
199 1 : vectorsData = nullptr;
200 : }
201 2 : }
202 :
203 0 : Rc<CsvData> ModelQueue::readCsv(StringView data) {
204 0 : Rc<CsvData> ret = Rc<CsvData>::alloc();
205 :
206 0 : auto readQuoted = [] (StringView &r) {
207 0 : StringView tmp(r);
208 :
209 0 : while (r.is("\"\"")) {
210 0 : r += 2;
211 : }
212 :
213 0 : while (!r.empty() && !r.is('"')) {
214 0 : r.skipUntil<StringView::Chars<'"', '\\'>>();
215 0 : if (r.is('\\')) {
216 0 : r += 2;
217 : }
218 0 : while (r.is("\"\"")) {
219 0 : r += 2;
220 : }
221 : }
222 :
223 0 : tmp = StringView(tmp.data(), r.data() - tmp.data());
224 :
225 0 : if (r.is('"')) {
226 0 : ++ r;
227 : }
228 :
229 0 : return tmp;
230 : };
231 :
232 0 : auto readHeader = [&] (StringView &r) {
233 0 : while (!r.empty() && !r.is('\n') && !r.is("\r")) {
234 0 : r.skipChars<StringView::WhiteSpace>();
235 0 : StringView tmp(r);
236 0 : if (r.is('"')) {
237 0 : ++ r;
238 0 : ret->fields.emplace_back(readQuoted(r).str<Interface>());
239 0 : r.skipUntil<StringView::Chars<',', '\n', '\r'>>();
240 : } else {
241 0 : auto tmp = r.readUntil<StringView::Chars<',', '\n', '\r'>>();
242 0 : tmp.trimChars<StringView::WhiteSpace>();
243 0 : ret->fields.emplace_back(readQuoted(r).str<Interface>());
244 : }
245 0 : if (r.is(',')) {
246 0 : ++ r;
247 : }
248 : }
249 0 : if (r.is('\n') || r.is('\r')) {
250 0 : r.skipChars<StringView::WhiteSpace>();
251 : }
252 0 : };
253 :
254 0 : auto validateFloat = [] (StringView str) {
255 0 : if (str.empty()) {
256 0 : return false;
257 : }
258 :
259 0 : StringView r(str);
260 0 : if (r.is('-')) { ++ r; }
261 0 : r.skipChars<chars::Range<char, '0', '9'>, StringView::Chars<'.'>>();
262 0 : if (!r.empty()) {
263 0 : return false;
264 : }
265 :
266 0 : return true;
267 : };
268 :
269 0 : auto readLine = [&] (StringView &r) {
270 0 : Value ret;
271 0 : while (!r.empty() && !r.is('\n') && !r.is("\r")) {
272 0 : r.skipChars<StringView::WhiteSpace>();
273 0 : StringView tmp(r);
274 0 : if (r.is('"')) {
275 0 : ++ r;
276 0 : auto data = readQuoted(r);
277 0 : if (valid::validateNumber(data)) {
278 0 : ret.addInteger(data.readInteger(10).get(0));
279 0 : } else if (validateFloat(data)) {
280 0 : ret.addInteger(data.readInteger(10).get(0));
281 : } else {
282 0 : ret.addString(data.str<Interface>());
283 : }
284 0 : r.skipUntil<StringView::Chars<',', '\n', '\r'>>();
285 : } else {
286 0 : auto tmp = r.readUntil<StringView::Chars<',', '\n', '\r'>>();
287 0 : tmp.trimChars<StringView::WhiteSpace>();
288 0 : if (valid::validateNumber(tmp)) {
289 0 : ret.addInteger(tmp.readInteger(10).get(0));
290 0 : } else if (validateFloat(data)) {
291 0 : ret.addInteger(data.readInteger(10).get(0));
292 : } else {
293 0 : ret.addString(tmp.str<Interface>());
294 : }
295 : }
296 0 : if (r.is(',')) {
297 0 : ++ r;
298 : }
299 : }
300 0 : if (r.is('\n') || r.is('\r')) {
301 0 : r.skipChars<StringView::WhiteSpace>();
302 : }
303 0 : return ret;
304 0 : };
305 :
306 0 : while (!data.empty()) {
307 0 : if (ret->fields.empty()) {
308 0 : readHeader(data);
309 : } else {
310 0 : if (auto val = readLine(data)) {
311 0 : if (!val.getValue(0).isInteger()) {
312 0 : std::cout << val << "\n";
313 : }
314 0 : ret->data.emplace_back(move(val));
315 0 : }
316 : }
317 : }
318 :
319 0 : return ret;
320 0 : }
321 :
322 2 : ModelQueue::~ModelQueue() { }
323 :
324 1 : bool ModelQueue::init(StringView modelPath, ModelFlags flags, StringView input) {
325 1 : _processor = Rc<ModelProcessor>::create();
326 1 : _model = _processor->load(FilePath(modelPath), flags);
327 :
328 1 : if (!_model) {
329 0 : return false;
330 : }
331 :
332 1 : core::Queue::Builder builder(filepath::name(modelPath));
333 :
334 1 : Extent3 frameExtent(1, 1, 1);
335 1 : if (input.starts_with("mnist:")) {
336 1 : _trainData = Rc<MnistTrainData>::alloc(input.sub(6));
337 0 : } else if (input.starts_with("csv:")) {
338 0 : auto data = filesystem::readIntoMemory<Interface>(input.sub(4));
339 0 : if (!data.empty()) {
340 0 : _csvData = readCsv(StringView((const char *)data.data(), data.size()));
341 : } else {
342 0 : return false;
343 : }
344 0 : } else {
345 0 : if (!bitmap::getImageSize(input, frameExtent.width, frameExtent.height)) {
346 0 : log::error("InputQueue", "fail to read image: ", input);
347 0 : return false;
348 : }
349 : }
350 :
351 1 : _image = input.str<Interface>();
352 :
353 1 : Map<Layer *, const core::AttachmentData *> inputs;
354 1 : Map<Attachment *, const core::AttachmentData *> attachments;
355 :
356 1 : auto modelInputs = _model->getInputs();
357 3 : for (auto &it : modelInputs) {
358 2 : inputs.emplace(it, it->makeInputAttachment(builder));
359 : }
360 :
361 1 : size_t i = 0;
362 7 : for (auto &it : _model->getSortedLayers()) {
363 6 : auto output = it->getOutput();
364 6 : auto attachment = it->makeOutputAttachment(builder, _model->getSortedLayers().size() == i + 1);
365 6 : attachments.emplace(output, attachment);
366 6 : if (_model->getSortedLayers().size() == i + 1) {
367 1 : _outputAttachment = attachment;
368 : }
369 6 : ++ i;
370 : }
371 :
372 7 : for (auto &it : _model->getSortedLayers()) {
373 6 : auto ret = it->prepare(builder, inputs, attachments);
374 6 : if (!ret) {
375 0 : return false;
376 : }
377 6 : if (it->isInput()) {
378 2 : _inputAttachments.emplace_back(static_cast<vk::shadernn::InputLayer *>(ret->pass.get())->getDataAttachment());
379 : }
380 : }
381 :
382 1 : return Queue::init(move(builder));
383 1 : }
384 :
385 1200 : void ModelQueue::run(Application *app) {
386 1200 : _app = app;
387 :
388 1200 : Extent3 frameExtent(1, 1, 1);
389 1200 : if (!_trainData && !_csvData) {
390 0 : if (!bitmap::getImageSize(StringView(_image), frameExtent.width, frameExtent.height)) {
391 0 : log::error("InputQueue", "fail to read image: ", _image);
392 0 : return;
393 : }
394 1200 : } else if (_csvData) {
395 0 : frameExtent.width = _csvData->fields.size();
396 0 : frameExtent.height = _csvData->data.size();
397 : }
398 :
399 1200 : auto req = Rc<core::FrameRequest>::create(Rc<core::Queue>(this), core::FrameContraints{frameExtent});
400 :
401 1200 : ModelSpecialization spec = _processor->specializeModel(_model, frameExtent);
402 :
403 8400 : for (auto &it : spec.attachments) {
404 7200 : if (auto a = getAttachment(it.first->getName())) {
405 0 : if (auto img = dynamic_cast<vk::ImageAttachment *>(a->attachment.get())) {
406 0 : core::ImageInfoData info = img->getImageInfo();
407 0 : info.extent = it.second;
408 0 : log::debug("ModelQueue", "Specialize attachment ", it.first->getName(), " for extent ", info.extent);
409 0 : req->addImageSpecialization(img, move(info));
410 : }
411 : }
412 : }
413 :
414 3600 : for (auto &it : _inputAttachments) {
415 2400 : if (it->type == core::AttachmentType::Image) {
416 0 : auto inputData = Rc<vk::shadernn::InputDataInput>::alloc();
417 0 : inputData->norm = vk::shadernn::NormData{
418 : Vec4(-0.5f, -0.5f, -0.5f, -0.5f), // to -0.5 - 0.5
419 : Vec4(2.0f, 2.0f, 2.0f, 2.0f) // to -1.0 - 1.0
420 : };
421 0 : inputData->image.extent = frameExtent;
422 0 : inputData->image.stdCallback = [path = _image] (uint8_t *ptr, uint64_t size, const core::ImageData::DataCallback &dcb) {
423 0 : core::Resource::loadImageFileData(ptr, size, path, core::ImageFormat::R8G8B8A8_UNORM, dcb);
424 0 : };
425 :
426 0 : req->addInput(it, move(inputData));
427 2400 : } else if (_trainData) {
428 2400 : if (it->key == "input_samples_buffer") {
429 1200 : auto inputData = Rc<vk::shadernn::InputBufferDataInput>::alloc();
430 4800 : inputData->buffer.stdCallback = [data = _trainData, offset = _loadOffset] (uint8_t *ptr, uint64_t size, const core::BufferData::DataCallback &cb) {
431 1200 : data->readImages(ptr, size, offset);
432 1200 : };
433 1200 : req->addInput(it, move(inputData));
434 2400 : } else if (it->key == "input_labels_buffer") {
435 1200 : auto inputData = Rc<vk::shadernn::InputBufferDataInput>::alloc();
436 4800 : inputData->buffer.stdCallback = [data = _trainData, offset = _loadOffset] (uint8_t *ptr, uint64_t size, const core::BufferData::DataCallback &cb) {
437 1200 : data->readLabels(ptr, size, offset);
438 1200 : };
439 1200 : req->addInput(it, move(inputData));
440 1200 : }
441 0 : } else if (_csvData) {
442 0 : auto inputData = Rc<vk::shadernn::InputCsvInput>::alloc();
443 0 : inputData->data = _csvData->data;
444 0 : req->addInput(it, move(inputData));
445 0 : }
446 : }
447 :
448 1200 : req->setOutput(getOutputAttachment(), [this, app, trainData = _trainData, csvData = _csvData] (core::FrameAttachmentData &data, bool success, Ref *) {
449 1200 : if (data.image) {
450 0 : app->getGlLoop()->captureImage([app] (core::ImageInfoData info, BytesView view) {
451 0 : if (!view.empty()) {
452 0 : Vector<Bitmap> planes;
453 0 : for (size_t i = 0; i < info.extent.depth; ++ i) {
454 0 : auto &b = planes.emplace_back(Bitmap());
455 0 : b.alloc(info.extent.width, info.extent.height, bitmap::PixelFormat::RGBA8888, bitmap::AlphaFormat::Premultiplied);
456 : }
457 :
458 0 : Vector<uint8_t *> ptrs;
459 0 : for (auto &it : planes) {
460 0 : ptrs.emplace_back(it.dataPtr());
461 : }
462 :
463 0 : std::cout << view.size() << " ";
464 :
465 0 : for (size_t i = 0; i < info.extent.width; ++ i) {
466 0 : for (size_t j = 0; j < info.extent.height; ++ j) {
467 0 : for (size_t k = 0; k < info.extent.depth; ++ k) {
468 0 : for (size_t f = 0; f < 4; ++ f) {
469 0 : *ptrs[k] = uint8_t((view.readFloat32() + 1.0f) * 127.5f);
470 : if (f == 3) {
471 : //*ptrs[k] = 255;
472 : }
473 0 : ptrs[k] ++;
474 : }
475 : }
476 : }
477 : }
478 :
479 0 : std::cout << view.size() << "\n";
480 :
481 0 : for (size_t i = 0; i < info.extent.depth; ++ i) {
482 0 : planes[i].save(toString(i, "_", Time::now().toMicros(), ".png"));
483 : }
484 0 : }
485 0 : app->end();
486 0 : }, data.image->getImage(), core::AttachmentLayout::General);
487 1200 : } else if (auto buf = dynamic_cast<vk::BufferAttachmentHandle *>(data.handle.get())) {
488 1200 : if (trainData) {
489 1200 : auto target = buf->getBuffers().back().buffer;
490 1200 : app->getGlLoop()->captureBuffer([this, app] (const core::BufferInfo &info, BytesView view) {
491 : //front->synchronizeParameters(SpanView<float>((float *)view.data(), view.size() / sizeof(float)));
492 1200 : onComplete(app, view);
493 1200 : }, target);
494 :
495 1200 : } else if (csvData) {
496 0 : auto target = buf->getBuffers().front().buffer;
497 0 : app->getGlLoop()->captureBuffer([app, trainData] (core::BufferInfo info, BytesView view) {
498 0 : std::cout << view.size() / (sizeof(float) * 4) << "\n";
499 0 : std::cout << "0";
500 :
501 0 : size_t row = 1;
502 0 : size_t i = 0;
503 0 : while (!view.empty()) {
504 0 : switch (i) {
505 0 : case 0: {
506 0 : auto u = view.readUnsigned32();
507 0 : if (u > uint32_t(maxOf<int32_t>())) {
508 0 : std::cout << ", " << static_cast<int32_t>(u);
509 : } else {
510 0 : std::cout << ", " << u;
511 : }
512 0 : break;
513 : }
514 0 : case 1:
515 0 : std::cout << ", " << view.readUnsigned32();
516 0 : break;
517 0 : case 2:
518 : case 3:
519 0 : std::cout << ", " << view.readFloat32();
520 0 : break;
521 0 : default:
522 0 : view.readUnsigned32();
523 0 : break;
524 : }
525 0 : ++ i;
526 0 : if (i > 3) {
527 0 : i = 0;
528 0 : std::cout << "\n" << row ++;
529 : }
530 : }
531 0 : std::cout << "\n";
532 0 : }, target);
533 0 : }
534 : }
535 1200 : return true;
536 : });
537 :
538 1200 : app->getGlLoop()->runRenderQueue(move(req), 0);
539 1200 : }
540 :
541 1200 : void ModelQueue::onComplete(Application *app, BytesView data) {
542 : //auto iter = _loadOffset / 100;
543 1200 : _loadOffset += 100;
544 :
545 1200 : data.readFloat32();
546 1200 : auto loss = data.readFloat32();
547 :
548 1200 : if (_loadOffset < 60000) {
549 : // 5 std::cout << iter << " Loss: " << loss << "\n";
550 1198 : _epochLoss += loss;
551 1198 : app->performOnMainThread([this, app] {
552 1198 : run(_app);
553 1198 : });
554 : } else {
555 2 : std::cout << _epoch << ": avg loss: " << _epochLoss / 600.0f << "\n";
556 :
557 2 : _trainData->shuffle(_model->getRand());
558 2 : _epochLoss = 0;
559 2 : _loadOffset = 0;
560 2 : ++ _epoch;
561 :
562 2 : if (_epoch < _endEpoch) {
563 1 : app->performOnMainThread([this, app] {
564 1 : run(app);
565 1 : });
566 : } else {
567 1 : release(0);
568 1 : app->end();
569 : }
570 : }
571 1200 : }
572 :
573 : }
|