Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLCommon.h"
24 : #include "XLSnnVkShaders.h"
25 :
26 : namespace stappler::xenolith::vk::shadernn {
27 :
28 : #include "gen_f32.comp.h"
29 : #include "norm_f32.comp.h"
30 : #include "vk_activation_f32.comp.h"
31 : #include "vk_add_f32.comp.h"
32 : #include "vk_avgpool2d_f32.comp.h"
33 : #include "vk_batchnorm_f32.comp.h"
34 : #include "vk_concat_f32.comp.h"
35 : #include "vk_conv2d_1x1_f32.comp.h"
36 : #include "vk_conv2d_f32.comp.h"
37 : #include "vk_dense_f32.comp.h"
38 : #include "vk_depthwise_f32.comp.h"
39 : #include "vk_flatten_f32.comp.h"
40 : #include "vk_instancenorm_f32.comp.h"
41 : #include "vk_maxpool2d_f32.comp.h"
42 : #include "vk_pad_f32.comp.h"
43 : #include "vk_resize_f32.comp.h"
44 : #include "vk_subpixel_f32.comp.h"
45 : #include "vk_unary_f32.comp.h"
46 : #include "vk_upsampling2d_bilinear_f32.comp.h"
47 : #include "vk_upsampling2d_nearest_f32.comp.h"
48 :
49 : #include "gen_f16.comp.h"
50 : #include "norm_f16.comp.h"
51 : #include "vk_activation_f16.comp.h"
52 : #include "vk_add_f16.comp.h"
53 : #include "vk_avgpool2d_f16.comp.h"
54 : #include "vk_batchnorm_f16.comp.h"
55 : #include "vk_concat_f16.comp.h"
56 : #include "vk_conv2d_1x1_f16.comp.h"
57 : #include "vk_conv2d_f16.comp.h"
58 : #include "vk_dense_f16.comp.h"
59 : #include "vk_depthwise_f16.comp.h"
60 : #include "vk_flatten_f16.comp.h"
61 : #include "vk_instancenorm_f16.comp.h"
62 : #include "vk_maxpool2d_f16.comp.h"
63 : #include "vk_pad_f16.comp.h"
64 : #include "vk_resize_f16.comp.h"
65 : #include "vk_subpixel_f16.comp.h"
66 : #include "vk_unary_f16.comp.h"
67 : #include "vk_upsampling2d_bilinear_f16.comp.h"
68 : #include "vk_upsampling2d_nearest_f16.comp.h"
69 :
70 : #include "AddVectorToMatrixRows.comp.h"
71 : #include "BufferNorm.comp.h"
72 : #include "MultiplyMatrixByMatrix.comp.h"
73 : #include "MultiplyMatrixByMatrixBorders.comp.h"
74 : #include "MultiplyMatrixByTransposedMatrix.comp.h"
75 : #include "MultiplyMatrixByTransposedMatrixBorders.comp.h"
76 : #include "MultiplyTransposedMatrixByMatrix.comp.h"
77 : #include "MultiplyTransposedMatrixByMatrixBorders.comp.h"
78 : #include "MatrixSoftmaxByRows.comp.h"
79 : #include "VectorAddFloat1.comp.h"
80 : #include "VectorAddFloat4.comp.h"
81 : #include "VectorReLU.comp.h"
82 : #include "VectorReLU4.comp.h"
83 : #include "VectorReLUDiff.comp.h"
84 : #include "VectorLog.comp.h"
85 : #include "VectorDotProduct.comp.h"
86 : #include "VectorMultiplyFloat.comp.h"
87 : #include "VectorMultiplyAndAdd.comp.h"
88 : #include "VectorSubFloat.comp.h"
89 : #include "SumMatrixColumns.comp.h"
90 : #include "SumMatrixRows.comp.h"
91 : #include "MultiplyDiagMatrixByMatrix.comp.h"
92 :
93 : #include "StatNorm.comp.h"
94 : #include "StatClassMap.comp.h"
95 : #include "StatClassPercent.comp.h"
96 : #include "StatAnalysis.comp.h"
97 :
98 : SpanView<uint32_t> GenF32Comp(reinterpret_cast<const uint32_t *>(gen_f32_comp), gen_f32_comp_len / sizeof(uint32_t));
99 : SpanView<uint32_t> NormF32Comp(reinterpret_cast<const uint32_t *>(norm_f32_comp), norm_f32_comp_len / sizeof(uint32_t));
100 : SpanView<uint32_t> ActivationF32Comp(reinterpret_cast<const uint32_t *>(vk_activation_f32_comp), vk_activation_f32_comp_len / sizeof(uint32_t));
101 : SpanView<uint32_t> AddF32Comp(reinterpret_cast<const uint32_t *>(vk_add_f32_comp), vk_add_f32_comp_len / sizeof(uint32_t));
102 : SpanView<uint32_t> Avgpool2dF32Comp(reinterpret_cast<const uint32_t *>(vk_avgpool2d_f32_comp), vk_avgpool2d_f32_comp_len / sizeof(uint32_t));
103 : SpanView<uint32_t> BatchnormF32Comp(reinterpret_cast<const uint32_t *>(vk_batchnorm_f32_comp), vk_batchnorm_f32_comp_len / sizeof(uint32_t));
104 : SpanView<uint32_t> ConcatF32Comp(reinterpret_cast<const uint32_t *>(vk_concat_f32_comp), vk_concat_f32_comp_len / sizeof(uint32_t));
105 : SpanView<uint32_t> Conv2d1x1F32Comp(reinterpret_cast<const uint32_t *>(vk_conv2d_1x1_f32_comp), vk_conv2d_1x1_f32_comp_len / sizeof(uint32_t));
106 : SpanView<uint32_t> Conv2dF32Comp(reinterpret_cast<const uint32_t *>(vk_conv2d_f32_comp), vk_conv2d_f32_comp_len / sizeof(uint32_t));
107 : SpanView<uint32_t> DenseF32Comp(reinterpret_cast<const uint32_t *>(vk_dense_f32_comp), vk_dense_f32_comp_len / sizeof(uint32_t));
108 : SpanView<uint32_t> DepthwiseF32Comp(reinterpret_cast<const uint32_t *>(vk_depthwise_f32_comp), vk_depthwise_f32_comp_len / sizeof(uint32_t));
109 : SpanView<uint32_t> FlattenF32Comp(reinterpret_cast<const uint32_t *>(vk_flatten_f32_comp), vk_flatten_f32_comp_len / sizeof(uint32_t));
110 : SpanView<uint32_t> InstancenormF32Comp(reinterpret_cast<const uint32_t *>(vk_instancenorm_f32_comp), vk_instancenorm_f32_comp_len / sizeof(uint32_t));
111 : SpanView<uint32_t> Maxpool2dF32Comp(reinterpret_cast<const uint32_t *>(vk_maxpool2d_f32_comp), vk_maxpool2d_f32_comp_len / sizeof(uint32_t));
112 : SpanView<uint32_t> PadF32Comp(reinterpret_cast<const uint32_t *>(vk_pad_f32_comp), vk_pad_f32_comp_len / sizeof(uint32_t));
113 : SpanView<uint32_t> ResizeF32Comp(reinterpret_cast<const uint32_t *>(vk_resize_f32_comp), vk_resize_f32_comp_len / sizeof(uint32_t));
114 : SpanView<uint32_t> SubpixelF32Comp(reinterpret_cast<const uint32_t *>(vk_subpixel_f32_comp), vk_subpixel_f32_comp_len / sizeof(uint32_t));
115 : SpanView<uint32_t> UnaryF32Comp(reinterpret_cast<const uint32_t *>(vk_unary_f32_comp), vk_unary_f32_comp_len / sizeof(uint32_t));
116 : SpanView<uint32_t> Upsampling2dBilinearF32Comp(reinterpret_cast<const uint32_t *>(vk_upsampling2d_bilinear_f32_comp), vk_upsampling2d_bilinear_f32_comp_len / sizeof(uint32_t));
117 : SpanView<uint32_t> Upsampling2dNearestF32Comp(reinterpret_cast<const uint32_t *>(vk_upsampling2d_nearest_f32_comp), vk_upsampling2d_nearest_f32_comp_len / sizeof(uint32_t));
118 :
119 : SpanView<uint32_t> GenF16Comp(reinterpret_cast<const uint32_t *>(gen_f16_comp), gen_f16_comp_len / sizeof(uint32_t));
120 : SpanView<uint32_t> NormF16Comp(reinterpret_cast<const uint32_t *>(norm_f16_comp), norm_f16_comp_len / sizeof(uint32_t));
121 : SpanView<uint32_t> ActivationF16Comp(reinterpret_cast<const uint32_t *>(vk_activation_f16_comp), vk_activation_f16_comp_len / sizeof(uint32_t));
122 : SpanView<uint32_t> AddF16Comp(reinterpret_cast<const uint32_t *>(vk_add_f16_comp), vk_add_f16_comp_len / sizeof(uint32_t));
123 : SpanView<uint32_t> Avgpool2dF16Comp(reinterpret_cast<const uint32_t *>(vk_avgpool2d_f16_comp), vk_avgpool2d_f16_comp_len / sizeof(uint32_t));
124 : SpanView<uint32_t> BatchnormF16Comp(reinterpret_cast<const uint32_t *>(vk_batchnorm_f16_comp), vk_batchnorm_f16_comp_len / sizeof(uint32_t));
125 : SpanView<uint32_t> ConcatF16Comp(reinterpret_cast<const uint32_t *>(vk_concat_f16_comp), vk_concat_f16_comp_len / sizeof(uint32_t));
126 : SpanView<uint32_t> Conv2d1x1F16Comp(reinterpret_cast<const uint32_t *>(vk_conv2d_1x1_f16_comp), vk_conv2d_1x1_f16_comp_len / sizeof(uint32_t));
127 : SpanView<uint32_t> Conv2dF16Comp(reinterpret_cast<const uint32_t *>(vk_conv2d_f16_comp), vk_conv2d_f16_comp_len / sizeof(uint32_t));
128 : SpanView<uint32_t> DenseF16Comp(reinterpret_cast<const uint32_t *>(vk_dense_f16_comp), vk_dense_f16_comp_len / sizeof(uint32_t));
129 : SpanView<uint32_t> DepthwiseF16Comp(reinterpret_cast<const uint32_t *>(vk_depthwise_f16_comp), vk_depthwise_f16_comp_len / sizeof(uint32_t));
130 : SpanView<uint32_t> FlattenF16Comp(reinterpret_cast<const uint32_t *>(vk_flatten_f16_comp), vk_flatten_f16_comp_len / sizeof(uint32_t));
131 : SpanView<uint32_t> InstancenormF16Comp(reinterpret_cast<const uint32_t *>(vk_instancenorm_f16_comp), vk_instancenorm_f16_comp_len / sizeof(uint32_t));
132 : SpanView<uint32_t> Maxpool2dF16Comp(reinterpret_cast<const uint32_t *>(vk_maxpool2d_f16_comp), vk_maxpool2d_f16_comp_len / sizeof(uint32_t));
133 : SpanView<uint32_t> PadF16Comp(reinterpret_cast<const uint32_t *>(vk_pad_f16_comp), vk_pad_f16_comp_len / sizeof(uint32_t));
134 : SpanView<uint32_t> ResizeF16Comp(reinterpret_cast<const uint32_t *>(vk_resize_f16_comp), vk_resize_f16_comp_len / sizeof(uint32_t));
135 : SpanView<uint32_t> SubpixelF16Comp(reinterpret_cast<const uint32_t *>(vk_subpixel_f16_comp), vk_subpixel_f16_comp_len / sizeof(uint32_t));
136 : SpanView<uint32_t> UnaryF16Comp(reinterpret_cast<const uint32_t *>(vk_unary_f16_comp), vk_unary_f16_comp_len / sizeof(uint32_t));
137 : SpanView<uint32_t> Upsampling2dBilinearF16Comp(reinterpret_cast<const uint32_t *>(vk_upsampling2d_bilinear_f16_comp), vk_upsampling2d_bilinear_f16_comp_len / sizeof(uint32_t));
138 : SpanView<uint32_t> Upsampling2dNearestF16Comp(reinterpret_cast<const uint32_t *>(vk_upsampling2d_nearest_f16_comp), vk_upsampling2d_nearest_f16_comp_len / sizeof(uint32_t));
139 :
140 : SpanView<uint32_t> AddVectorToMatrixRowsComp(reinterpret_cast<const uint32_t *>(AddVectorToMatrixRows_comp), AddVectorToMatrixRows_comp_len / sizeof(uint32_t));
141 : SpanView<uint32_t> BufferNormComp(reinterpret_cast<const uint32_t *>(BufferNorm_comp), BufferNorm_comp_len / sizeof(uint32_t));
142 : SpanView<uint32_t> MultiplyMatrixByMatrixComp(reinterpret_cast<const uint32_t *>(MultiplyMatrixByMatrix_comp), MultiplyMatrixByMatrix_comp_len / sizeof(uint32_t));
143 : SpanView<uint32_t> MultiplyMatrixByMatrixBordersComp(reinterpret_cast<const uint32_t *>(MultiplyMatrixByMatrixBorders_comp), MultiplyMatrixByMatrixBorders_comp_len / sizeof(uint32_t));
144 : SpanView<uint32_t> MultiplyMatrixByTransposedMatrixComp(reinterpret_cast<const uint32_t *>(MultiplyMatrixByTransposedMatrix_comp), MultiplyMatrixByTransposedMatrix_comp_len / sizeof(uint32_t));
145 : SpanView<uint32_t> MultiplyMatrixByTransposedMatrixBordersComp(reinterpret_cast<const uint32_t *>(MultiplyMatrixByTransposedMatrixBorders_comp), MultiplyMatrixByTransposedMatrixBorders_comp_len / sizeof(uint32_t));
146 : SpanView<uint32_t> MultiplyTransposedMatrixByMatrixComp(reinterpret_cast<const uint32_t *>(MultiplyTransposedMatrixByMatrix_comp), MultiplyTransposedMatrixByMatrix_comp_len / sizeof(uint32_t));
147 : SpanView<uint32_t> MultiplyTransposedMatrixByMatrixBordersComp(reinterpret_cast<const uint32_t *>(MultiplyTransposedMatrixByMatrixBorders_comp), MultiplyTransposedMatrixByMatrixBorders_comp_len / sizeof(uint32_t));
148 : SpanView<uint32_t> MatrixSoftmaxByRowsComp(reinterpret_cast<const uint32_t *>(MatrixSoftmaxByRows_comp), MatrixSoftmaxByRows_comp_len / sizeof(uint32_t));
149 : SpanView<uint32_t> VectorAddFloat1Comp(reinterpret_cast<const uint32_t *>(VectorAddFloat1_comp), VectorAddFloat1_comp_len / sizeof(uint32_t));
150 : SpanView<uint32_t> VectorAddFloat4Comp(reinterpret_cast<const uint32_t *>(VectorAddFloat4_comp), VectorAddFloat4_comp_len / sizeof(uint32_t));
151 : SpanView<uint32_t> VectorReLUComp(reinterpret_cast<const uint32_t *>(VectorReLU_comp), VectorReLU_comp_len / sizeof(uint32_t));
152 : SpanView<uint32_t> VectorReLU4Comp(reinterpret_cast<const uint32_t *>(VectorReLU4_comp), VectorReLU4_comp_len / sizeof(uint32_t));
153 : SpanView<uint32_t> VectorReLUDiffComp(reinterpret_cast<const uint32_t *>(VectorReLUDiff_comp), VectorReLUDiff_comp_len / sizeof(uint32_t));
154 : SpanView<uint32_t> VectorLogComp(reinterpret_cast<const uint32_t *>(VectorLog_comp), VectorLog_comp_len / sizeof(uint32_t));
155 : SpanView<uint32_t> VectorDotProductComp(reinterpret_cast<const uint32_t *>(VectorDotProduct_comp), VectorDotProduct_comp_len / sizeof(uint32_t));
156 : SpanView<uint32_t> VectorMultiplyFloatComp(reinterpret_cast<const uint32_t *>(VectorMultiplyFloat_comp), VectorMultiplyFloat_comp_len / sizeof(uint32_t));
157 : SpanView<uint32_t> VectorMultiplyAndAddComp(reinterpret_cast<const uint32_t *>(VectorMultiplyAndAdd_comp), VectorMultiplyAndAdd_comp_len / sizeof(uint32_t));
158 : SpanView<uint32_t> VectorSubComp(reinterpret_cast<const uint32_t *>(VectorSubFloat_comp), VectorSubFloat_comp_len / sizeof(uint32_t));
159 : SpanView<uint32_t> SumMatrixColumnsComp(reinterpret_cast<const uint32_t *>(SumMatrixColumns_comp), SumMatrixColumns_comp_len / sizeof(uint32_t));
160 : SpanView<uint32_t> SumMatrixRowsComp(reinterpret_cast<const uint32_t *>(SumMatrixRows_comp), SumMatrixRows_comp_len / sizeof(uint32_t));
161 : SpanView<uint32_t> MultiplyDiagMatrixByMatrixComp(reinterpret_cast<const uint32_t *>(MultiplyDiagMatrixByMatrix_comp), MultiplyDiagMatrixByMatrix_comp_len / sizeof(uint32_t));
162 :
163 : SpanView<uint32_t> StatNormComp(reinterpret_cast<const uint32_t *>(StatNorm_comp), StatNorm_comp_len / sizeof(uint32_t));
164 : SpanView<uint32_t> StatClassMapComp(reinterpret_cast<const uint32_t *>(StatClassMap_comp), StatClassMap_comp_len / sizeof(uint32_t));
165 : SpanView<uint32_t> StatClassPercentComp(reinterpret_cast<const uint32_t *>(StatClassPercent_comp), StatClassPercent_comp_len / sizeof(uint32_t));
166 : SpanView<uint32_t> StatAnalysisComp(reinterpret_cast<const uint32_t *>(StatAnalysis_comp), StatAnalysis_comp_len / sizeof(uint32_t));
167 :
168 0 : Precision getAttachmentPrecision(const core::AttachmentData *data) {
169 0 : if (data->type == core::AttachmentType::Image) {
170 0 : auto img = static_cast<core::ImageAttachment *>(data->attachment.get());
171 0 : auto fmt = img->getImageInfo().format;
172 0 : switch (fmt) {
173 0 : case core::ImageFormat::R8_UNORM:
174 : case core::ImageFormat::R8_SNORM:
175 : case core::ImageFormat::R8_USCALED:
176 : case core::ImageFormat::R8_SSCALED:
177 : case core::ImageFormat::R8_UINT:
178 : case core::ImageFormat::R8_SINT:
179 : case core::ImageFormat::R8_SRGB:
180 : case core::ImageFormat::R8G8_UNORM:
181 : case core::ImageFormat::R8G8_SNORM:
182 : case core::ImageFormat::R8G8_USCALED:
183 : case core::ImageFormat::R8G8_SSCALED:
184 : case core::ImageFormat::R8G8_UINT:
185 : case core::ImageFormat::R8G8_SINT:
186 : case core::ImageFormat::R8G8_SRGB:
187 : case core::ImageFormat::R8G8B8_UNORM:
188 : case core::ImageFormat::R8G8B8_SNORM:
189 : case core::ImageFormat::R8G8B8_USCALED:
190 : case core::ImageFormat::R8G8B8_SSCALED:
191 : case core::ImageFormat::R8G8B8_UINT:
192 : case core::ImageFormat::R8G8B8_SINT:
193 : case core::ImageFormat::R8G8B8_SRGB:
194 : case core::ImageFormat::B8G8R8_UNORM:
195 : case core::ImageFormat::B8G8R8_SNORM:
196 : case core::ImageFormat::B8G8R8_USCALED:
197 : case core::ImageFormat::B8G8R8_SSCALED:
198 : case core::ImageFormat::B8G8R8_UINT:
199 : case core::ImageFormat::B8G8R8_SINT:
200 : case core::ImageFormat::B8G8R8_SRGB:
201 : case core::ImageFormat::R8G8B8A8_UNORM:
202 : case core::ImageFormat::R8G8B8A8_SNORM:
203 : case core::ImageFormat::R8G8B8A8_USCALED:
204 : case core::ImageFormat::R8G8B8A8_SSCALED:
205 : case core::ImageFormat::R8G8B8A8_UINT:
206 : case core::ImageFormat::R8G8B8A8_SINT:
207 : case core::ImageFormat::R8G8B8A8_SRGB:
208 : case core::ImageFormat::B8G8R8A8_UNORM:
209 : case core::ImageFormat::B8G8R8A8_SNORM:
210 : case core::ImageFormat::B8G8R8A8_USCALED:
211 : case core::ImageFormat::B8G8R8A8_SSCALED:
212 : case core::ImageFormat::B8G8R8A8_UINT:
213 : case core::ImageFormat::B8G8R8A8_SINT:
214 : case core::ImageFormat::B8G8R8A8_SRGB:
215 : case core::ImageFormat::A8B8G8R8_UNORM_PACK32:
216 : case core::ImageFormat::A8B8G8R8_SNORM_PACK32:
217 : case core::ImageFormat::A8B8G8R8_USCALED_PACK32:
218 : case core::ImageFormat::A8B8G8R8_SSCALED_PACK32:
219 : case core::ImageFormat::A8B8G8R8_UINT_PACK32:
220 : case core::ImageFormat::A8B8G8R8_SINT_PACK32:
221 : case core::ImageFormat::A8B8G8R8_SRGB_PACK32:
222 0 : return Precision::F8;
223 : break;
224 0 : case core::ImageFormat::A2R10G10B10_UNORM_PACK32:
225 : case core::ImageFormat::A2R10G10B10_SNORM_PACK32:
226 : case core::ImageFormat::A2R10G10B10_USCALED_PACK32:
227 : case core::ImageFormat::A2R10G10B10_SSCALED_PACK32:
228 : case core::ImageFormat::A2R10G10B10_UINT_PACK32:
229 : case core::ImageFormat::A2R10G10B10_SINT_PACK32:
230 : case core::ImageFormat::A2B10G10R10_UNORM_PACK32:
231 : case core::ImageFormat::A2B10G10R10_SNORM_PACK32:
232 : case core::ImageFormat::A2B10G10R10_USCALED_PACK32:
233 : case core::ImageFormat::A2B10G10R10_SSCALED_PACK32:
234 : case core::ImageFormat::A2B10G10R10_UINT_PACK32:
235 : case core::ImageFormat::A2B10G10R10_SINT_PACK32:
236 : case core::ImageFormat::R16_UNORM:
237 : case core::ImageFormat::R16_SNORM:
238 : case core::ImageFormat::R16_USCALED:
239 : case core::ImageFormat::R16_SSCALED:
240 : case core::ImageFormat::R16_UINT:
241 : case core::ImageFormat::R16_SINT:
242 : case core::ImageFormat::R16_SFLOAT:
243 : case core::ImageFormat::R16G16_UNORM:
244 : case core::ImageFormat::R16G16_SNORM:
245 : case core::ImageFormat::R16G16_USCALED:
246 : case core::ImageFormat::R16G16_SSCALED:
247 : case core::ImageFormat::R16G16_UINT:
248 : case core::ImageFormat::R16G16_SINT:
249 : case core::ImageFormat::R16G16_SFLOAT:
250 : case core::ImageFormat::R16G16B16_UNORM:
251 : case core::ImageFormat::R16G16B16_SNORM:
252 : case core::ImageFormat::R16G16B16_USCALED:
253 : case core::ImageFormat::R16G16B16_SSCALED:
254 : case core::ImageFormat::R16G16B16_UINT:
255 : case core::ImageFormat::R16G16B16_SINT:
256 : case core::ImageFormat::R16G16B16_SFLOAT:
257 : case core::ImageFormat::R16G16B16A16_UNORM:
258 : case core::ImageFormat::R16G16B16A16_SNORM:
259 : case core::ImageFormat::R16G16B16A16_USCALED:
260 : case core::ImageFormat::R16G16B16A16_SSCALED:
261 : case core::ImageFormat::R16G16B16A16_UINT:
262 : case core::ImageFormat::R16G16B16A16_SINT:
263 : case core::ImageFormat::R16G16B16A16_SFLOAT:
264 0 : return Precision::F16;
265 : break;
266 0 : case core::ImageFormat::R32_UINT:
267 : case core::ImageFormat::R32_SINT:
268 : case core::ImageFormat::R32_SFLOAT:
269 : case core::ImageFormat::R32G32_UINT:
270 : case core::ImageFormat::R32G32_SINT:
271 : case core::ImageFormat::R32G32_SFLOAT:
272 : case core::ImageFormat::R32G32B32_UINT:
273 : case core::ImageFormat::R32G32B32_SINT:
274 : case core::ImageFormat::R32G32B32_SFLOAT:
275 : case core::ImageFormat::R32G32B32A32_UINT:
276 : case core::ImageFormat::R32G32B32A32_SINT:
277 : case core::ImageFormat::R32G32B32A32_SFLOAT:
278 0 : return Precision::F32;
279 : break;
280 0 : case core::ImageFormat::R64_UINT:
281 : case core::ImageFormat::R64_SINT:
282 : case core::ImageFormat::R64_SFLOAT:
283 : case core::ImageFormat::R64G64_UINT:
284 : case core::ImageFormat::R64G64_SINT:
285 : case core::ImageFormat::R64G64_SFLOAT:
286 : case core::ImageFormat::R64G64B64_UINT:
287 : case core::ImageFormat::R64G64B64_SINT:
288 : case core::ImageFormat::R64G64B64_SFLOAT:
289 : case core::ImageFormat::R64G64B64A64_UINT:
290 : case core::ImageFormat::R64G64B64A64_SINT:
291 : case core::ImageFormat::R64G64B64A64_SFLOAT:
292 0 : return Precision::F32;
293 : break;
294 0 : default:
295 0 : return Precision::Unknown;
296 : break;
297 : }
298 : }
299 0 : return Precision::Unknown;
300 : }
301 :
302 65 : SpanView<uint32_t> getShader(LayerShader sh, Precision p) {
303 65 : switch (p) {
304 0 : case Precision::F16:
305 : switch (sh) {
306 0 : case LayerShader::Gen: return GenF16Comp; break;
307 0 : case LayerShader::Norm: return NormF16Comp; break;
308 0 : case LayerShader::Activation: return ActivationF16Comp; break;
309 0 : case LayerShader::Add: return AddF16Comp; break;
310 0 : case LayerShader::Avgpool2d: return Avgpool2dF16Comp; break;
311 0 : case LayerShader::Batchnorm: return BatchnormF16Comp; break;
312 0 : case LayerShader::Concat: return ConcatF16Comp; break;
313 0 : case LayerShader::Conv2d1x1: return Conv2d1x1F16Comp; break;
314 0 : case LayerShader::Conv2d: return Conv2dF16Comp; break;
315 0 : case LayerShader::Dense: return DenseF16Comp; break;
316 0 : case LayerShader::Depthwise: return DepthwiseF16Comp; break;
317 0 : case LayerShader::Flatten: return FlattenF16Comp; break;
318 0 : case LayerShader::Instancenorm: return InstancenormF16Comp; break;
319 0 : case LayerShader::Maxpool2d: return Maxpool2dF16Comp; break;
320 0 : case LayerShader::Pad: return PadF16Comp; break;
321 0 : case LayerShader::Resize: return ResizeF16Comp; break;
322 0 : case LayerShader::Subpixel: return SubpixelF16Comp; break;
323 0 : case LayerShader::Unary: return UnaryF16Comp; break;
324 0 : case LayerShader::Upsampling2dBilinear: return Upsampling2dBilinearF16Comp; break;
325 0 : case LayerShader::Upsampling2dNearest: return Upsampling2dNearestF16Comp; break;
326 0 : case LayerShader::AddVectorToMatrixRows: return AddVectorToMatrixRowsComp; break;
327 0 : case LayerShader::BufferNorm: return BufferNormComp; break;
328 0 : case LayerShader::MultiplyMatrixByMatrix: return MultiplyMatrixByMatrixComp; break;
329 0 : case LayerShader::MultiplyMatrixByMatrixBorder: return MultiplyMatrixByMatrixBordersComp; break;
330 0 : case LayerShader::MultiplyMatrixByTransposedMatrix: return MultiplyMatrixByTransposedMatrixComp; break;
331 0 : case LayerShader::MultiplyMatrixByTransposedMatrixBorder: return MultiplyMatrixByTransposedMatrixBordersComp; break;
332 0 : case LayerShader::MultiplyTransposedMatrixByMatrix: return MultiplyTransposedMatrixByMatrixComp; break;
333 0 : case LayerShader::MultiplyTransposedMatrixByMatrixBorder: return MultiplyTransposedMatrixByMatrixBordersComp; break;
334 0 : case LayerShader::MatrixSoftmaxByRows: return MatrixSoftmaxByRowsComp; break;
335 0 : case LayerShader::VectorAddFloat1: return VectorAddFloat1Comp; break;
336 0 : case LayerShader::VectorAddFloat4: return VectorAddFloat4Comp; break;
337 0 : case LayerShader::VectorReLU: return VectorReLUComp; break;
338 0 : case LayerShader::VectorReLU4: return VectorReLU4Comp; break;
339 0 : case LayerShader::VectorReLUDiff: return VectorReLUDiffComp; break;
340 0 : case LayerShader::VectorLog: return VectorLogComp; break;
341 0 : case LayerShader::VectorDotProduct: return VectorDotProductComp; break;
342 0 : case LayerShader::VectorEltwiseMultiply: return VectorMultiplyFloatComp; break;
343 0 : case LayerShader::VectorMultiplyAndAdd: return VectorMultiplyAndAddComp; break;
344 0 : case LayerShader::VectorSub: return VectorSubComp; break;
345 0 : case LayerShader::SumMatrixColumns: return SumMatrixColumnsComp; break;
346 0 : case LayerShader::SumMatrixRows: return SumMatrixRowsComp; break;
347 0 : case LayerShader::MultiplyDiagMatrixByMatrix: return MultiplyDiagMatrixByMatrixComp; break;
348 0 : case LayerShader::StatNorm: return StatNormComp; break;
349 0 : case LayerShader::StatClassMap: return StatClassMapComp; break;
350 0 : case LayerShader::StatClassPercent: return StatClassPercentComp; break;
351 0 : case LayerShader::StatAnalysis: return StatAnalysisComp; break;
352 : }
353 0 : break;
354 0 : case Precision::F32:
355 : switch (sh) {
356 0 : case LayerShader::Gen: return GenF32Comp; break;
357 0 : case LayerShader::Norm: return NormF32Comp; break;
358 0 : case LayerShader::Activation: return ActivationF32Comp; break;
359 0 : case LayerShader::Add: return AddF32Comp; break;
360 0 : case LayerShader::Avgpool2d: return Avgpool2dF32Comp; break;
361 0 : case LayerShader::Batchnorm: return BatchnormF32Comp; break;
362 0 : case LayerShader::Concat: return ConcatF32Comp; break;
363 0 : case LayerShader::Conv2d1x1: return Conv2d1x1F32Comp; break;
364 0 : case LayerShader::Conv2d: return Conv2dF32Comp; break;
365 0 : case LayerShader::Dense: return DenseF32Comp; break;
366 0 : case LayerShader::Depthwise: return DepthwiseF32Comp; break;
367 0 : case LayerShader::Flatten: return FlattenF32Comp; break;
368 0 : case LayerShader::Instancenorm: return InstancenormF32Comp; break;
369 0 : case LayerShader::Maxpool2d: return Maxpool2dF32Comp; break;
370 0 : case LayerShader::Pad: return PadF32Comp; break;
371 0 : case LayerShader::Resize: return ResizeF32Comp; break;
372 0 : case LayerShader::Subpixel: return SubpixelF32Comp; break;
373 0 : case LayerShader::Unary: return UnaryF32Comp; break;
374 0 : case LayerShader::Upsampling2dBilinear: return Upsampling2dBilinearF32Comp; break;
375 0 : case LayerShader::Upsampling2dNearest: return Upsampling2dNearestF32Comp; break;
376 0 : case LayerShader::AddVectorToMatrixRows: return AddVectorToMatrixRowsComp; break;
377 0 : case LayerShader::BufferNorm: return BufferNormComp; break;
378 0 : case LayerShader::MultiplyMatrixByMatrix: return MultiplyMatrixByMatrixComp; break;
379 0 : case LayerShader::MultiplyMatrixByMatrixBorder: return MultiplyMatrixByMatrixBordersComp; break;
380 0 : case LayerShader::MultiplyMatrixByTransposedMatrix: return MultiplyMatrixByTransposedMatrixComp; break;
381 0 : case LayerShader::MultiplyMatrixByTransposedMatrixBorder: return MultiplyMatrixByTransposedMatrixBordersComp; break;
382 0 : case LayerShader::MultiplyTransposedMatrixByMatrix: return MultiplyTransposedMatrixByMatrixComp; break;
383 0 : case LayerShader::MultiplyTransposedMatrixByMatrixBorder: return MultiplyTransposedMatrixByMatrixBordersComp; break;
384 0 : case LayerShader::MatrixSoftmaxByRows: return MatrixSoftmaxByRowsComp; break;
385 0 : case LayerShader::VectorAddFloat1: return VectorAddFloat1Comp; break;
386 0 : case LayerShader::VectorAddFloat4: return VectorAddFloat4Comp; break;
387 0 : case LayerShader::VectorReLU: return VectorReLUComp; break;
388 0 : case LayerShader::VectorReLU4: return VectorReLU4Comp; break;
389 0 : case LayerShader::VectorReLUDiff: return VectorReLUDiffComp; break;
390 0 : case LayerShader::VectorLog: return VectorLogComp; break;
391 0 : case LayerShader::VectorDotProduct: return VectorDotProductComp; break;
392 0 : case LayerShader::VectorEltwiseMultiply: return VectorMultiplyFloatComp; break;
393 0 : case LayerShader::VectorMultiplyAndAdd: return VectorMultiplyAndAddComp; break;
394 0 : case LayerShader::VectorSub: return VectorSubComp; break;
395 0 : case LayerShader::SumMatrixColumns: return SumMatrixColumnsComp; break;
396 0 : case LayerShader::SumMatrixRows: return SumMatrixRowsComp; break;
397 0 : case LayerShader::MultiplyDiagMatrixByMatrix: return MultiplyDiagMatrixByMatrixComp; break;
398 0 : case LayerShader::StatNorm: return StatNormComp; break;
399 0 : case LayerShader::StatClassMap: return StatClassMapComp; break;
400 0 : case LayerShader::StatClassPercent: return StatClassPercentComp; break;
401 0 : case LayerShader::StatAnalysis: return StatAnalysisComp; break;
402 : }
403 0 : break;
404 65 : default:
405 65 : break;
406 : }
407 :
408 65 : switch (sh) {
409 3 : case LayerShader::AddVectorToMatrixRows: return AddVectorToMatrixRowsComp; break;
410 2 : case LayerShader::BufferNorm: return BufferNormComp; break;
411 2 : case LayerShader::MultiplyMatrixByMatrix: return MultiplyMatrixByMatrixComp; break;
412 2 : case LayerShader::MultiplyMatrixByMatrixBorder: return MultiplyMatrixByMatrixBordersComp; break;
413 3 : case LayerShader::MultiplyMatrixByTransposedMatrix: return MultiplyMatrixByTransposedMatrixComp; break;
414 3 : case LayerShader::MultiplyMatrixByTransposedMatrixBorder: return MultiplyMatrixByTransposedMatrixBordersComp; break;
415 3 : case LayerShader::MultiplyTransposedMatrixByMatrix: return MultiplyTransposedMatrixByMatrixComp; break;
416 3 : case LayerShader::MultiplyTransposedMatrixByMatrixBorder: return MultiplyTransposedMatrixByMatrixBordersComp; break;
417 1 : case LayerShader::MatrixSoftmaxByRows: return MatrixSoftmaxByRowsComp; break;
418 6 : case LayerShader::VectorAddFloat1: return VectorAddFloat1Comp; break;
419 6 : case LayerShader::VectorAddFloat4: return VectorAddFloat4Comp; break;
420 3 : case LayerShader::VectorReLU: return VectorReLUComp; break;
421 3 : case LayerShader::VectorReLU4: return VectorReLU4Comp; break;
422 2 : case LayerShader::VectorReLUDiff: return VectorReLUDiffComp; break;
423 1 : case LayerShader::VectorLog: return VectorLogComp; break;
424 1 : case LayerShader::VectorDotProduct: return VectorDotProductComp; break;
425 7 : case LayerShader::VectorEltwiseMultiply: return VectorMultiplyFloatComp; break;
426 6 : case LayerShader::VectorMultiplyAndAdd: return VectorMultiplyAndAddComp; break;
427 1 : case LayerShader::VectorSub: return VectorSubComp; break;
428 2 : case LayerShader::SumMatrixColumns: return SumMatrixColumnsComp; break;
429 3 : case LayerShader::SumMatrixRows: return SumMatrixRowsComp; break;
430 2 : case LayerShader::MultiplyDiagMatrixByMatrix: return MultiplyDiagMatrixByMatrixComp; break;
431 :
432 0 : case LayerShader::StatNorm: return StatNormComp; break;
433 0 : case LayerShader::StatClassMap: return StatClassMapComp; break;
434 0 : case LayerShader::StatClassPercent: return StatClassPercentComp; break;
435 0 : case LayerShader::StatAnalysis: return StatAnalysisComp; break;
436 0 : default: break;
437 : }
438 :
439 0 : return SpanView<uint32_t>();
440 : }
441 :
442 9 : void FillFloatBuffer(uint8_t *buf, uint64_t size, float val) {
443 9 : auto target = (float *)buf;
444 9 : size /= sizeof(float);
445 1334979 : while (size > 0) {
446 1334970 : *(target ++) = val;
447 1334970 : -- size;
448 : }
449 9 : }
450 :
451 : }
452 :
453 : #include "XLSnnVkActivationLayer.cc"
454 : #include "XLSnnVkGenerationLayer.cc"
455 : #include "XLSnnVkInputLayer.cc"
456 : #include "XLSnnVkConvLayer.cc"
457 : #include "XLSnnVkSubpixelLayer.cc"
458 : #include "XLSnnVkMatrixMulLayer.cc"
459 : #include "XLSnnVkStatPercentLayer.cc"
460 : #include "XLSnnVkLossLayer.cc"
461 : #include "XLSnnVkTrainableLayer.cc"
462 :
463 : #include "XLSnnVkNmeMath.cc"
|