Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #ifndef SRC_PROCESSOR_XLSNNMODEL_H_
24 : #define SRC_PROCESSOR_XLSNNMODEL_H_
25 :
26 : #include "XLSnnRandom.h"
27 :
28 : namespace stappler::xenolith::shadernn {
29 :
30 : template <typename T>
31 0 : static constexpr auto ROUND_UP(T x, T y) -> T { return (((x) + (y) - (1)) / (y) * (y)); }
32 :
33 : template <typename T>
34 0 : static constexpr auto UP_DIV(T x, T y) -> T { return ((x) + (y) - (1)) / (y); }
35 :
36 : enum class Activation : uint32_t {
37 : None,
38 : RELU,
39 : RELU6,
40 : TANH,
41 : SIGMOID,
42 : LEAKYRELU,
43 : SILU
44 : };
45 :
46 : enum class ModelFlags {
47 : None = 0,
48 : HalfPrecision = 1 << 0,
49 : Range01 = 1 << 1,
50 : Trainable = 1 << 2,
51 : };
52 :
53 : SP_DEFINE_ENUM_AS_MASK(ModelFlags)
54 :
55 : class Layer;
56 : class Attachment;
57 :
58 : class Model : public Ref {
59 : public:
60 : static void saveBlob(const char *, const uint8_t *, size_t);
61 : static void loadBlob(const char *, const std::function<void(const uint8_t *, size_t)> &);
62 : static bool compareBlob(const uint8_t *, size_t, const uint8_t *, size_t, float = std::numeric_limits<float>::epsilon());
63 : static bool compareBlob(const float *, size_t, const float *, size_t, float = std::numeric_limits<float>::epsilon());
64 :
65 :
66 : virtual ~Model();
67 :
68 : virtual bool init(ModelFlags, const Value &val, uint32_t numLayers, StringView dataFilePath);
69 :
70 : virtual void addLayer(Rc<Layer> &&);
71 :
72 : virtual bool link();
73 :
74 0 : bool isHalfPrecision() const { return (_flags & ModelFlags::HalfPrecision) != ModelFlags::None; }
75 2 : bool isTrainable() const { return (_flags & ModelFlags::Trainable) != ModelFlags::None; }
76 0 : bool usesDataFile() const { return _dataFile ? true : false; }
77 :
78 : float readFloatData();
79 :
80 1214 : const Vector<Layer *> &getSortedLayers() const { return _sortedLayers; }
81 :
82 : Vector<Layer *> getInputs() const;
83 :
84 : float getLastLoss() const;
85 :
86 5 : Random &getRand() { return _rand; }
87 :
88 : protected:
89 : void linkInput(Vector<Layer *> &, Layer *, Attachment *);
90 :
91 : ModelFlags _flags;
92 : uint32_t _numLayers;
93 :
94 : int32_t _inputWidth = -1;
95 : int32_t _inputHeight = -1;
96 : int32_t _inputChannels = -1;
97 : int32_t _upscale = -1;
98 : bool _useSubPixel = false;
99 :
100 : filesystem::File _dataFile;
101 : Map<uint32_t, Rc<Layer>> _layers;
102 : Vector<Layer *> _sortedLayers;
103 : Vector<Rc<Attachment>> _attachments;
104 :
105 : Random _rand = Random( 451 );
106 : };
107 :
108 : Activation getActivationValue(StringView);
109 :
110 : float convertToMediumPrecision(float in);
111 : float convertToHighPrecision(uint16_t in);
112 : void convertToMediumPrecision(Vector<float> &in);
113 : void convertToMediumPrecision(Vector<double> &in);
114 : void getByteRepresentation(float in, Vector<unsigned char> &byteRep, bool fp16);
115 :
116 : }
117 :
118 : #endif /* SRC_PROCESSOR_XLSNNMODEL_H_ */
|