Line data Source code
1 : /** 2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev> 3 : 4 : Permission is hereby granted, free of charge, to any person obtaining a copy 5 : of this software and associated documentation files (the "Software"), to deal 6 : in the Software without restriction, including without limitation the rights 7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 : copies of the Software, and to permit persons to whom the Software is 9 : furnished to do so, subject to the following conditions: 10 : 11 : The above copyright notice and this permission notice shall be included in 12 : all copies or substantial portions of the Software. 13 : 14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 : THE SOFTWARE. 21 : **/ 22 : 23 : #ifndef SRC_LAYERS_XLSNNLOSSLAYER_H_ 24 : #define SRC_LAYERS_XLSNNLOSSLAYER_H_ 25 : 26 : #include "XLSnnLayer.h" 27 : 28 : namespace stappler::xenolith::shadernn { 29 : 30 : class LossLayer : public Layer { 31 : public: 32 : enum ParameterIndex { 33 : P_LossWeight = 0, // the weight for the loss function 34 : P_Loss, // the loss value on the last step 35 : P_LossDivider, // the averaging factor for calculating the loss value 36 : P_LossGradientDivider, // the averaging factor for calculating the loss gradient (takes lossWeight into account) 37 : P_MinGradient, 38 : P_MaxGradient, 39 : P_Count 40 : }; 41 : 42 2 : SpanView<float> getParameters() const { return _params; } 43 : BytesView getParametersBuffer() const { return BytesView((const uint8_t *)_params.data(), _params.size() * sizeof(float)); } 44 : void synchronizeParameters(SpanView<float>); 45 : 46 : void setParameter(ParameterIndex, float); 47 : float getParameter(ParameterIndex) const; 48 : 49 : protected: 50 : std::array<float, P_Count> _params; 51 : }; 52 : 53 : class CrossEntropyLossLayer : public LossLayer { 54 : public: 55 0 : virtual ~CrossEntropyLossLayer() = default; 56 : 57 : virtual bool init(Model *, StringView tag, size_t idx, const Value&) override; 58 : 59 : virtual void setInputExtent(uint32_t index, Attachment *, Extent3 e) override; 60 : 61 : virtual const core::QueuePassData *prepare(core::Queue::Builder &builder, 62 : Map<Layer *, const core::AttachmentData *> inputs, 63 : Map<Attachment *, const core::AttachmentData *> attachments) override; 64 : 65 : virtual const core::AttachmentData *makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) override; 66 : 67 : Layer *getInputLabels() const { return _inputLabels; } 68 14400 : uint32_t getBatchSize() const { return _batchSize; } 69 13200 : uint32_t getClassesCount() const { return _classesCount; } 70 : 71 1 : uint32_t getWeightBufferSize() const { return sizeof(float) * _batchSize; } 72 1 : uint32_t getResultBufferSize() const { return sizeof(float) * _batchSize; } 73 1 : uint32_t getLossGradientBufferSize() const { return sizeof(float) * _batchSize * _classesCount; } 74 : 75 : protected: 76 : uint32_t _batchSize = 100; 77 : uint32_t _classesCount = 10; 78 : String _labelsInputName; 79 : Layer *_inputLabels = nullptr; 80 : Layer *_inputNetwork = nullptr; 81 : }; 82 : 83 : } 84 : 85 : #endif /* SRC_LAYERS_XLSNNLOSSLAYER_H_ */