LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/layers - XLSnnLossLayer.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 6 7 85.7 %
Date: 2024-05-06 04:51:23 Functions: 6 8 75.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #ifndef SRC_LAYERS_XLSNNLOSSLAYER_H_
      24             : #define SRC_LAYERS_XLSNNLOSSLAYER_H_
      25             : 
      26             : #include "XLSnnLayer.h"
      27             : 
      28             : namespace stappler::xenolith::shadernn {
      29             : 
      30             : class LossLayer : public Layer {
      31             : public:
      32             :         enum ParameterIndex {
      33             :                 P_LossWeight = 0, // the weight for the loss function
      34             :                 P_Loss, // the loss value on the last step
      35             :                 P_LossDivider, // the averaging factor for calculating the loss value
      36             :                 P_LossGradientDivider, // the averaging factor for calculating the loss gradient (takes lossWeight into account)
      37             :                 P_MinGradient,
      38             :                 P_MaxGradient,
      39             :                 P_Count
      40             :         };
      41             : 
      42           2 :         SpanView<float> getParameters() const { return _params; }
      43             :         BytesView getParametersBuffer() const { return BytesView((const uint8_t *)_params.data(), _params.size() * sizeof(float)); }
      44             :         void synchronizeParameters(SpanView<float>);
      45             : 
      46             :         void setParameter(ParameterIndex, float);
      47             :         float getParameter(ParameterIndex) const;
      48             : 
      49             : protected:
      50             :         std::array<float, P_Count> _params;
      51             : };
      52             : 
      53             : class CrossEntropyLossLayer : public LossLayer {
      54             : public:
      55           0 :         virtual ~CrossEntropyLossLayer() = default;
      56             : 
      57             :         virtual bool init(Model *, StringView tag, size_t idx, const Value&) override;
      58             : 
      59             :         virtual void setInputExtent(uint32_t index, Attachment *, Extent3 e) override;
      60             : 
      61             :         virtual const core::QueuePassData *prepare(core::Queue::Builder &builder,
      62             :                         Map<Layer *, const core::AttachmentData *> inputs,
      63             :                         Map<Attachment *, const core::AttachmentData *> attachments) override;
      64             : 
      65             :         virtual const core::AttachmentData *makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) override;
      66             : 
      67             :         Layer *getInputLabels() const { return _inputLabels; }
      68       14400 :         uint32_t getBatchSize() const { return _batchSize; }
      69       13200 :         uint32_t getClassesCount() const { return _classesCount; }
      70             : 
      71           1 :         uint32_t getWeightBufferSize() const { return sizeof(float) * _batchSize; }
      72           1 :         uint32_t getResultBufferSize() const { return sizeof(float) * _batchSize; }
      73           1 :         uint32_t getLossGradientBufferSize() const { return sizeof(float) * _batchSize * _classesCount; }
      74             : 
      75             : protected:
      76             :         uint32_t _batchSize = 100;
      77             :         uint32_t _classesCount = 10;
      78             :         String _labelsInputName;
      79             :         Layer *_inputLabels = nullptr;
      80             :         Layer *_inputNetwork = nullptr;
      81             : };
      82             : 
      83             : }
      84             : 
      85             : #endif /* SRC_LAYERS_XLSNNLOSSLAYER_H_ */

Generated by: LCOV version 1.14