LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkShaders.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 34 152 22.4 %
Date: 2024-05-06 04:51:23 Functions: 2 3 66.7 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLCommon.h"
      24             : #include "XLSnnVkShaders.h"
      25             : 
      26             : namespace stappler::xenolith::vk::shadernn {
      27             : 
      28             : #include "gen_f32.comp.h"
      29             : #include "norm_f32.comp.h"
      30             : #include "vk_activation_f32.comp.h"
      31             : #include "vk_add_f32.comp.h"
      32             : #include "vk_avgpool2d_f32.comp.h"
      33             : #include "vk_batchnorm_f32.comp.h"
      34             : #include "vk_concat_f32.comp.h"
      35             : #include "vk_conv2d_1x1_f32.comp.h"
      36             : #include "vk_conv2d_f32.comp.h"
      37             : #include "vk_dense_f32.comp.h"
      38             : #include "vk_depthwise_f32.comp.h"
      39             : #include "vk_flatten_f32.comp.h"
      40             : #include "vk_instancenorm_f32.comp.h"
      41             : #include "vk_maxpool2d_f32.comp.h"
      42             : #include "vk_pad_f32.comp.h"
      43             : #include "vk_resize_f32.comp.h"
      44             : #include "vk_subpixel_f32.comp.h"
      45             : #include "vk_unary_f32.comp.h"
      46             : #include "vk_upsampling2d_bilinear_f32.comp.h"
      47             : #include "vk_upsampling2d_nearest_f32.comp.h"
      48             : 
      49             : #include "gen_f16.comp.h"
      50             : #include "norm_f16.comp.h"
      51             : #include "vk_activation_f16.comp.h"
      52             : #include "vk_add_f16.comp.h"
      53             : #include "vk_avgpool2d_f16.comp.h"
      54             : #include "vk_batchnorm_f16.comp.h"
      55             : #include "vk_concat_f16.comp.h"
      56             : #include "vk_conv2d_1x1_f16.comp.h"
      57             : #include "vk_conv2d_f16.comp.h"
      58             : #include "vk_dense_f16.comp.h"
      59             : #include "vk_depthwise_f16.comp.h"
      60             : #include "vk_flatten_f16.comp.h"
      61             : #include "vk_instancenorm_f16.comp.h"
      62             : #include "vk_maxpool2d_f16.comp.h"
      63             : #include "vk_pad_f16.comp.h"
      64             : #include "vk_resize_f16.comp.h"
      65             : #include "vk_subpixel_f16.comp.h"
      66             : #include "vk_unary_f16.comp.h"
      67             : #include "vk_upsampling2d_bilinear_f16.comp.h"
      68             : #include "vk_upsampling2d_nearest_f16.comp.h"
      69             : 
      70             : #include "AddVectorToMatrixRows.comp.h"
      71             : #include "BufferNorm.comp.h"
      72             : #include "MultiplyMatrixByMatrix.comp.h"
      73             : #include "MultiplyMatrixByMatrixBorders.comp.h"
      74             : #include "MultiplyMatrixByTransposedMatrix.comp.h"
      75             : #include "MultiplyMatrixByTransposedMatrixBorders.comp.h"
      76             : #include "MultiplyTransposedMatrixByMatrix.comp.h"
      77             : #include "MultiplyTransposedMatrixByMatrixBorders.comp.h"
      78             : #include "MatrixSoftmaxByRows.comp.h"
      79             : #include "VectorAddFloat1.comp.h"
      80             : #include "VectorAddFloat4.comp.h"
      81             : #include "VectorReLU.comp.h"
      82             : #include "VectorReLU4.comp.h"
      83             : #include "VectorReLUDiff.comp.h"
      84             : #include "VectorLog.comp.h"
      85             : #include "VectorDotProduct.comp.h"
      86             : #include "VectorMultiplyFloat.comp.h"
      87             : #include "VectorMultiplyAndAdd.comp.h"
      88             : #include "VectorSubFloat.comp.h"
      89             : #include "SumMatrixColumns.comp.h"
      90             : #include "SumMatrixRows.comp.h"
      91             : #include "MultiplyDiagMatrixByMatrix.comp.h"
      92             : 
      93             : #include "StatNorm.comp.h"
      94             : #include "StatClassMap.comp.h"
      95             : #include "StatClassPercent.comp.h"
      96             : #include "StatAnalysis.comp.h"
      97             : 
      98             : SpanView<uint32_t> GenF32Comp(reinterpret_cast<const uint32_t *>(gen_f32_comp), gen_f32_comp_len / sizeof(uint32_t));
      99             : SpanView<uint32_t> NormF32Comp(reinterpret_cast<const uint32_t *>(norm_f32_comp), norm_f32_comp_len / sizeof(uint32_t));
     100             : SpanView<uint32_t> ActivationF32Comp(reinterpret_cast<const uint32_t *>(vk_activation_f32_comp), vk_activation_f32_comp_len / sizeof(uint32_t));
     101             : SpanView<uint32_t> AddF32Comp(reinterpret_cast<const uint32_t *>(vk_add_f32_comp), vk_add_f32_comp_len / sizeof(uint32_t));
     102             : SpanView<uint32_t> Avgpool2dF32Comp(reinterpret_cast<const uint32_t *>(vk_avgpool2d_f32_comp), vk_avgpool2d_f32_comp_len / sizeof(uint32_t));
     103             : SpanView<uint32_t> BatchnormF32Comp(reinterpret_cast<const uint32_t *>(vk_batchnorm_f32_comp), vk_batchnorm_f32_comp_len / sizeof(uint32_t));
     104             : SpanView<uint32_t> ConcatF32Comp(reinterpret_cast<const uint32_t *>(vk_concat_f32_comp), vk_concat_f32_comp_len / sizeof(uint32_t));
     105             : SpanView<uint32_t> Conv2d1x1F32Comp(reinterpret_cast<const uint32_t *>(vk_conv2d_1x1_f32_comp), vk_conv2d_1x1_f32_comp_len / sizeof(uint32_t));
     106             : SpanView<uint32_t> Conv2dF32Comp(reinterpret_cast<const uint32_t *>(vk_conv2d_f32_comp), vk_conv2d_f32_comp_len / sizeof(uint32_t));
     107             : SpanView<uint32_t> DenseF32Comp(reinterpret_cast<const uint32_t *>(vk_dense_f32_comp), vk_dense_f32_comp_len / sizeof(uint32_t));
     108             : SpanView<uint32_t> DepthwiseF32Comp(reinterpret_cast<const uint32_t *>(vk_depthwise_f32_comp), vk_depthwise_f32_comp_len / sizeof(uint32_t));
     109             : SpanView<uint32_t> FlattenF32Comp(reinterpret_cast<const uint32_t *>(vk_flatten_f32_comp), vk_flatten_f32_comp_len / sizeof(uint32_t));
     110             : SpanView<uint32_t> InstancenormF32Comp(reinterpret_cast<const uint32_t *>(vk_instancenorm_f32_comp), vk_instancenorm_f32_comp_len / sizeof(uint32_t));
     111             : SpanView<uint32_t> Maxpool2dF32Comp(reinterpret_cast<const uint32_t *>(vk_maxpool2d_f32_comp), vk_maxpool2d_f32_comp_len / sizeof(uint32_t));
     112             : SpanView<uint32_t> PadF32Comp(reinterpret_cast<const uint32_t *>(vk_pad_f32_comp), vk_pad_f32_comp_len / sizeof(uint32_t));
     113             : SpanView<uint32_t> ResizeF32Comp(reinterpret_cast<const uint32_t *>(vk_resize_f32_comp), vk_resize_f32_comp_len / sizeof(uint32_t));
     114             : SpanView<uint32_t> SubpixelF32Comp(reinterpret_cast<const uint32_t *>(vk_subpixel_f32_comp), vk_subpixel_f32_comp_len / sizeof(uint32_t));
     115             : SpanView<uint32_t> UnaryF32Comp(reinterpret_cast<const uint32_t *>(vk_unary_f32_comp), vk_unary_f32_comp_len / sizeof(uint32_t));
     116             : SpanView<uint32_t> Upsampling2dBilinearF32Comp(reinterpret_cast<const uint32_t *>(vk_upsampling2d_bilinear_f32_comp), vk_upsampling2d_bilinear_f32_comp_len / sizeof(uint32_t));
     117             : SpanView<uint32_t> Upsampling2dNearestF32Comp(reinterpret_cast<const uint32_t *>(vk_upsampling2d_nearest_f32_comp), vk_upsampling2d_nearest_f32_comp_len / sizeof(uint32_t));
     118             : 
     119             : SpanView<uint32_t> GenF16Comp(reinterpret_cast<const uint32_t *>(gen_f16_comp), gen_f16_comp_len / sizeof(uint32_t));
     120             : SpanView<uint32_t> NormF16Comp(reinterpret_cast<const uint32_t *>(norm_f16_comp), norm_f16_comp_len / sizeof(uint32_t));
     121             : SpanView<uint32_t> ActivationF16Comp(reinterpret_cast<const uint32_t *>(vk_activation_f16_comp), vk_activation_f16_comp_len / sizeof(uint32_t));
     122             : SpanView<uint32_t> AddF16Comp(reinterpret_cast<const uint32_t *>(vk_add_f16_comp), vk_add_f16_comp_len / sizeof(uint32_t));
     123             : SpanView<uint32_t> Avgpool2dF16Comp(reinterpret_cast<const uint32_t *>(vk_avgpool2d_f16_comp), vk_avgpool2d_f16_comp_len / sizeof(uint32_t));
     124             : SpanView<uint32_t> BatchnormF16Comp(reinterpret_cast<const uint32_t *>(vk_batchnorm_f16_comp), vk_batchnorm_f16_comp_len / sizeof(uint32_t));
     125             : SpanView<uint32_t> ConcatF16Comp(reinterpret_cast<const uint32_t *>(vk_concat_f16_comp), vk_concat_f16_comp_len / sizeof(uint32_t));
     126             : SpanView<uint32_t> Conv2d1x1F16Comp(reinterpret_cast<const uint32_t *>(vk_conv2d_1x1_f16_comp), vk_conv2d_1x1_f16_comp_len / sizeof(uint32_t));
     127             : SpanView<uint32_t> Conv2dF16Comp(reinterpret_cast<const uint32_t *>(vk_conv2d_f16_comp), vk_conv2d_f16_comp_len / sizeof(uint32_t));
     128             : SpanView<uint32_t> DenseF16Comp(reinterpret_cast<const uint32_t *>(vk_dense_f16_comp), vk_dense_f16_comp_len / sizeof(uint32_t));
     129             : SpanView<uint32_t> DepthwiseF16Comp(reinterpret_cast<const uint32_t *>(vk_depthwise_f16_comp), vk_depthwise_f16_comp_len / sizeof(uint32_t));
     130             : SpanView<uint32_t> FlattenF16Comp(reinterpret_cast<const uint32_t *>(vk_flatten_f16_comp), vk_flatten_f16_comp_len / sizeof(uint32_t));
     131             : SpanView<uint32_t> InstancenormF16Comp(reinterpret_cast<const uint32_t *>(vk_instancenorm_f16_comp), vk_instancenorm_f16_comp_len / sizeof(uint32_t));
     132             : SpanView<uint32_t> Maxpool2dF16Comp(reinterpret_cast<const uint32_t *>(vk_maxpool2d_f16_comp), vk_maxpool2d_f16_comp_len / sizeof(uint32_t));
     133             : SpanView<uint32_t> PadF16Comp(reinterpret_cast<const uint32_t *>(vk_pad_f16_comp), vk_pad_f16_comp_len / sizeof(uint32_t));
     134             : SpanView<uint32_t> ResizeF16Comp(reinterpret_cast<const uint32_t *>(vk_resize_f16_comp), vk_resize_f16_comp_len / sizeof(uint32_t));
     135             : SpanView<uint32_t> SubpixelF16Comp(reinterpret_cast<const uint32_t *>(vk_subpixel_f16_comp), vk_subpixel_f16_comp_len / sizeof(uint32_t));
     136             : SpanView<uint32_t> UnaryF16Comp(reinterpret_cast<const uint32_t *>(vk_unary_f16_comp), vk_unary_f16_comp_len / sizeof(uint32_t));
     137             : SpanView<uint32_t> Upsampling2dBilinearF16Comp(reinterpret_cast<const uint32_t *>(vk_upsampling2d_bilinear_f16_comp), vk_upsampling2d_bilinear_f16_comp_len / sizeof(uint32_t));
     138             : SpanView<uint32_t> Upsampling2dNearestF16Comp(reinterpret_cast<const uint32_t *>(vk_upsampling2d_nearest_f16_comp), vk_upsampling2d_nearest_f16_comp_len / sizeof(uint32_t));
     139             : 
     140             : SpanView<uint32_t> AddVectorToMatrixRowsComp(reinterpret_cast<const uint32_t *>(AddVectorToMatrixRows_comp), AddVectorToMatrixRows_comp_len / sizeof(uint32_t));
     141             : SpanView<uint32_t> BufferNormComp(reinterpret_cast<const uint32_t *>(BufferNorm_comp), BufferNorm_comp_len / sizeof(uint32_t));
     142             : SpanView<uint32_t> MultiplyMatrixByMatrixComp(reinterpret_cast<const uint32_t *>(MultiplyMatrixByMatrix_comp), MultiplyMatrixByMatrix_comp_len / sizeof(uint32_t));
     143             : SpanView<uint32_t> MultiplyMatrixByMatrixBordersComp(reinterpret_cast<const uint32_t *>(MultiplyMatrixByMatrixBorders_comp), MultiplyMatrixByMatrixBorders_comp_len / sizeof(uint32_t));
     144             : SpanView<uint32_t> MultiplyMatrixByTransposedMatrixComp(reinterpret_cast<const uint32_t *>(MultiplyMatrixByTransposedMatrix_comp), MultiplyMatrixByTransposedMatrix_comp_len / sizeof(uint32_t));
     145             : SpanView<uint32_t> MultiplyMatrixByTransposedMatrixBordersComp(reinterpret_cast<const uint32_t *>(MultiplyMatrixByTransposedMatrixBorders_comp), MultiplyMatrixByTransposedMatrixBorders_comp_len / sizeof(uint32_t));
     146             : SpanView<uint32_t> MultiplyTransposedMatrixByMatrixComp(reinterpret_cast<const uint32_t *>(MultiplyTransposedMatrixByMatrix_comp), MultiplyTransposedMatrixByMatrix_comp_len / sizeof(uint32_t));
     147             : SpanView<uint32_t> MultiplyTransposedMatrixByMatrixBordersComp(reinterpret_cast<const uint32_t *>(MultiplyTransposedMatrixByMatrixBorders_comp), MultiplyTransposedMatrixByMatrixBorders_comp_len / sizeof(uint32_t));
     148             : SpanView<uint32_t> MatrixSoftmaxByRowsComp(reinterpret_cast<const uint32_t *>(MatrixSoftmaxByRows_comp), MatrixSoftmaxByRows_comp_len / sizeof(uint32_t));
     149             : SpanView<uint32_t> VectorAddFloat1Comp(reinterpret_cast<const uint32_t *>(VectorAddFloat1_comp), VectorAddFloat1_comp_len / sizeof(uint32_t));
     150             : SpanView<uint32_t> VectorAddFloat4Comp(reinterpret_cast<const uint32_t *>(VectorAddFloat4_comp), VectorAddFloat4_comp_len / sizeof(uint32_t));
     151             : SpanView<uint32_t> VectorReLUComp(reinterpret_cast<const uint32_t *>(VectorReLU_comp), VectorReLU_comp_len / sizeof(uint32_t));
     152             : SpanView<uint32_t> VectorReLU4Comp(reinterpret_cast<const uint32_t *>(VectorReLU4_comp), VectorReLU4_comp_len / sizeof(uint32_t));
     153             : SpanView<uint32_t> VectorReLUDiffComp(reinterpret_cast<const uint32_t *>(VectorReLUDiff_comp), VectorReLUDiff_comp_len / sizeof(uint32_t));
     154             : SpanView<uint32_t> VectorLogComp(reinterpret_cast<const uint32_t *>(VectorLog_comp), VectorLog_comp_len / sizeof(uint32_t));
     155             : SpanView<uint32_t> VectorDotProductComp(reinterpret_cast<const uint32_t *>(VectorDotProduct_comp), VectorDotProduct_comp_len / sizeof(uint32_t));
     156             : SpanView<uint32_t> VectorMultiplyFloatComp(reinterpret_cast<const uint32_t *>(VectorMultiplyFloat_comp), VectorMultiplyFloat_comp_len / sizeof(uint32_t));
     157             : SpanView<uint32_t> VectorMultiplyAndAddComp(reinterpret_cast<const uint32_t *>(VectorMultiplyAndAdd_comp), VectorMultiplyAndAdd_comp_len / sizeof(uint32_t));
     158             : SpanView<uint32_t> VectorSubComp(reinterpret_cast<const uint32_t *>(VectorSubFloat_comp), VectorSubFloat_comp_len / sizeof(uint32_t));
     159             : SpanView<uint32_t> SumMatrixColumnsComp(reinterpret_cast<const uint32_t *>(SumMatrixColumns_comp), SumMatrixColumns_comp_len / sizeof(uint32_t));
     160             : SpanView<uint32_t> SumMatrixRowsComp(reinterpret_cast<const uint32_t *>(SumMatrixRows_comp), SumMatrixRows_comp_len / sizeof(uint32_t));
     161             : SpanView<uint32_t> MultiplyDiagMatrixByMatrixComp(reinterpret_cast<const uint32_t *>(MultiplyDiagMatrixByMatrix_comp), MultiplyDiagMatrixByMatrix_comp_len / sizeof(uint32_t));
     162             : 
     163             : SpanView<uint32_t> StatNormComp(reinterpret_cast<const uint32_t *>(StatNorm_comp), StatNorm_comp_len / sizeof(uint32_t));
     164             : SpanView<uint32_t> StatClassMapComp(reinterpret_cast<const uint32_t *>(StatClassMap_comp), StatClassMap_comp_len / sizeof(uint32_t));
     165             : SpanView<uint32_t> StatClassPercentComp(reinterpret_cast<const uint32_t *>(StatClassPercent_comp), StatClassPercent_comp_len / sizeof(uint32_t));
     166             : SpanView<uint32_t> StatAnalysisComp(reinterpret_cast<const uint32_t *>(StatAnalysis_comp), StatAnalysis_comp_len / sizeof(uint32_t));
     167             : 
     168           0 : Precision getAttachmentPrecision(const core::AttachmentData *data) {
     169           0 :         if (data->type == core::AttachmentType::Image) {
     170           0 :                 auto img = static_cast<core::ImageAttachment *>(data->attachment.get());
     171           0 :                 auto fmt = img->getImageInfo().format;
     172           0 :                 switch (fmt) {
     173           0 :                 case core::ImageFormat::R8_UNORM:
     174             :                 case core::ImageFormat::R8_SNORM:
     175             :                 case core::ImageFormat::R8_USCALED:
     176             :                 case core::ImageFormat::R8_SSCALED:
     177             :                 case core::ImageFormat::R8_UINT:
     178             :                 case core::ImageFormat::R8_SINT:
     179             :                 case core::ImageFormat::R8_SRGB:
     180             :                 case core::ImageFormat::R8G8_UNORM:
     181             :                 case core::ImageFormat::R8G8_SNORM:
     182             :                 case core::ImageFormat::R8G8_USCALED:
     183             :                 case core::ImageFormat::R8G8_SSCALED:
     184             :                 case core::ImageFormat::R8G8_UINT:
     185             :                 case core::ImageFormat::R8G8_SINT:
     186             :                 case core::ImageFormat::R8G8_SRGB:
     187             :                 case core::ImageFormat::R8G8B8_UNORM:
     188             :                 case core::ImageFormat::R8G8B8_SNORM:
     189             :                 case core::ImageFormat::R8G8B8_USCALED:
     190             :                 case core::ImageFormat::R8G8B8_SSCALED:
     191             :                 case core::ImageFormat::R8G8B8_UINT:
     192             :                 case core::ImageFormat::R8G8B8_SINT:
     193             :                 case core::ImageFormat::R8G8B8_SRGB:
     194             :                 case core::ImageFormat::B8G8R8_UNORM:
     195             :                 case core::ImageFormat::B8G8R8_SNORM:
     196             :                 case core::ImageFormat::B8G8R8_USCALED:
     197             :                 case core::ImageFormat::B8G8R8_SSCALED:
     198             :                 case core::ImageFormat::B8G8R8_UINT:
     199             :                 case core::ImageFormat::B8G8R8_SINT:
     200             :                 case core::ImageFormat::B8G8R8_SRGB:
     201             :                 case core::ImageFormat::R8G8B8A8_UNORM:
     202             :                 case core::ImageFormat::R8G8B8A8_SNORM:
     203             :                 case core::ImageFormat::R8G8B8A8_USCALED:
     204             :                 case core::ImageFormat::R8G8B8A8_SSCALED:
     205             :                 case core::ImageFormat::R8G8B8A8_UINT:
     206             :                 case core::ImageFormat::R8G8B8A8_SINT:
     207             :                 case core::ImageFormat::R8G8B8A8_SRGB:
     208             :                 case core::ImageFormat::B8G8R8A8_UNORM:
     209             :                 case core::ImageFormat::B8G8R8A8_SNORM:
     210             :                 case core::ImageFormat::B8G8R8A8_USCALED:
     211             :                 case core::ImageFormat::B8G8R8A8_SSCALED:
     212             :                 case core::ImageFormat::B8G8R8A8_UINT:
     213             :                 case core::ImageFormat::B8G8R8A8_SINT:
     214             :                 case core::ImageFormat::B8G8R8A8_SRGB:
     215             :                 case core::ImageFormat::A8B8G8R8_UNORM_PACK32:
     216             :                 case core::ImageFormat::A8B8G8R8_SNORM_PACK32:
     217             :                 case core::ImageFormat::A8B8G8R8_USCALED_PACK32:
     218             :                 case core::ImageFormat::A8B8G8R8_SSCALED_PACK32:
     219             :                 case core::ImageFormat::A8B8G8R8_UINT_PACK32:
     220             :                 case core::ImageFormat::A8B8G8R8_SINT_PACK32:
     221             :                 case core::ImageFormat::A8B8G8R8_SRGB_PACK32:
     222           0 :                         return Precision::F8;
     223             :                         break;
     224           0 :                 case core::ImageFormat::A2R10G10B10_UNORM_PACK32:
     225             :                 case core::ImageFormat::A2R10G10B10_SNORM_PACK32:
     226             :                 case core::ImageFormat::A2R10G10B10_USCALED_PACK32:
     227             :                 case core::ImageFormat::A2R10G10B10_SSCALED_PACK32:
     228             :                 case core::ImageFormat::A2R10G10B10_UINT_PACK32:
     229             :                 case core::ImageFormat::A2R10G10B10_SINT_PACK32:
     230             :                 case core::ImageFormat::A2B10G10R10_UNORM_PACK32:
     231             :                 case core::ImageFormat::A2B10G10R10_SNORM_PACK32:
     232             :                 case core::ImageFormat::A2B10G10R10_USCALED_PACK32:
     233             :                 case core::ImageFormat::A2B10G10R10_SSCALED_PACK32:
     234             :                 case core::ImageFormat::A2B10G10R10_UINT_PACK32:
     235             :                 case core::ImageFormat::A2B10G10R10_SINT_PACK32:
     236             :                 case core::ImageFormat::R16_UNORM:
     237             :                 case core::ImageFormat::R16_SNORM:
     238             :                 case core::ImageFormat::R16_USCALED:
     239             :                 case core::ImageFormat::R16_SSCALED:
     240             :                 case core::ImageFormat::R16_UINT:
     241             :                 case core::ImageFormat::R16_SINT:
     242             :                 case core::ImageFormat::R16_SFLOAT:
     243             :                 case core::ImageFormat::R16G16_UNORM:
     244             :                 case core::ImageFormat::R16G16_SNORM:
     245             :                 case core::ImageFormat::R16G16_USCALED:
     246             :                 case core::ImageFormat::R16G16_SSCALED:
     247             :                 case core::ImageFormat::R16G16_UINT:
     248             :                 case core::ImageFormat::R16G16_SINT:
     249             :                 case core::ImageFormat::R16G16_SFLOAT:
     250             :                 case core::ImageFormat::R16G16B16_UNORM:
     251             :                 case core::ImageFormat::R16G16B16_SNORM:
     252             :                 case core::ImageFormat::R16G16B16_USCALED:
     253             :                 case core::ImageFormat::R16G16B16_SSCALED:
     254             :                 case core::ImageFormat::R16G16B16_UINT:
     255             :                 case core::ImageFormat::R16G16B16_SINT:
     256             :                 case core::ImageFormat::R16G16B16_SFLOAT:
     257             :                 case core::ImageFormat::R16G16B16A16_UNORM:
     258             :                 case core::ImageFormat::R16G16B16A16_SNORM:
     259             :                 case core::ImageFormat::R16G16B16A16_USCALED:
     260             :                 case core::ImageFormat::R16G16B16A16_SSCALED:
     261             :                 case core::ImageFormat::R16G16B16A16_UINT:
     262             :                 case core::ImageFormat::R16G16B16A16_SINT:
     263             :                 case core::ImageFormat::R16G16B16A16_SFLOAT:
     264           0 :                         return Precision::F16;
     265             :                         break;
     266           0 :                 case core::ImageFormat::R32_UINT:
     267             :                 case core::ImageFormat::R32_SINT:
     268             :                 case core::ImageFormat::R32_SFLOAT:
     269             :                 case core::ImageFormat::R32G32_UINT:
     270             :                 case core::ImageFormat::R32G32_SINT:
     271             :                 case core::ImageFormat::R32G32_SFLOAT:
     272             :                 case core::ImageFormat::R32G32B32_UINT:
     273             :                 case core::ImageFormat::R32G32B32_SINT:
     274             :                 case core::ImageFormat::R32G32B32_SFLOAT:
     275             :                 case core::ImageFormat::R32G32B32A32_UINT:
     276             :                 case core::ImageFormat::R32G32B32A32_SINT:
     277             :                 case core::ImageFormat::R32G32B32A32_SFLOAT:
     278           0 :                         return Precision::F32;
     279             :                         break;
     280           0 :                 case core::ImageFormat::R64_UINT:
     281             :                 case core::ImageFormat::R64_SINT:
     282             :                 case core::ImageFormat::R64_SFLOAT:
     283             :                 case core::ImageFormat::R64G64_UINT:
     284             :                 case core::ImageFormat::R64G64_SINT:
     285             :                 case core::ImageFormat::R64G64_SFLOAT:
     286             :                 case core::ImageFormat::R64G64B64_UINT:
     287             :                 case core::ImageFormat::R64G64B64_SINT:
     288             :                 case core::ImageFormat::R64G64B64_SFLOAT:
     289             :                 case core::ImageFormat::R64G64B64A64_UINT:
     290             :                 case core::ImageFormat::R64G64B64A64_SINT:
     291             :                 case core::ImageFormat::R64G64B64A64_SFLOAT:
     292           0 :                         return Precision::F32;
     293             :                         break;
     294           0 :                 default:
     295           0 :                         return Precision::Unknown;
     296             :                         break;
     297             :                 }
     298             :         }
     299           0 :         return Precision::Unknown;
     300             : }
     301             : 
     302          65 : SpanView<uint32_t> getShader(LayerShader sh, Precision p) {
     303          65 :         switch (p) {
     304           0 :         case Precision::F16:
     305             :                 switch (sh) {
     306           0 :                 case LayerShader::Gen: return GenF16Comp; break;
     307           0 :                 case LayerShader::Norm: return NormF16Comp; break;
     308           0 :                 case LayerShader::Activation: return ActivationF16Comp; break;
     309           0 :                 case LayerShader::Add: return AddF16Comp; break;
     310           0 :                 case LayerShader::Avgpool2d: return Avgpool2dF16Comp; break;
     311           0 :                 case LayerShader::Batchnorm: return BatchnormF16Comp; break;
     312           0 :                 case LayerShader::Concat: return ConcatF16Comp; break;
     313           0 :                 case LayerShader::Conv2d1x1: return Conv2d1x1F16Comp; break;
     314           0 :                 case LayerShader::Conv2d: return Conv2dF16Comp; break;
     315           0 :                 case LayerShader::Dense: return DenseF16Comp; break;
     316           0 :                 case LayerShader::Depthwise: return DepthwiseF16Comp; break;
     317           0 :                 case LayerShader::Flatten: return FlattenF16Comp; break;
     318           0 :                 case LayerShader::Instancenorm: return InstancenormF16Comp; break;
     319           0 :                 case LayerShader::Maxpool2d: return Maxpool2dF16Comp; break;
     320           0 :                 case LayerShader::Pad: return PadF16Comp; break;
     321           0 :                 case LayerShader::Resize: return ResizeF16Comp; break;
     322           0 :                 case LayerShader::Subpixel: return SubpixelF16Comp; break;
     323           0 :                 case LayerShader::Unary: return UnaryF16Comp; break;
     324           0 :                 case LayerShader::Upsampling2dBilinear: return Upsampling2dBilinearF16Comp; break;
     325           0 :                 case LayerShader::Upsampling2dNearest: return Upsampling2dNearestF16Comp; break;
     326           0 :                 case LayerShader::AddVectorToMatrixRows: return AddVectorToMatrixRowsComp; break;
     327           0 :                 case LayerShader::BufferNorm: return BufferNormComp; break;
     328           0 :                 case LayerShader::MultiplyMatrixByMatrix: return MultiplyMatrixByMatrixComp; break;
     329           0 :                 case LayerShader::MultiplyMatrixByMatrixBorder: return MultiplyMatrixByMatrixBordersComp; break;
     330           0 :                 case LayerShader::MultiplyMatrixByTransposedMatrix: return MultiplyMatrixByTransposedMatrixComp; break;
     331           0 :                 case LayerShader::MultiplyMatrixByTransposedMatrixBorder: return MultiplyMatrixByTransposedMatrixBordersComp; break;
     332           0 :                 case LayerShader::MultiplyTransposedMatrixByMatrix: return MultiplyTransposedMatrixByMatrixComp; break;
     333           0 :                 case LayerShader::MultiplyTransposedMatrixByMatrixBorder: return MultiplyTransposedMatrixByMatrixBordersComp; break;
     334           0 :                 case LayerShader::MatrixSoftmaxByRows: return MatrixSoftmaxByRowsComp; break;
     335           0 :                 case LayerShader::VectorAddFloat1: return VectorAddFloat1Comp; break;
     336           0 :                 case LayerShader::VectorAddFloat4: return VectorAddFloat4Comp; break;
     337           0 :                 case LayerShader::VectorReLU: return VectorReLUComp; break;
     338           0 :                 case LayerShader::VectorReLU4: return VectorReLU4Comp; break;
     339           0 :                 case LayerShader::VectorReLUDiff: return VectorReLUDiffComp; break;
     340           0 :                 case LayerShader::VectorLog: return VectorLogComp; break;
     341           0 :                 case LayerShader::VectorDotProduct: return VectorDotProductComp; break;
     342           0 :                 case LayerShader::VectorEltwiseMultiply: return VectorMultiplyFloatComp; break;
     343           0 :                 case LayerShader::VectorMultiplyAndAdd: return VectorMultiplyAndAddComp; break;
     344           0 :                 case LayerShader::VectorSub: return VectorSubComp; break;
     345           0 :                 case LayerShader::SumMatrixColumns: return SumMatrixColumnsComp; break;
     346           0 :                 case LayerShader::SumMatrixRows: return SumMatrixRowsComp; break;
     347           0 :                 case LayerShader::MultiplyDiagMatrixByMatrix: return MultiplyDiagMatrixByMatrixComp; break;
     348           0 :                 case LayerShader::StatNorm: return StatNormComp; break;
     349           0 :                 case LayerShader::StatClassMap: return StatClassMapComp; break;
     350           0 :                 case LayerShader::StatClassPercent: return StatClassPercentComp; break;
     351           0 :                 case LayerShader::StatAnalysis: return StatAnalysisComp; break;
     352             :                 }
     353           0 :                 break;
     354           0 :         case Precision::F32:
     355             :                 switch (sh) {
     356           0 :                 case LayerShader::Gen: return GenF32Comp; break;
     357           0 :                 case LayerShader::Norm: return NormF32Comp; break;
     358           0 :                 case LayerShader::Activation: return ActivationF32Comp; break;
     359           0 :                 case LayerShader::Add: return AddF32Comp; break;
     360           0 :                 case LayerShader::Avgpool2d: return Avgpool2dF32Comp; break;
     361           0 :                 case LayerShader::Batchnorm: return BatchnormF32Comp; break;
     362           0 :                 case LayerShader::Concat: return ConcatF32Comp; break;
     363           0 :                 case LayerShader::Conv2d1x1: return Conv2d1x1F32Comp; break;
     364           0 :                 case LayerShader::Conv2d: return Conv2dF32Comp; break;
     365           0 :                 case LayerShader::Dense: return DenseF32Comp; break;
     366           0 :                 case LayerShader::Depthwise: return DepthwiseF32Comp; break;
     367           0 :                 case LayerShader::Flatten: return FlattenF32Comp; break;
     368           0 :                 case LayerShader::Instancenorm: return InstancenormF32Comp; break;
     369           0 :                 case LayerShader::Maxpool2d: return Maxpool2dF32Comp; break;
     370           0 :                 case LayerShader::Pad: return PadF32Comp; break;
     371           0 :                 case LayerShader::Resize: return ResizeF32Comp; break;
     372           0 :                 case LayerShader::Subpixel: return SubpixelF32Comp; break;
     373           0 :                 case LayerShader::Unary: return UnaryF32Comp; break;
     374           0 :                 case LayerShader::Upsampling2dBilinear: return Upsampling2dBilinearF32Comp; break;
     375           0 :                 case LayerShader::Upsampling2dNearest: return Upsampling2dNearestF32Comp; break;
     376           0 :                 case LayerShader::AddVectorToMatrixRows: return AddVectorToMatrixRowsComp; break;
     377           0 :                 case LayerShader::BufferNorm: return BufferNormComp; break;
     378           0 :                 case LayerShader::MultiplyMatrixByMatrix: return MultiplyMatrixByMatrixComp; break;
     379           0 :                 case LayerShader::MultiplyMatrixByMatrixBorder: return MultiplyMatrixByMatrixBordersComp; break;
     380           0 :                 case LayerShader::MultiplyMatrixByTransposedMatrix: return MultiplyMatrixByTransposedMatrixComp; break;
     381           0 :                 case LayerShader::MultiplyMatrixByTransposedMatrixBorder: return MultiplyMatrixByTransposedMatrixBordersComp; break;
     382           0 :                 case LayerShader::MultiplyTransposedMatrixByMatrix: return MultiplyTransposedMatrixByMatrixComp; break;
     383           0 :                 case LayerShader::MultiplyTransposedMatrixByMatrixBorder: return MultiplyTransposedMatrixByMatrixBordersComp; break;
     384           0 :                 case LayerShader::MatrixSoftmaxByRows: return MatrixSoftmaxByRowsComp; break;
     385           0 :                 case LayerShader::VectorAddFloat1: return VectorAddFloat1Comp; break;
     386           0 :                 case LayerShader::VectorAddFloat4: return VectorAddFloat4Comp; break;
     387           0 :                 case LayerShader::VectorReLU: return VectorReLUComp; break;
     388           0 :                 case LayerShader::VectorReLU4: return VectorReLU4Comp; break;
     389           0 :                 case LayerShader::VectorReLUDiff: return VectorReLUDiffComp; break;
     390           0 :                 case LayerShader::VectorLog: return VectorLogComp; break;
     391           0 :                 case LayerShader::VectorDotProduct: return VectorDotProductComp; break;
     392           0 :                 case LayerShader::VectorEltwiseMultiply: return VectorMultiplyFloatComp; break;
     393           0 :                 case LayerShader::VectorMultiplyAndAdd: return VectorMultiplyAndAddComp; break;
     394           0 :                 case LayerShader::VectorSub: return VectorSubComp; break;
     395           0 :                 case LayerShader::SumMatrixColumns: return SumMatrixColumnsComp; break;
     396           0 :                 case LayerShader::SumMatrixRows: return SumMatrixRowsComp; break;
     397           0 :                 case LayerShader::MultiplyDiagMatrixByMatrix: return MultiplyDiagMatrixByMatrixComp; break;
     398           0 :                 case LayerShader::StatNorm: return StatNormComp; break;
     399           0 :                 case LayerShader::StatClassMap: return StatClassMapComp; break;
     400           0 :                 case LayerShader::StatClassPercent: return StatClassPercentComp; break;
     401           0 :                 case LayerShader::StatAnalysis: return StatAnalysisComp; break;
     402             :                 }
     403           0 :                 break;
     404          65 :         default:
     405          65 :                 break;
     406             :         }
     407             : 
     408          65 :         switch (sh) {
     409           3 :         case LayerShader::AddVectorToMatrixRows: return AddVectorToMatrixRowsComp; break;
     410           2 :         case LayerShader::BufferNorm: return BufferNormComp; break;
     411           2 :         case LayerShader::MultiplyMatrixByMatrix: return MultiplyMatrixByMatrixComp; break;
     412           2 :         case LayerShader::MultiplyMatrixByMatrixBorder: return MultiplyMatrixByMatrixBordersComp; break;
     413           3 :         case LayerShader::MultiplyMatrixByTransposedMatrix: return MultiplyMatrixByTransposedMatrixComp; break;
     414           3 :         case LayerShader::MultiplyMatrixByTransposedMatrixBorder: return MultiplyMatrixByTransposedMatrixBordersComp; break;
     415           3 :         case LayerShader::MultiplyTransposedMatrixByMatrix: return MultiplyTransposedMatrixByMatrixComp; break;
     416           3 :         case LayerShader::MultiplyTransposedMatrixByMatrixBorder: return MultiplyTransposedMatrixByMatrixBordersComp; break;
     417           1 :         case LayerShader::MatrixSoftmaxByRows: return MatrixSoftmaxByRowsComp; break;
     418           6 :         case LayerShader::VectorAddFloat1: return VectorAddFloat1Comp; break;
     419           6 :         case LayerShader::VectorAddFloat4: return VectorAddFloat4Comp; break;
     420           3 :         case LayerShader::VectorReLU: return VectorReLUComp; break;
     421           3 :         case LayerShader::VectorReLU4: return VectorReLU4Comp; break;
     422           2 :         case LayerShader::VectorReLUDiff: return VectorReLUDiffComp; break;
     423           1 :         case LayerShader::VectorLog: return VectorLogComp; break;
     424           1 :         case LayerShader::VectorDotProduct: return VectorDotProductComp; break;
     425           7 :         case LayerShader::VectorEltwiseMultiply: return VectorMultiplyFloatComp; break;
     426           6 :         case LayerShader::VectorMultiplyAndAdd: return VectorMultiplyAndAddComp; break;
     427           1 :         case LayerShader::VectorSub: return VectorSubComp; break;
     428           2 :         case LayerShader::SumMatrixColumns: return SumMatrixColumnsComp; break;
     429           3 :         case LayerShader::SumMatrixRows: return SumMatrixRowsComp; break;
     430           2 :         case LayerShader::MultiplyDiagMatrixByMatrix: return MultiplyDiagMatrixByMatrixComp; break;
     431             : 
     432           0 :         case LayerShader::StatNorm: return StatNormComp; break;
     433           0 :         case LayerShader::StatClassMap: return StatClassMapComp; break;
     434           0 :         case LayerShader::StatClassPercent: return StatClassPercentComp; break;
     435           0 :         case LayerShader::StatAnalysis: return StatAnalysisComp; break;
     436           0 :         default: break;
     437             :         }
     438             : 
     439           0 :         return SpanView<uint32_t>();
     440             : }
     441             : 
     442           9 : void FillFloatBuffer(uint8_t *buf, uint64_t size, float val) {
     443           9 :         auto target = (float *)buf;
     444           9 :         size /= sizeof(float);
     445     1334979 :         while (size > 0) {
     446     1334970 :                 *(target ++) = val;
     447     1334970 :                 -- size;
     448             :         }
     449           9 : }
     450             : 
     451             : }
     452             : 
     453             : #include "XLSnnVkActivationLayer.cc"
     454             : #include "XLSnnVkGenerationLayer.cc"
     455             : #include "XLSnnVkInputLayer.cc"
     456             : #include "XLSnnVkConvLayer.cc"
     457             : #include "XLSnnVkSubpixelLayer.cc"
     458             : #include "XLSnnVkMatrixMulLayer.cc"
     459             : #include "XLSnnVkStatPercentLayer.cc"
     460             : #include "XLSnnVkLossLayer.cc"
     461             : #include "XLSnnVkTrainableLayer.cc"
     462             : 
     463             : #include "XLSnnVkNmeMath.cc"

Generated by: LCOV version 1.14