Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #ifndef SRC_LAYERS_XLSNNINPUTLAYER_H_
24 : #define SRC_LAYERS_XLSNNINPUTLAYER_H_
25 :
26 : #include "XLSnnLayer.h"
27 :
28 : namespace stappler::xenolith::shadernn {
29 :
30 : class InputLayer : public Layer {
31 : public:
32 0 : virtual ~InputLayer() = default;
33 :
34 : virtual bool init(Model *, StringView tag, size_t idx, const Value&) override;
35 :
36 0 : virtual Extent3 getOutputExtent() const override {
37 0 : return Extent3(_inputWidth, _inputHeight, (_inputChannels + 3) / 4);
38 : }
39 :
40 : virtual const core::QueuePassData *prepare(core::Queue::Builder &builder,
41 : Map<Layer *, const core::AttachmentData *> inputs,
42 : Map<Attachment *, const core::AttachmentData *> attachments) override;
43 :
44 : virtual const core::AttachmentData *makeInputAttachment(core::Queue::Builder &builder) override;
45 : virtual const core::AttachmentData *makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) override;
46 :
47 : protected:
48 : uint32_t _inputWidth = 1;
49 : uint32_t _inputHeight = 1;
50 : uint32_t _inputChannels = 1;
51 : };
52 :
53 : class InputBufferLayer : public Layer {
54 : public:
55 4 : virtual ~InputBufferLayer() = default;
56 :
57 : virtual bool init(Model *, StringView tag, size_t idx, const Value&) override;
58 :
59 2405 : virtual Extent3 getOutputExtent() const override {
60 2405 : return Extent3(_inputWidth * _inputHeight, 1, _inputObjects);
61 : }
62 :
63 : virtual const core::QueuePassData *prepare(core::Queue::Builder &builder,
64 : Map<Layer *, const core::AttachmentData *> inputs,
65 : Map<Attachment *, const core::AttachmentData *> attachments) override;
66 :
67 : virtual const core::AttachmentData *makeInputAttachment(core::Queue::Builder &builder) override;
68 : virtual const core::AttachmentData *makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) override;
69 :
70 2400 : float getNorm() const { return _norm; }
71 2400 : float getMean() const { return _mean; }
72 :
73 7200 : uint32_t getBufferSize() const { return _inputWidth * _inputHeight * _inputObjects; }
74 :
75 : protected:
76 : uint32_t _inputWidth = 1;
77 : uint32_t _inputHeight = 1;
78 : uint32_t _inputObjects = 1;
79 :
80 : float _norm = 1.0f;
81 : float _mean = 0.0f;
82 : };
83 :
84 : class InputCsvIntLayer : public Layer {
85 : public:
86 : struct NormData {
87 : uint64_t offset;
88 : uint64_t norm;
89 : };
90 :
91 0 : virtual ~InputCsvIntLayer() = default;
92 :
93 : virtual bool init(Model *, StringView tag, size_t idx, const Value&) override;
94 :
95 0 : virtual Extent3 getOutputExtent() const override {
96 0 : return Extent3(_fields.size(), _inputObjects, 1);
97 : }
98 :
99 : SpanView<NormData> getNormData() const { return _norm; }
100 0 : BytesView getNormDataBuffer() const { return BytesView((const uint8_t *)_norm.data(), _norm.size() * sizeof(NormData)); }
101 :
102 0 : SpanView<uint32_t> getFields() const { return _fields; }
103 :
104 : virtual const core::QueuePassData *prepare(core::Queue::Builder &builder,
105 : Map<Layer *, const core::AttachmentData *> inputs,
106 : Map<Attachment *, const core::AttachmentData *> attachments) override;
107 :
108 : virtual const core::AttachmentData *makeInputAttachment(core::Queue::Builder &builder) override;
109 : virtual const core::AttachmentData *makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) override;
110 :
111 : protected:
112 : uint32_t _inputObjects = 1;
113 : Vector<uint32_t> _fields;
114 : Vector<NormData> _norm;
115 : };
116 :
117 : }
118 :
119 : #endif /* SRC_LAYERS_XLSNNINPUTLAYER_H_ */
|