Line data Source code
1 : /** 2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev> 3 : 4 : Permission is hereby granted, free of charge, to any person obtaining a copy 5 : of this software and associated documentation files (the "Software"), to deal 6 : in the Software without restriction, including without limitation the rights 7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 : copies of the Software, and to permit persons to whom the Software is 9 : furnished to do so, subject to the following conditions: 10 : 11 : The above copyright notice and this permission notice shall be included in 12 : all copies or substantial portions of the Software. 13 : 14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 : THE SOFTWARE. 21 : **/ 22 : 23 : #ifndef SRC_LAYERS_XLSNNMATRIXMULLAYER_H_ 24 : #define SRC_LAYERS_XLSNNMATRIXMULLAYER_H_ 25 : 26 : #include "XLSnnLayer.h" 27 : 28 : namespace stappler::xenolith::shadernn { 29 : 30 : class MatrixMulLayer : public Layer { 31 : public: 32 0 : virtual ~MatrixMulLayer() = default; 33 : 34 : virtual bool init(Model *, StringView tag, size_t idx, const Value&) override; 35 : 36 0 : virtual bool isTrainable() const { return _model->isTrainable(); } 37 : 38 19214 : virtual Extent3 getOutputExtent() const override { 39 19214 : return Extent3(_kernelSize, 1, _batchSize); 40 : } 41 : 42 : virtual void setInputExtent(uint32_t index, Attachment *, Extent3 e) override; 43 : 44 18006 : Extent3 getWeightSize() const { 45 18006 : auto ret = Layer::getOutputExtent(); 46 18006 : return Extent3(ret.width, _kernelSize, 1); 47 : } 48 : 49 7206 : size_t getWeightBufferSize() const { 50 7206 : auto extent = getWeightSize(); 51 7206 : return extent.width * extent.height * extent.depth * sizeof(float); 52 : } 53 : 54 : virtual const core::QueuePassData *prepare(core::Queue::Builder &builder, 55 : Map<Layer *, const core::AttachmentData *> inputs, 56 : Map<Attachment *, const core::AttachmentData *> attachments) override; 57 : 58 : virtual const core::AttachmentData *makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) override; 59 : 60 7203 : Activation getActivation() const { return _activation; } 61 13206 : uint32_t getKernelSize() const { return _kernelSize; } 62 : 63 3600 : Layer *getInput() const { return _input; } 64 : 65 : void generateWeights(uint8_t *buf, uint64_t size, const core::BufferData::DataCallback &cb) const; 66 : void generateFreeTerms(uint8_t *buf, uint64_t size, const core::BufferData::DataCallback &cb) const; 67 : 68 : protected: 69 : Activation _activation = Activation::RELU; 70 : uint32_t _kernelSize = 2; 71 : uint32_t _batchSize = 128; 72 : Layer *_input = nullptr; 73 : }; 74 : 75 : } 76 : 77 : #endif /* SRC_LAYERS_XLSNNMATRIXMULLAYER_H_ */