Line data Source code
1 : /** 2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev> 3 : 4 : Permission is hereby granted, free of charge, to any person obtaining a copy 5 : of this software and associated documentation files (the "Software"), to deal 6 : in the Software without restriction, including without limitation the rights 7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 : copies of the Software, and to permit persons to whom the Software is 9 : furnished to do so, subject to the following conditions: 10 : 11 : The above copyright notice and this permission notice shall be included in 12 : all copies or substantial portions of the Software. 13 : 14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 : THE SOFTWARE. 21 : **/ 22 : 23 : #ifndef SRC_BACKEND_VK_XLSNNVKMATRIXMULLAYER_H_ 24 : #define SRC_BACKEND_VK_XLSNNVKMATRIXMULLAYER_H_ 25 : 26 : #include "XLVkRenderPass.h" 27 : #include "XLSnnVkTrainableLayer.h" 28 : #include "XLSnnMatrixMulLayer.h" 29 : 30 : namespace stappler::xenolith::vk::shadernn { 31 : 32 : class MatrixMulLayer : public TrainableLayer { 33 : public: 34 : using Front = xenolith::shadernn::MatrixMulLayer; 35 : 36 : virtual ~MatrixMulLayer(); 37 : 38 : virtual bool init(Queue::Builder &queueBuilder, QueuePassBuilder &, Front *, 39 : const AttachmentData *input, const AttachmentData *output); 40 : 41 : virtual void initPropagationSubpass(core::Queue::Builder &builder, core::QueuePassBuilder &, core::SubpassBuilder &subpass, 42 : const core::PipelineLayoutData *layout) override; 43 : 44 : virtual Vector<const core::BufferData *> getTrainableGradients(Queue::Builder &queueBuilder) const; 45 : 46 25200 : const Front *getFront() const { return _front; } 47 : 48 42 : virtual uint32_t getPropagationSubpassBufferCount() const override { return 12; } 49 : protected: 50 : using QueuePass::init; 51 : 52 : Rc<Front> _front; 53 : 54 : uint32_t _nbuffers = 4; 55 : uint32_t _weightsBufferIndex = 0; 56 : uint32_t _freeTermBufferIndex = 1; 57 : uint32_t _inputBufferIndex = 2; 58 : uint32_t _outputBufferIndex = 3; 59 : 60 : uint32_t _staticParams = 0; 61 : uint32_t _staticWeightsHistoryIndex = 1; 62 : uint32_t _staticTermsHistoryIndex = 2; 63 : 64 : uint32_t _propWeightsIndex = 0; 65 : uint32_t _propTermsIndex = 0; 66 : uint32_t _propOriginalOutput = 0; 67 : uint32_t _propOriginalInput = 0; 68 : uint32_t _propOutputDiff = 0; 69 : uint32_t _propTargetIndex = 0; 70 : uint32_t _propWeightsDiff = 0; 71 : uint32_t _propTermsDiff = 0; 72 : uint32_t _propFeedback = 0; 73 : }; 74 : 75 : } 76 : 77 : #endif /* SRC_BACKEND_VK_XLSNNVKMATRIXMULLAYER_H_ */