Line data Source code
1 : /** 2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev> 3 : 4 : Permission is hereby granted, free of charge, to any person obtaining a copy 5 : of this software and associated documentation files (the "Software"), to deal 6 : in the Software without restriction, including without limitation the rights 7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 : copies of the Software, and to permit persons to whom the Software is 9 : furnished to do so, subject to the following conditions: 10 : 11 : The above copyright notice and this permission notice shall be included in 12 : all copies or substantial portions of the Software. 13 : 14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 : THE SOFTWARE. 21 : **/ 22 : 23 : #include "XLSnnVkTrainableLayer.h" 24 : 25 : namespace stappler::xenolith::vk::shadernn { 26 : 27 3 : TrainableLayer::~TrainableLayer() { } 28 : 29 3 : void TrainableLayer::initPropagation(Queue::Builder &queueBuilder, QueuePassBuilder &builder, const core::AttachmentData *source, uint32_t idx) { 30 : using namespace core; 31 : 32 6 : auto paramsBuf = queueBuilder.addBuffer(toString(getName(), "_trainingParams"), BufferInfo(BufferUsage::StorageBuffer, size_t(TV_Count * sizeof(float))), 33 3 : [this] (uint8_t *buf, uint64_t size, const BufferData::DataCallback &) { 34 3 : auto target = (float *)buf; 35 : 36 3 : float rate = _learningRate; 37 3 : float regL1 = _regularizationL1; 38 3 : float regL2 = _regularizationL2; 39 : 40 3 : target[TV_MomentDecayRateVar] = _momentDecayRate; 41 3 : target[TV_OpMomentDecayRateVar] = 1 - _momentDecayRate; 42 3 : target[TV_OpRegL2MomentDecayRateVar] = -rate * regL2; 43 3 : target[TV_RateVar] = (-rate); 44 3 : target[TV_L1Threshold] = regL1; 45 3 : target[TV_L1Mult] = -rate; 46 3 : }); 47 : 48 3 : auto trainable = getTrainableGradients(queueBuilder); 49 3 : trainable.emplace(trainable.begin(), paramsBuf); 50 : 51 3 : _staticPropagationBuffers = trainable.size(); 52 : 53 3 : _propagationAttachment = queueBuilder.addAttachemnt(toString(getName(), "_BackwardOnce_data"), 54 3 : [&] (AttachmentBuilder &builder) -> Rc<core::Attachment> { 55 6 : return Rc<vk::BufferAttachment>::create(builder, move(trainable)); 56 : }); 57 : 58 3 : _externalPropagationDataSource = source; 59 3 : _externalPropagationBufferIdx = idx; 60 : 61 3 : auto propagationPassAttachment = builder.addAttachment(_propagationAttachment); 62 : 63 3 : auto l = builder.addDescriptorLayout([&] (PipelineLayoutBuilder &layout) { 64 3 : layout.addSet([&] (DescriptorSetBuilder &set) { 65 6 : set.addDescriptorArray(propagationPassAttachment, getPropagationSubpassBufferCount(), DescriptorType::StorageBuffer); 66 3 : }); 67 6 : }); 68 : 69 3 : builder.addSubpass([&] (SubpassBuilder &subpass) { 70 3 : initPropagationSubpass(queueBuilder, builder, subpass, l); 71 3 : }); 72 : 73 3 : const core::QueuePassData *pass = _inputAttachment->passes.front()->pass; 74 3 : if (_inputAttachment->passes.front()->pass->pass != this) { 75 3 : auto trainable = dynamic_cast<TrainableLayer *>(pass->pass.get()); 76 3 : if (trainable) { 77 2 : trainable->initPropagation(queueBuilder, builder, _propagationAttachment, _targetPropagationIdx); 78 : } 79 : } 80 3 : } 81 : 82 0 : void TrainableLayer::initPropagationSubpass(Queue::Builder &queueBuilder, core::QueuePassBuilder &, 83 : core::SubpassBuilder &, const core::PipelineLayoutData *) { 84 : 85 0 : } 86 : 87 0 : Vector<const core::BufferData *> TrainableLayer::getTrainableGradients(Queue::Builder &queueBuilder) const { 88 0 : return Vector<const core::BufferData *>(); 89 : } 90 : 91 3 : bool TrainableLayer::isBackwardNeeded() const { 92 3 : const core::QueuePassData *pass = _inputAttachment->passes.front()->pass; 93 3 : if (_inputAttachment->passes.front()->pass->pass != this && dynamic_cast<TrainableLayer *>(pass->pass.get())) { 94 2 : return true; 95 : } 96 1 : return false; 97 : } 98 : 99 : }