Line data Source code
1 : /** 2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev> 3 : 4 : Permission is hereby granted, free of charge, to any person obtaining a copy 5 : of this software and associated documentation files (the "Software"), to deal 6 : in the Software without restriction, including without limitation the rights 7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 : copies of the Software, and to permit persons to whom the Software is 9 : furnished to do so, subject to the following conditions: 10 : 11 : The above copyright notice and this permission notice shall be included in 12 : all copies or substantial portions of the Software. 13 : 14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 : THE SOFTWARE. 21 : **/ 22 : 23 : #ifndef SRC_LAYERS_XLSNNSTATPERCENTLAYER_H_ 24 : #define SRC_LAYERS_XLSNNSTATPERCENTLAYER_H_ 25 : 26 : #include "XLSnnLayer.h" 27 : 28 : namespace stappler::xenolith::shadernn { 29 : 30 : class StatPercentLayer : public Layer { 31 : public: 32 0 : virtual ~StatPercentLayer() = default; 33 : 34 : virtual bool init(Model *, StringView tag, size_t idx, const Value&) override; 35 : 36 0 : virtual Extent3 getOutputExtent() const override { 37 0 : return Extent3(4, _classCount, 1); 38 : } 39 : 40 : virtual const core::QueuePassData *prepare(core::Queue::Builder &builder, 41 : Map<Layer *, const core::AttachmentData *> inputs, 42 : Map<Attachment *, const core::AttachmentData *> attachments) override; 43 : 44 : virtual const core::AttachmentData *makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) override; 45 : 46 0 : uint32_t getFieldClass() const { return _fieldClass; } 47 0 : uint32_t getFieldSource() const { return _fieldSource; } 48 0 : uint32_t getFieldTarget() const { return _fieldTarget; } 49 0 : uint32_t getClassMin() const { return _classMin; } 50 0 : uint32_t getClassCount() const { return _classCount; } 51 : 52 : protected: 53 : uint32_t _fieldClass = 0; 54 : uint32_t _fieldSource = 1; 55 : uint32_t _fieldTarget = 2; 56 : uint32_t _classMin = 0; 57 : uint32_t _classCount = 100; 58 : }; 59 : 60 : class StatAnalysisLayer : public Layer { 61 : public: 62 0 : virtual ~StatAnalysisLayer() = default; 63 : 64 : virtual bool init(Model *, StringView tag, size_t idx, const Value&) override; 65 : 66 : virtual void setInputExtent(uint32_t index, Attachment *, Extent3 e) override; 67 : 68 0 : virtual Extent3 getOutputExtent() const override { 69 0 : return Extent3(4, 1, 1); 70 : } 71 : 72 : virtual const core::QueuePassData *prepare(core::Queue::Builder &builder, 73 : Map<Layer *, const core::AttachmentData *> inputs, 74 : Map<Attachment *, const core::AttachmentData *> attachments) override; 75 : 76 : virtual const core::AttachmentData *makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) override; 77 : 78 0 : uint32_t getFieldClass() const { return _percent->getFieldClass(); } 79 0 : uint32_t getFieldSource() const { return _percent->getFieldSource(); } 80 0 : uint32_t getFieldTarget() const { return _percent->getFieldTarget(); } 81 0 : uint32_t getClassMin() const { return _percent->getClassMin(); } 82 0 : uint32_t getClassCount() const { return _percent->getClassCount(); } 83 0 : float getThreshold() const { return _threshold; } 84 : 85 : StatPercentLayer *getPercentLayer() const { return _percent; } 86 : 87 : protected: 88 : StatPercentLayer *_percent = nullptr; 89 : float _threshold = 1; 90 : }; 91 : 92 : } 93 : 94 : #endif /* SRC_LAYERS_XLSNNSTATPERCENTLAYER_H_ */