Line data Source code
1 : /** 2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev> 3 : 4 : Permission is hereby granted, free of charge, to any person obtaining a copy 5 : of this software and associated documentation files (the "Software"), to deal 6 : in the Software without restriction, including without limitation the rights 7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 : copies of the Software, and to permit persons to whom the Software is 9 : furnished to do so, subject to the following conditions: 10 : 11 : The above copyright notice and this permission notice shall be included in 12 : all copies or substantial portions of the Software. 13 : 14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 : THE SOFTWARE. 21 : **/ 22 : 23 : #ifndef SRC_BACKEND_VK_XLSNNVKTRAINABLELAYER_H_ 24 : #define SRC_BACKEND_VK_XLSNNVKTRAINABLELAYER_H_ 25 : 26 : #include "XLVkQueuePass.h" 27 : #include "XLVkAttachment.h" 28 : 29 : namespace stappler::xenolith::vk::shadernn { 30 : 31 : class TrainableLayer : public vk::QueuePass { 32 : public: 33 : enum VariableIndex { 34 : TV_MomentDecayRateVar = 0, 35 : TV_OpMomentDecayRateVar, 36 : TV_OpRegL2MomentDecayRateVar, 37 : TV_RateVar, 38 : TV_L1Threshold, 39 : TV_L1Mult, 40 : TV_Count 41 : }; 42 : 43 : virtual ~TrainableLayer(); 44 : 45 : virtual void initPropagation(Queue::Builder &queueBuilder, QueuePassBuilder &, 46 : const core::AttachmentData *source, uint32_t bufferIndex); 47 : 48 : virtual void initPropagationSubpass(Queue::Builder &queueBuilder, core::QueuePassBuilder &, 49 : core::SubpassBuilder &, const core::PipelineLayoutData *); 50 : 51 0 : virtual uint32_t getPropagationSubpassBufferCount() const { return _fullPropagationBuffers; } 52 : 53 : virtual Vector<const core::BufferData *> getTrainableGradients(Queue::Builder &queueBuilder) const; 54 : 55 : virtual bool isBackwardNeeded() const; 56 : 57 7200 : const AttachmentData *getInputAttachment() const { return _inputAttachment; } 58 10800 : const AttachmentData *getOutputAttachment() const { return _outputAttachment; } 59 10800 : const AttachmentData *getWeightsAttachment() const { return _weightsAttachment; } 60 7200 : const AttachmentData *getPropagationAttachment() const { return _propagationAttachment; } 61 : uint32_t getTargetPropagationBufferIdx() const { return _targetPropagationIdx; } 62 : 63 3600 : const AttachmentData *getExternalPropagationDataSource() const { return _externalPropagationDataSource; } 64 3600 : uint32_t getExternalPropagationBufferIdx() const { return _externalPropagationBufferIdx; } 65 : 66 : protected: 67 : using QueuePass::init; 68 : 69 : const AttachmentData *_inputAttachment = nullptr; 70 : const AttachmentData *_outputAttachment = nullptr; 71 : const AttachmentData *_weightsAttachment = nullptr; 72 : 73 : const AttachmentData *_propagationAttachment = nullptr; 74 : uint32_t _targetPropagationIdx = 2; 75 : 76 : const AttachmentData *_externalPropagationDataSource = nullptr; 77 : uint32_t _externalPropagationBufferIdx = 0; 78 : 79 : uint32_t _staticPropagationBuffers = 1; 80 : uint32_t _fullPropagationBuffers = 2; 81 : 82 : float _momentDecayRate = 0.9f; 83 : float _learningRate = 0.01f; 84 : float _regularizationL2 = 0.f; 85 : float _regularizationL1 = 0.f; 86 : float _maxGradientNorm = -1.f; 87 : }; 88 : 89 : } 90 : 91 : #endif /* SRC_BACKEND_VK_XLSNNVKTRAINABLELAYER_H_ */