Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnStatPercentLayer.h"
24 : #include "XLSnnVkStatPercentLayer.h"
25 :
26 : namespace stappler::xenolith::shadernn {
27 :
28 0 : bool StatPercentLayer::init(Model *m, StringView tag, size_t idx, const Value &data) {
29 0 : if (!Layer::init(m, tag, idx, data)) {
30 0 : return false;
31 : }
32 :
33 0 : _fieldClass = data.getInteger("fieldClass");
34 0 : _fieldSource = data.getInteger("fieldSource");
35 0 : _fieldTarget = data.getInteger("fieldTarget");
36 0 : _classMin = data.getInteger("classMin");
37 0 : _classCount = data.getInteger("classCount");
38 :
39 0 : return true;
40 : }
41 :
42 0 : const core::QueuePassData *StatPercentLayer::prepare(core::Queue::Builder &builder,
43 : Map<Layer *, const core::AttachmentData *> inputs,
44 : Map<Attachment *, const core::AttachmentData *> attachments) {
45 0 : auto inputIt = attachments.find(_inputs.front().attachment);
46 0 : auto outputIt = attachments.find(getOutput());
47 :
48 0 : if (inputIt == attachments.end() || outputIt == attachments.end()) {
49 0 : log::error("snn::InputLayer", "No attachments specified");
50 0 : return nullptr;
51 : }
52 :
53 0 : return builder.addPass(getName(), core::PassType::Compute, core::RenderOrdering(_inputIndex),
54 0 : [&] (core::QueuePassBuilder &passBuilder) -> Rc<core::QueuePass> {
55 0 : return Rc<vk::shadernn::StatPercentLayer>::create(builder, passBuilder, this, inputIt->second, outputIt->second);
56 0 : });
57 : }
58 :
59 0 : const core::AttachmentData *StatPercentLayer::makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) {
60 0 : return builder.addAttachemnt(toString(getName(), "_output"),
61 0 : [&] (core::AttachmentBuilder &attachmentBuilder) -> Rc<core::Attachment> {
62 0 : if (isGlobalOutput) {
63 0 : attachmentBuilder.defineAsOutput();
64 : }
65 0 : return Rc<vk::BufferAttachment>::create(attachmentBuilder,
66 0 : core::BufferInfo(size_t(_classCount * 4 * sizeof(float)),
67 0 : core::BufferUsage::StorageBuffer | core::BufferUsage::TransferDst, core::PassType::Compute)
68 0 : );
69 0 : });
70 : }
71 :
72 0 : bool StatAnalysisLayer::init(Model *m, StringView tag, size_t idx, const Value &data) {
73 0 : if (!Layer::init(m, tag, idx, data)) {
74 0 : return false;
75 : }
76 :
77 0 : _threshold = float(data.getDouble("threshold"));
78 :
79 0 : return true;
80 : }
81 :
82 0 : void StatAnalysisLayer::setInputExtent(uint32_t index, Attachment *a, Extent3 e) {
83 0 : Layer::setInputExtent(index, a, e);
84 :
85 0 : auto outBy = a->getOutputBy();
86 0 : if (auto l = dynamic_cast<StatPercentLayer *>(outBy)) {
87 0 : _percent = l;
88 : }
89 0 : }
90 :
91 0 : const core::QueuePassData *StatAnalysisLayer::prepare(core::Queue::Builder &builder,
92 : Map<Layer *, const core::AttachmentData *> inputs,
93 : Map<Attachment *, const core::AttachmentData *> attachments) {
94 0 : auto inputDataIt = attachments.find(_inputs[0].attachment);
95 0 : auto inputClasseIt = attachments.find(_inputs[1].attachment);
96 0 : auto outputIt = attachments.find(getOutput());
97 :
98 0 : if (inputDataIt == attachments.end() || inputClasseIt == attachments.end() || outputIt == attachments.end()) {
99 0 : log::error("snn::InputLayer", "No attachments specified");
100 0 : return nullptr;
101 : }
102 :
103 0 : return builder.addPass(getName(), core::PassType::Compute, core::RenderOrdering(_inputIndex),
104 0 : [&] (core::QueuePassBuilder &passBuilder) -> Rc<core::QueuePass> {
105 0 : return Rc<vk::shadernn::StatAnalysisLayer>::create(builder, passBuilder, this, inputDataIt->second, inputClasseIt->second, outputIt->second);
106 0 : });
107 : }
108 :
109 0 : const core::AttachmentData *StatAnalysisLayer::makeOutputAttachment(core::Queue::Builder &builder, bool isGlobalOutput) {
110 0 : return builder.addAttachemnt(toString(getName(), "_output"),
111 0 : [&] (core::AttachmentBuilder &attachmentBuilder) -> Rc<core::Attachment> {
112 0 : if (isGlobalOutput) {
113 0 : attachmentBuilder.defineAsOutput();
114 : }
115 0 : return Rc<vk::BufferAttachment>::create(attachmentBuilder,
116 0 : core::BufferInfo(size_t(1 * 4 * sizeof(float)),
117 0 : core::BufferUsage::StorageBuffer | core::BufferUsage::TransferDst, core::PassType::Compute)
118 0 : );
119 0 : });
120 : }
121 :
122 : }
|