LCOV - code coverage report
Current view: top level - xenolith/utils/shadernn/src/backend/vk - XLSnnVkNmeMath.cc (source / functions) Hit Total Coverage
Test: coverage.info Lines: 126 144 87.5 %
Date: 2024-05-06 04:51:23 Functions: 22 25 88.0 %

          Line data    Source code
       1             : /**
       2             :  Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
       3             : 
       4             :  Permission is hereby granted, free of charge, to any person obtaining a copy
       5             :  of this software and associated documentation files (the "Software"), to deal
       6             :  in the Software without restriction, including without limitation the rights
       7             :  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
       8             :  copies of the Software, and to permit persons to whom the Software is
       9             :  furnished to do so, subject to the following conditions:
      10             : 
      11             :  The above copyright notice and this permission notice shall be included in
      12             :  all copies or substantial portions of the Software.
      13             : 
      14             :  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      15             :  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      16             :  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      17             :  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      18             :  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      19             :  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      20             :  THE SOFTWARE.
      21             :  **/
      22             : 
      23             : #include "XLSnnVkNmeMath.h"
      24             : 
      25             : namespace stappler::xenolith::vk::shadernn {
      26             : 
      27             : struct AddVectorToMatrixRowsData {
      28             :         int batchSize;
      29             :         int matrixHeight;
      30             :         int matrixWidth;
      31             : };
      32             : 
      33             : struct MultiplyMatrixByMatrixData {
      34             :         int batchSize;
      35             :         int firstHeight;
      36             :         int firstWidth;
      37             :         int firstRowSize;
      38             :         int secondWidth;
      39             :         int secondRowSize;
      40             :         int resultRowSize;
      41             :         int toAdd;
      42             : };
      43             : 
      44             : struct BatchMultiplyMatrixByMatrixBordersData {
      45             :         int batchSize;
      46             :         int firstHeight;
      47             :         int firstWidth;
      48             :         int firstRowSize;
      49             :         int secondWidth;
      50             :         int secondRowSize;
      51             :         int resultRowSize;
      52             :         int leftOffset;
      53             :         int topOffset;
      54             :         int toAdd;
      55             : };
      56             : 
      57             : struct MultiplyMatrixByTransposedMatrixData {
      58             :         int batchSize;
      59             :         int firstHeight;
      60             :         int firstWidth;
      61             :         int firstRowSize;
      62             :         int secondHeight;
      63             :         int secondRowSize;
      64             :         int resultRowSize;
      65             :         int toAdd;
      66             : };
      67             : 
      68             : struct MultiplyMatrixByTransposedMatrixBordersData {
      69             :         int batchSize;
      70             :         int firstHeight;
      71             :         int firstWidth;
      72             :         int firstRowSize;
      73             :         int secondHeight;
      74             :         int secondRowSize;
      75             :         int resultRowSize;
      76             :         int leftOffset;
      77             :         int topOffset;
      78             :         int toAdd;
      79             : };
      80             : 
      81             : struct BatchMultiplyTransposedMatrixByMatrixData {
      82             :         int batchSize;
      83             :         int firstHeight;
      84             :         int firstWidth;
      85             :         int firstRowSize;
      86             :         int secondWidth;
      87             :         int secondRowSize;
      88             :         int resultRowSize;
      89             :         int toAdd;
      90             : };
      91             : 
      92             : struct BatchMultiplyTransposedMatrixByMatrixBordersData {
      93             :         int batchSize;
      94             :         int firstHeight;
      95             :         int firstWidth;
      96             :         int firstRowSize;
      97             :         int secondWidth;
      98             :         int secondRowSize;
      99             :         int resultRowSize;
     100             :         int leftOffset;
     101             :         int topOffset;
     102             :         int toAdd;
     103             : };
     104             : 
     105             : struct MatrixSoftmaxByRowsData {
     106             :         int matrixHeight;
     107             :         int matrixWidth;
     108             : };
     109             : 
     110             : struct MultiplyDiagMatrixByMatrixData {
     111             :         int height;
     112             :         int width;
     113             : };
     114             : 
     115             : struct SumMatrixColumnsData {
     116             :         int width;
     117             :         int height;
     118             : };
     119             : 
     120             : struct SumMatrixRowsData {
     121             :         int width;
     122             :         int height;
     123             :         int batchSize;
     124             : 
     125             :         int toAdd;
     126             : };
     127             : 
     128             : struct VectorLogData {
     129             :         int neg;
     130             : };
     131             : 
     132             : struct VectorMultiplyFloatData {
     133             :         int isSecondValue;
     134             :         int isNeg;
     135             :         int toAdd;
     136             : };
     137             : 
     138             : struct VectorDotData {
     139             :         int targetOffset;
     140             :         int hasMult;
     141             :         int multOffset;
     142             : };
     143             : 
     144       88800 : inline int Ceil(int val, int discret) {
     145       88800 :         if (val > 0) {
     146       88800 :                 return (val + discret - 1) / discret;
     147             :         }
     148           0 :         return val / discret;
     149             : }
     150             : 
     151             : // The maximum number of groups over the X dimension when working with a 1D (vector) shader
     152             : // With larger sizes, the shader data will be represented in two dimensions
     153             : constexpr int VulkanMaxVectorXGroupCount = 8192;
     154             : 
     155             : // The number of combined operations
     156             : constexpr int VectorCombine = 4;
     157             : 
     158       31200 : static void runVectorShader(CommandBuffer &buf, ComputePipeline *pipeline, BytesView pcb, int count) {
     159       31200 :         int groupCountX = Ceil(count, pipeline->getLocalX());
     160       31200 :         int groupCountY = Ceil(groupCountX, VulkanMaxVectorXGroupCount);
     161       31200 :         groupCountX = std::min<int>(groupCountX, VulkanMaxVectorXGroupCount);
     162             : 
     163       31200 :         if (!pcb.empty()) {
     164       15600 :                 buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, pcb);
     165             :         }
     166             : 
     167       31200 :         buf.cmdBindPipeline(pipeline);
     168       31200 :         buf.cmdDispatch(groupCountX, groupCountY, 1);
     169       31200 : }
     170             : 
     171        3600 : static void BatchMultiplyMatrixByTransposedMatrix(
     172             :         CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
     173             :         bool toAdd, int batchSize,
     174             :         int firstHeight, int firstWidth, int firstRowSize,
     175             :         int secondHeight, int secondRowSize,
     176             :         int resultRowSize, int resultBufferSize ) {
     177             : 
     178        3600 :         if( firstHeight >= 4 && secondHeight >= 4 ) {
     179             :                 MultiplyMatrixByTransposedMatrixData param = { batchSize, firstHeight, firstWidth, firstRowSize,
     180        3600 :                         secondHeight, secondRowSize, resultRowSize,  ( toAdd ) ? 1 : 0 };
     181             : 
     182        3600 :                 buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     183        3600 :                 buf.cmdDispatchPipeline(mul, firstHeight / 4, secondHeight / 4, batchSize);
     184             :         }
     185             : 
     186        3600 :         int leftOffset = secondHeight - secondHeight % 4;
     187        3600 :         int topOffset = firstHeight - firstHeight % 4;
     188        3600 :         int count = secondHeight * firstHeight - leftOffset * topOffset;
     189        3600 :         if ( count > 0 ) {
     190             :                 MultiplyMatrixByTransposedMatrixBordersData param = { batchSize, firstHeight, firstWidth, firstRowSize,
     191        1200 :                         secondHeight, secondRowSize, resultRowSize, leftOffset, topOffset, toAdd };
     192        1200 :                 buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     193        1200 :                 buf.cmdDispatchPipeline(borders, count, batchSize, 1);
     194             :         }
     195        3600 : }
     196             : 
     197        3600 : static void batchMultiplyTransposedMatrixByMatrix(
     198             :                 CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
     199             :                 bool toAdd, int batchSize, int firstHeight, int firstWidth, int firstRowSize,
     200             :                 int secondWidth, int secondRowSize, int resultRowSize, int resultBufferSize ) {
     201        3600 :         if( firstWidth >= 4 && secondWidth >= 4 ) {
     202             :                 BatchMultiplyTransposedMatrixByMatrixData param = { batchSize, firstHeight, firstWidth, firstRowSize,
     203        3600 :                         secondWidth, secondRowSize, resultRowSize, toAdd ? 1 : 0 };
     204             : 
     205        3600 :                 buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     206        3600 :                 buf.cmdDispatchPipeline(mul, secondWidth / 4, firstWidth / 4, batchSize);
     207             :         }
     208             : 
     209        3600 :         int leftOffset = secondWidth - secondWidth % 4;
     210        3600 :         int topOffset = firstWidth - firstWidth % 4;
     211        3600 :         int count = secondWidth * firstWidth - leftOffset * topOffset;
     212        3600 :         if( count > 0 ) {
     213             :                 BatchMultiplyTransposedMatrixByMatrixBordersData param = { batchSize, firstHeight, firstWidth, firstRowSize,
     214        1200 :                         secondWidth, secondRowSize, resultRowSize, leftOffset, topOffset, toAdd ? 1 : 0 };
     215             : 
     216        1200 :                 buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     217        1200 :                 buf.cmdDispatchPipeline(borders, count, batchSize, 1);
     218             :         }
     219        3600 : }
     220             : 
     221        2400 : static void multiplyMatrixByMatrix(
     222             :                 CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
     223             :                 bool toAdd, int batchSize, int firstHeight, int firstWidth, int firstRowSize, int secondWidth,
     224             :         int secondRowSize, int resultRowSize, int resultBufferSize ) {
     225             : 
     226        2400 :         if( firstHeight >= 4 && secondWidth >= 4 ) {
     227             :                 MultiplyMatrixByMatrixData param = { batchSize, firstHeight, firstWidth, firstRowSize,
     228        2400 :                         secondWidth, secondRowSize, resultRowSize, toAdd ? 1 : 0 };
     229             : 
     230        2400 :                 buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     231        2400 :                 buf.cmdDispatchPipeline(mul, secondWidth / 4, firstHeight / 4, batchSize);
     232             :         }
     233             : 
     234        2400 :         int leftOffset = secondWidth - secondWidth % 4;
     235        2400 :     int topOffset = firstHeight - firstHeight % 4;
     236        2400 :     int count = secondWidth * firstHeight - leftOffset * topOffset;
     237        2400 :     if( count > 0 ) {
     238             :                 BatchMultiplyMatrixByMatrixBordersData param = { batchSize, firstHeight, firstWidth, firstRowSize,
     239           0 :                         secondWidth, secondRowSize, resultRowSize, leftOffset, topOffset, toAdd ? 1 : 0 };
     240             : 
     241           0 :                 buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     242           0 :                 buf.cmdDispatchPipeline(borders, count, batchSize, 1);
     243             :         }
     244        2400 : }
     245             : 
     246        2400 : void MultiplyMatrixByMatrix(
     247             :                 CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
     248             :                 int batchSize, int firstHeight, int firstWidth, int secondWidth, int resultBufferSize) {
     249        2400 :         multiplyMatrixByMatrix(buf, mul, borders, false, batchSize, firstHeight, firstWidth, firstWidth,
     250             :                         secondWidth, secondWidth, secondWidth, resultBufferSize );
     251        2400 : }
     252             : 
     253        3600 : void MultiplyMatrixByTransposedMatrix(
     254             :                 CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
     255             :                 int firstHeight, int firstWidth, int firstRowSize,
     256             :                 int secondHeight, int secondRowSize,
     257             :                 int resultRowSize, int resultBufferSize ) {
     258             : 
     259        3600 :         BatchMultiplyMatrixByTransposedMatrix( buf, mul, borders, false, 1,
     260             :                 firstHeight, firstWidth, firstRowSize,
     261             :                 secondHeight, secondRowSize,
     262             :                 resultRowSize, resultBufferSize );
     263        3600 : }
     264             : 
     265           0 : void MultiplyMatrixByTransposedMatrix(
     266             :                 CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
     267             :                 int batchSize, int firstHeight, int firstWidth,
     268             :                 int secondHeight, int resultBufferSize )
     269             : {
     270           0 :         BatchMultiplyMatrixByTransposedMatrix( buf, mul, borders, false, batchSize,
     271             :                 firstHeight, firstWidth, firstWidth,
     272             :                 secondHeight, firstWidth,
     273             :                 secondHeight, resultBufferSize );
     274           0 : }
     275             : 
     276           0 : void MultiplyTransposedMatrixByMatrixAndAdd(
     277             :                 CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
     278             :                 int firstHeight, int firstWidth, int firstRowSize,
     279             :                 int secondWidth, int secondRowSize,
     280             :                 int resultRowSize, int resultBufferSize) {
     281           0 :         batchMultiplyTransposedMatrixByMatrix( buf, mul, borders, true, 1, firstHeight, firstWidth, firstRowSize,
     282             :                 secondWidth, secondRowSize, resultRowSize, resultBufferSize );
     283           0 : }
     284             : 
     285        3600 : void MultiplyTransposedMatrixByMatrix(
     286             :                 CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
     287             :                 int firstHeight, int firstWidth, int firstRowSize,
     288             :                 int secondWidth, int secondRowSize,
     289             :                 int resultRowSize, int resultBufferSize) {
     290        3600 :         batchMultiplyTransposedMatrixByMatrix( buf, mul, borders, false, 1, firstHeight, firstWidth, firstRowSize,
     291             :                 secondWidth, secondRowSize, resultRowSize, resultBufferSize );
     292        3600 : }
     293             : 
     294        3600 : void AddVectorToMatrixRows(CommandBuffer &buf, ComputePipeline *p, int batchSize,
     295             :         int matrixHeight, int matrixWidth)
     296             : {
     297        3600 :         AddVectorToMatrixRowsData param = { batchSize, matrixHeight, matrixWidth };
     298             : 
     299        3600 :         buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     300        3600 :         buf.cmdDispatchPipeline(p, matrixWidth, Ceil( matrixHeight, 4 ), batchSize);
     301        3600 : }
     302             : 
     303        7200 : void VectorAdd(CommandBuffer &buf, ComputePipeline *add4, ComputePipeline *add1, int vectorSize) {
     304        7200 :         int countQuad = ( vectorSize / 16 ) * 4;
     305        7200 :         if( countQuad > 0 ) {
     306        6000 :                 runVectorShader(buf, add4, BytesView(), countQuad);
     307             :         }
     308             : 
     309        7200 :         int countSingle = vectorSize % 16;
     310        7200 :         if( countSingle > 0 ) {
     311        1200 :                 int offset = vectorSize - countSingle;
     312             : 
     313             :                 struct {
     314             :                         int offset;
     315        1200 :                 } param = { offset };
     316             : 
     317        1200 :                 runVectorShader(buf, add1, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)), countSingle);
     318             :         }
     319        7200 : }
     320             : 
     321        2400 : void VectorReLU(CommandBuffer &buf, ComputePipeline *relu4, ComputePipeline *relu, int vectorSize, float threshold) {
     322        2400 :         int countQuad = ( vectorSize / 16 ) * 4;
     323        2400 :         if( countQuad > 0 ) {
     324             :                 struct {
     325             :                         float value;
     326        2400 :                 } param = { threshold };
     327             : 
     328        2400 :                 runVectorShader(buf, relu4, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)), countQuad);
     329             :         }
     330             : 
     331        2400 :         int countSingle = vectorSize % 16;
     332        2400 :         if( countSingle > 0 ) {
     333           0 :                 int offset = vectorSize - countSingle;
     334             : 
     335             :                 struct {
     336             :                         float value;
     337             :                         int offset;
     338           0 :                 } param = { threshold, offset };
     339             : 
     340           0 :                 runVectorShader(buf, relu, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)), countSingle);
     341             :         }
     342        2400 : }
     343             : 
     344        2400 : void VectorReLUDiff(CommandBuffer &buf, ComputePipeline *relu, int vectorSize, float threshold) {
     345             :         struct {
     346             :                 float value;
     347        2400 :         } param = { threshold };
     348             : 
     349        2400 :         runVectorShader(buf, relu, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)), Ceil(vectorSize, VectorCombine));
     350        2400 : }
     351             : 
     352        1200 : void MatrixSoftmaxByRows(CommandBuffer &buf, ComputePipeline *p, int height, int width) {
     353        1200 :         MatrixSoftmaxByRowsData param = { height, width };
     354             : 
     355        1200 :         buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0,
     356             :                         BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     357        1200 :         buf.cmdDispatchPipeline(p, width, height, 1);
     358        1200 : }
     359             : 
     360        1200 : void VectorNegLog(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
     361        1200 :         VectorLogData param = { 1 };
     362             : 
     363        1200 :         runVectorShader(buf, p, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)), Ceil(vectorSize, VectorCombine));
     364        1200 : }
     365             : 
     366        1200 : void VectorEltwiseMultiply(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
     367        1200 :         VectorMultiplyFloatData param = { 0, 0, 0 };
     368             : 
     369        1200 :         runVectorShader(buf, p, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)), Ceil(vectorSize, VectorCombine));
     370        1200 : }
     371             : 
     372        7200 : void VectorMultiply(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
     373        7200 :         VectorMultiplyFloatData param = { 1, 0, 0 };
     374             : 
     375        7200 :         runVectorShader(buf, p, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)), Ceil(vectorSize, VectorCombine));
     376        7200 : }
     377             : 
     378        7200 : void VectorMultiplyAndAdd(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
     379        7200 :         runVectorShader(buf, p, BytesView(), Ceil(vectorSize, VectorCombine));
     380        7200 : }
     381             : 
     382        1200 : void VectorSub(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
     383        1200 :         runVectorShader(buf, p, BytesView(), Ceil(vectorSize, VectorCombine));
     384        1200 : }
     385             : 
     386        2400 : void SumMatrixColumns(CommandBuffer &buf, ComputePipeline *p, int matrixHeight, int matrixWidth) {
     387        2400 :         SumMatrixColumnsData param = { matrixWidth, matrixHeight };
     388             : 
     389        2400 :         buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     390        2400 :         buf.cmdDispatchPipeline(p, matrixHeight, 1, 1);
     391        2400 : }
     392             : 
     393           0 : void SumMatrixRowsAdd(CommandBuffer &buf, ComputePipeline *p, int batchSize, int matrixHeight, int matrixWidth ) {
     394           0 :         struct SumMatrixRowsData param = { matrixWidth, matrixHeight, batchSize, 1 };
     395           0 :         buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     396           0 :         buf.cmdDispatchPipeline(p, matrixWidth, 1, batchSize);
     397           0 : }
     398             : 
     399        3600 : void SumMatrixRows(CommandBuffer &buf, ComputePipeline *p, int batchSize, int matrixHeight, int matrixWidth ) {
     400        3600 :         struct SumMatrixRowsData param = { matrixWidth, matrixHeight, batchSize, 0 };
     401        3600 :         buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     402        3600 :         buf.cmdDispatchPipeline(p, matrixWidth, 1, batchSize);
     403        3600 : }
     404             : 
     405        2400 : void MultiplyDiagMatrixByMatrix(CommandBuffer &buf, ComputePipeline *p, int firstSize, int secondWidth, int resultBufferSize ) {
     406        2400 :         MultiplyDiagMatrixByMatrixData param = { firstSize, secondWidth };
     407             : 
     408        2400 :         buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(&param), sizeof(param)));
     409        2400 :         buf.cmdDispatchPipeline(p, Ceil( firstSize, 4 ), secondWidth, 1);
     410        2400 : }
     411             : 
     412        1200 : void VectorDotProduct(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
     413        1200 :         runVectorShader(buf, p, BytesView(),
     414        1200 :                         p->getLocalX() * p->getLocalY() * p->getLocalZ());
     415        1200 : }
     416             : 
     417             : }

Generated by: LCOV version 1.14