Line data Source code
1 : /** 2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev> 3 : 4 : Permission is hereby granted, free of charge, to any person obtaining a copy 5 : of this software and associated documentation files (the "Software"), to deal 6 : in the Software without restriction, including without limitation the rights 7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 : copies of the Software, and to permit persons to whom the Software is 9 : furnished to do so, subject to the following conditions: 10 : 11 : The above copyright notice and this permission notice shall be included in 12 : all copies or substantial portions of the Software. 13 : 14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 : THE SOFTWARE. 21 : **/ 22 : 23 : #ifndef SRC_BACKEND_VK_XLSNNVKLOSSLAYER_H_ 24 : #define SRC_BACKEND_VK_XLSNNVKLOSSLAYER_H_ 25 : 26 : #include "XLVkRenderPass.h" 27 : #include "XLVkQueuePass.h" 28 : #include "XLVkAttachment.h" 29 : #include "XLSnnLossLayer.h" 30 : 31 : namespace stappler::xenolith::vk::shadernn { 32 : 33 : class CrossEntropyLossLayer : public vk::QueuePass { 34 : public: 35 : using Front = xenolith::shadernn::CrossEntropyLossLayer; 36 : 37 : using BufferView = BufferAttachmentHandle::BufferView; 38 : 39 : using PipelineOpFn = Function<void(Front *, CommandBuffer &, ComputePipeline *, SpanView<BufferView>)>; 40 : 41 : enum class PipelineOpIndex { 42 : MatrixSoftmaxByRows, // softmax to activation 43 : VectorNegLog, // transform activation 44 : VectorEltwiseMultiply, // multiply activation on labels 45 : SumMatrixColumnsToResult, // sum activation columns to result 46 : VectorSub, // compute diff fro activation to label 47 : SumMatrixColumnsLabels, // sum labels 48 : MultiplyDiagMatrixByMatrix, // calc gradient 49 : VectorDotProduct, // loss value function 50 : MultiplyDiagMatrixByMatrixForInput // prepare error propagation 51 : }; 52 : 53 : struct PipelineOp { 54 : PipelineOpIndex idx; 55 : const core::ComputePipelineData *pipeline; 56 : PipelineOpFn command; 57 : 58 9 : PipelineOp(PipelineOpIndex i, const core::ComputePipelineData *p, PipelineOpFn &&f) 59 9 : : idx(i), pipeline(p), command(move(f)) { }; 60 : }; 61 : 62 : static constexpr uint32_t DescriptorArraySize = 8; 63 : static constexpr uint32_t ParamsIdx = 0; 64 : static constexpr uint32_t WeightsIdx = 1; 65 : static constexpr uint32_t LossValueIdx = 2; 66 : static constexpr uint32_t LossGradientIdx = 3; 67 : static constexpr uint32_t InputNetworkIdx = 4; 68 : static constexpr uint32_t InputLabelsIdx = 5; 69 : static constexpr uint32_t ActivationIdx = 6; 70 : static constexpr uint32_t ActivationEltwiseMulIdx = 7; 71 : 72 : virtual ~CrossEntropyLossLayer(); 73 : 74 : virtual bool init(Queue::Builder &queueBuilder, QueuePassBuilder &, Front *, 75 : const AttachmentData *inputLabels, const AttachmentData *inputNetwork, const AttachmentData *output); 76 : 77 : void initPropagation(Queue::Builder &queueBuilder, QueuePassBuilder &); 78 : 79 : void runAll(CommandBuffer &, SpanView<BufferView>); 80 : 81 1200 : const AttachmentData *getInputLabelsAttachment() const { return _inputLabelsAttachment; } 82 1200 : const AttachmentData *getInputNetworkAttachment() const { return _inputNetworkAttachment; } 83 3600 : const AttachmentData *getWeightsAttachment() const { return _weightAttachment; } 84 1200 : const AttachmentData *getOutputAttachment() const { return _outputAttachment; } 85 : 86 1200 : const Front *getFront() const { return _front; } 87 : 88 : protected: 89 : using QueuePass::init; 90 : 91 : const AttachmentData *_inputLabelsAttachment = nullptr; 92 : const AttachmentData *_inputNetworkAttachment = nullptr; 93 : const AttachmentData *_weightAttachment = nullptr; 94 : const AttachmentData *_outputAttachment = nullptr; 95 : 96 : Rc<Front> _front; 97 : Map<PipelineOpIndex, PipelineOp> _pipelineOps; 98 : }; 99 : 100 : } 101 : 102 : #endif /* SRC_BACKEND_VK_XLSNNVKLOSSLAYER_H_ */