Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnMatrixMulLayer.h"
24 : #include "XLSnnVkMatrixMulLayer.h"
25 :
26 : namespace stappler::xenolith::shadernn {
27 :
28 3 : bool MatrixMulLayer::init(Model *m, StringView tag, size_t idx, const Value &data) {
29 3 : if (!Layer::init(m, tag, idx, data)) {
30 0 : return false;
31 : }
32 :
33 3 : _activation = getActivationValue(data.getString("activation"));
34 3 : _kernelSize = data.getInteger("kernel_size");
35 :
36 3 : return true;
37 : }
38 :
39 3 : void MatrixMulLayer::setInputExtent(uint32_t index, Attachment *a, Extent3 e) {
40 3 : Layer::setInputExtent(index, a, e);
41 3 : _input = a->getOutputBy();
42 3 : _batchSize = e.depth;
43 3 : }
44 :
45 3 : const core::QueuePassData *MatrixMulLayer::prepare(core::Queue::Builder &builder,
46 : Map<Layer *, const core::AttachmentData *> inputs,
47 : Map<Attachment *, const core::AttachmentData *> attachments) {
48 3 : auto inputIt = attachments.find(_inputs.front().attachment);
49 3 : auto outputIt = attachments.find(getOutput());
50 :
51 3 : if (inputIt == attachments.end() || outputIt == attachments.end()) {
52 0 : log::error("snn::InputLayer", "No attachments specified");
53 0 : return nullptr;
54 : }
55 :
56 6 : return builder.addPass(getName(), core::PassType::Compute, core::RenderOrdering(_inputIndex),
57 3 : [&] (core::QueuePassBuilder &passBuilder) -> Rc<core::QueuePass> {
58 6 : return Rc<vk::shadernn::MatrixMulLayer>::create(builder, passBuilder, this, inputIt->second, outputIt->second);
59 3 : });
60 : }
61 :
62 3 : const core::AttachmentData *MatrixMulLayer::makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) {
63 6 : return builder.addAttachemnt(toString(getName(), "_output"),
64 3 : [&] (core::AttachmentBuilder &attachmentBuilder) -> Rc<core::Attachment> {
65 3 : if (isGlobalOutput) {
66 0 : attachmentBuilder.defineAsOutput();
67 : }
68 3 : auto ext = getOutputExtent();
69 6 : return Rc<vk::BufferAttachment>::create(attachmentBuilder,
70 3 : core::BufferInfo(size_t(ext.width * ext.height * ext.depth * sizeof(float)),
71 6 : core::BufferUsage::StorageBuffer | core::BufferUsage::TransferDst, core::PassType::Compute)
72 6 : );
73 6 : });
74 : }
75 :
76 3 : void MatrixMulLayer::generateWeights(uint8_t *buf, uint64_t size, const core::BufferData::DataCallback &cb) const {
77 3 : auto &rand = _model->getRand();
78 3 : auto ext = _input->getOutputExtent();
79 3 : auto inputCount = ext.width * ext.height / 2;
80 :
81 3 : size /= sizeof(float);
82 3 : auto target = (float *)buf;
83 :
84 3 : double deviation = std::sqrt(1. / std::max(inputCount, uint32_t(1)));
85 :
86 1332227 : while (size) {
87 1332224 : *(target ++) = rand.normal(0, deviation);
88 1332224 : -- size;
89 : }
90 3 : }
91 :
92 3 : void MatrixMulLayer::generateFreeTerms(uint8_t *buf, uint64_t size, const core::BufferData::DataCallback &cb) const {
93 3 : size /= sizeof(float);
94 3 : auto target = (float *)buf;
95 :
96 1549 : while (size) {
97 1546 : *(target ++) = 0.0f;
98 1546 : -- size;
99 : }
100 3 : }
101 :
102 :
103 : }
|