Line data Source code
1 : /**
2 : Copyright (c) 2023 Stappler LLC <admin@stappler.dev>
3 :
4 : Permission is hereby granted, free of charge, to any person obtaining a copy
5 : of this software and associated documentation files (the "Software"), to deal
6 : in the Software without restriction, including without limitation the rights
7 : to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 : copies of the Software, and to permit persons to whom the Software is
9 : furnished to do so, subject to the following conditions:
10 :
11 : The above copyright notice and this permission notice shall be included in
12 : all copies or substantial portions of the Software.
13 :
14 : THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 : IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 : FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 : AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 : LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 : OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 : THE SOFTWARE.
21 : **/
22 :
23 : #include "XLSnnVkNmeMath.h"
24 :
25 : namespace stappler::xenolith::vk::shadernn {
26 :
27 : struct AddVectorToMatrixRowsData {
28 : int batchSize;
29 : int matrixHeight;
30 : int matrixWidth;
31 : };
32 :
33 : struct MultiplyMatrixByMatrixData {
34 : int batchSize;
35 : int firstHeight;
36 : int firstWidth;
37 : int firstRowSize;
38 : int secondWidth;
39 : int secondRowSize;
40 : int resultRowSize;
41 : int toAdd;
42 : };
43 :
44 : struct BatchMultiplyMatrixByMatrixBordersData {
45 : int batchSize;
46 : int firstHeight;
47 : int firstWidth;
48 : int firstRowSize;
49 : int secondWidth;
50 : int secondRowSize;
51 : int resultRowSize;
52 : int leftOffset;
53 : int topOffset;
54 : int toAdd;
55 : };
56 :
57 : struct MultiplyMatrixByTransposedMatrixData {
58 : int batchSize;
59 : int firstHeight;
60 : int firstWidth;
61 : int firstRowSize;
62 : int secondHeight;
63 : int secondRowSize;
64 : int resultRowSize;
65 : int toAdd;
66 : };
67 :
68 : struct MultiplyMatrixByTransposedMatrixBordersData {
69 : int batchSize;
70 : int firstHeight;
71 : int firstWidth;
72 : int firstRowSize;
73 : int secondHeight;
74 : int secondRowSize;
75 : int resultRowSize;
76 : int leftOffset;
77 : int topOffset;
78 : int toAdd;
79 : };
80 :
81 : struct BatchMultiplyTransposedMatrixByMatrixData {
82 : int batchSize;
83 : int firstHeight;
84 : int firstWidth;
85 : int firstRowSize;
86 : int secondWidth;
87 : int secondRowSize;
88 : int resultRowSize;
89 : int toAdd;
90 : };
91 :
92 : struct BatchMultiplyTransposedMatrixByMatrixBordersData {
93 : int batchSize;
94 : int firstHeight;
95 : int firstWidth;
96 : int firstRowSize;
97 : int secondWidth;
98 : int secondRowSize;
99 : int resultRowSize;
100 : int leftOffset;
101 : int topOffset;
102 : int toAdd;
103 : };
104 :
105 : struct MatrixSoftmaxByRowsData {
106 : int matrixHeight;
107 : int matrixWidth;
108 : };
109 :
110 : struct MultiplyDiagMatrixByMatrixData {
111 : int height;
112 : int width;
113 : };
114 :
115 : struct SumMatrixColumnsData {
116 : int width;
117 : int height;
118 : };
119 :
120 : struct SumMatrixRowsData {
121 : int width;
122 : int height;
123 : int batchSize;
124 :
125 : int toAdd;
126 : };
127 :
128 : struct VectorLogData {
129 : int neg;
130 : };
131 :
132 : struct VectorMultiplyFloatData {
133 : int isSecondValue;
134 : int isNeg;
135 : int toAdd;
136 : };
137 :
138 : struct VectorDotData {
139 : int targetOffset;
140 : int hasMult;
141 : int multOffset;
142 : };
143 :
144 88800 : inline int Ceil(int val, int discret) {
145 88800 : if (val > 0) {
146 88800 : return (val + discret - 1) / discret;
147 : }
148 0 : return val / discret;
149 : }
150 :
151 : // The maximum number of groups over the X dimension when working with a 1D (vector) shader
152 : // With larger sizes, the shader data will be represented in two dimensions
153 : constexpr int VulkanMaxVectorXGroupCount = 8192;
154 :
155 : // The number of combined operations
156 : constexpr int VectorCombine = 4;
157 :
158 31200 : static void runVectorShader(CommandBuffer &buf, ComputePipeline *pipeline, BytesView pcb, int count) {
159 31200 : int groupCountX = Ceil(count, pipeline->getLocalX());
160 31200 : int groupCountY = Ceil(groupCountX, VulkanMaxVectorXGroupCount);
161 31200 : groupCountX = std::min<int>(groupCountX, VulkanMaxVectorXGroupCount);
162 :
163 31200 : if (!pcb.empty()) {
164 15600 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, pcb);
165 : }
166 :
167 31200 : buf.cmdBindPipeline(pipeline);
168 31200 : buf.cmdDispatch(groupCountX, groupCountY, 1);
169 31200 : }
170 :
171 3600 : static void BatchMultiplyMatrixByTransposedMatrix(
172 : CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
173 : bool toAdd, int batchSize,
174 : int firstHeight, int firstWidth, int firstRowSize,
175 : int secondHeight, int secondRowSize,
176 : int resultRowSize, int resultBufferSize ) {
177 :
178 3600 : if( firstHeight >= 4 && secondHeight >= 4 ) {
179 : MultiplyMatrixByTransposedMatrixData param = { batchSize, firstHeight, firstWidth, firstRowSize,
180 3600 : secondHeight, secondRowSize, resultRowSize, ( toAdd ) ? 1 : 0 };
181 :
182 3600 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
183 3600 : buf.cmdDispatchPipeline(mul, firstHeight / 4, secondHeight / 4, batchSize);
184 : }
185 :
186 3600 : int leftOffset = secondHeight - secondHeight % 4;
187 3600 : int topOffset = firstHeight - firstHeight % 4;
188 3600 : int count = secondHeight * firstHeight - leftOffset * topOffset;
189 3600 : if ( count > 0 ) {
190 : MultiplyMatrixByTransposedMatrixBordersData param = { batchSize, firstHeight, firstWidth, firstRowSize,
191 1200 : secondHeight, secondRowSize, resultRowSize, leftOffset, topOffset, toAdd };
192 1200 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
193 1200 : buf.cmdDispatchPipeline(borders, count, batchSize, 1);
194 : }
195 3600 : }
196 :
197 3600 : static void batchMultiplyTransposedMatrixByMatrix(
198 : CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
199 : bool toAdd, int batchSize, int firstHeight, int firstWidth, int firstRowSize,
200 : int secondWidth, int secondRowSize, int resultRowSize, int resultBufferSize ) {
201 3600 : if( firstWidth >= 4 && secondWidth >= 4 ) {
202 : BatchMultiplyTransposedMatrixByMatrixData param = { batchSize, firstHeight, firstWidth, firstRowSize,
203 3600 : secondWidth, secondRowSize, resultRowSize, toAdd ? 1 : 0 };
204 :
205 3600 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
206 3600 : buf.cmdDispatchPipeline(mul, secondWidth / 4, firstWidth / 4, batchSize);
207 : }
208 :
209 3600 : int leftOffset = secondWidth - secondWidth % 4;
210 3600 : int topOffset = firstWidth - firstWidth % 4;
211 3600 : int count = secondWidth * firstWidth - leftOffset * topOffset;
212 3600 : if( count > 0 ) {
213 : BatchMultiplyTransposedMatrixByMatrixBordersData param = { batchSize, firstHeight, firstWidth, firstRowSize,
214 1200 : secondWidth, secondRowSize, resultRowSize, leftOffset, topOffset, toAdd ? 1 : 0 };
215 :
216 1200 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
217 1200 : buf.cmdDispatchPipeline(borders, count, batchSize, 1);
218 : }
219 3600 : }
220 :
221 2400 : static void multiplyMatrixByMatrix(
222 : CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
223 : bool toAdd, int batchSize, int firstHeight, int firstWidth, int firstRowSize, int secondWidth,
224 : int secondRowSize, int resultRowSize, int resultBufferSize ) {
225 :
226 2400 : if( firstHeight >= 4 && secondWidth >= 4 ) {
227 : MultiplyMatrixByMatrixData param = { batchSize, firstHeight, firstWidth, firstRowSize,
228 2400 : secondWidth, secondRowSize, resultRowSize, toAdd ? 1 : 0 };
229 :
230 2400 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
231 2400 : buf.cmdDispatchPipeline(mul, secondWidth / 4, firstHeight / 4, batchSize);
232 : }
233 :
234 2400 : int leftOffset = secondWidth - secondWidth % 4;
235 2400 : int topOffset = firstHeight - firstHeight % 4;
236 2400 : int count = secondWidth * firstHeight - leftOffset * topOffset;
237 2400 : if( count > 0 ) {
238 : BatchMultiplyMatrixByMatrixBordersData param = { batchSize, firstHeight, firstWidth, firstRowSize,
239 0 : secondWidth, secondRowSize, resultRowSize, leftOffset, topOffset, toAdd ? 1 : 0 };
240 :
241 0 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
242 0 : buf.cmdDispatchPipeline(borders, count, batchSize, 1);
243 : }
244 2400 : }
245 :
246 2400 : void MultiplyMatrixByMatrix(
247 : CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
248 : int batchSize, int firstHeight, int firstWidth, int secondWidth, int resultBufferSize) {
249 2400 : multiplyMatrixByMatrix(buf, mul, borders, false, batchSize, firstHeight, firstWidth, firstWidth,
250 : secondWidth, secondWidth, secondWidth, resultBufferSize );
251 2400 : }
252 :
253 3600 : void MultiplyMatrixByTransposedMatrix(
254 : CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
255 : int firstHeight, int firstWidth, int firstRowSize,
256 : int secondHeight, int secondRowSize,
257 : int resultRowSize, int resultBufferSize ) {
258 :
259 3600 : BatchMultiplyMatrixByTransposedMatrix( buf, mul, borders, false, 1,
260 : firstHeight, firstWidth, firstRowSize,
261 : secondHeight, secondRowSize,
262 : resultRowSize, resultBufferSize );
263 3600 : }
264 :
265 0 : void MultiplyMatrixByTransposedMatrix(
266 : CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
267 : int batchSize, int firstHeight, int firstWidth,
268 : int secondHeight, int resultBufferSize )
269 : {
270 0 : BatchMultiplyMatrixByTransposedMatrix( buf, mul, borders, false, batchSize,
271 : firstHeight, firstWidth, firstWidth,
272 : secondHeight, firstWidth,
273 : secondHeight, resultBufferSize );
274 0 : }
275 :
276 0 : void MultiplyTransposedMatrixByMatrixAndAdd(
277 : CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
278 : int firstHeight, int firstWidth, int firstRowSize,
279 : int secondWidth, int secondRowSize,
280 : int resultRowSize, int resultBufferSize) {
281 0 : batchMultiplyTransposedMatrixByMatrix( buf, mul, borders, true, 1, firstHeight, firstWidth, firstRowSize,
282 : secondWidth, secondRowSize, resultRowSize, resultBufferSize );
283 0 : }
284 :
285 3600 : void MultiplyTransposedMatrixByMatrix(
286 : CommandBuffer &buf, ComputePipeline *mul, ComputePipeline *borders,
287 : int firstHeight, int firstWidth, int firstRowSize,
288 : int secondWidth, int secondRowSize,
289 : int resultRowSize, int resultBufferSize) {
290 3600 : batchMultiplyTransposedMatrixByMatrix( buf, mul, borders, false, 1, firstHeight, firstWidth, firstRowSize,
291 : secondWidth, secondRowSize, resultRowSize, resultBufferSize );
292 3600 : }
293 :
294 3600 : void AddVectorToMatrixRows(CommandBuffer &buf, ComputePipeline *p, int batchSize,
295 : int matrixHeight, int matrixWidth)
296 : {
297 3600 : AddVectorToMatrixRowsData param = { batchSize, matrixHeight, matrixWidth };
298 :
299 3600 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
300 3600 : buf.cmdDispatchPipeline(p, matrixWidth, Ceil( matrixHeight, 4 ), batchSize);
301 3600 : }
302 :
303 7200 : void VectorAdd(CommandBuffer &buf, ComputePipeline *add4, ComputePipeline *add1, int vectorSize) {
304 7200 : int countQuad = ( vectorSize / 16 ) * 4;
305 7200 : if( countQuad > 0 ) {
306 6000 : runVectorShader(buf, add4, BytesView(), countQuad);
307 : }
308 :
309 7200 : int countSingle = vectorSize % 16;
310 7200 : if( countSingle > 0 ) {
311 1200 : int offset = vectorSize - countSingle;
312 :
313 : struct {
314 : int offset;
315 1200 : } param = { offset };
316 :
317 1200 : runVectorShader(buf, add1, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)), countSingle);
318 : }
319 7200 : }
320 :
321 2400 : void VectorReLU(CommandBuffer &buf, ComputePipeline *relu4, ComputePipeline *relu, int vectorSize, float threshold) {
322 2400 : int countQuad = ( vectorSize / 16 ) * 4;
323 2400 : if( countQuad > 0 ) {
324 : struct {
325 : float value;
326 2400 : } param = { threshold };
327 :
328 2400 : runVectorShader(buf, relu4, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)), countQuad);
329 : }
330 :
331 2400 : int countSingle = vectorSize % 16;
332 2400 : if( countSingle > 0 ) {
333 0 : int offset = vectorSize - countSingle;
334 :
335 : struct {
336 : float value;
337 : int offset;
338 0 : } param = { threshold, offset };
339 :
340 0 : runVectorShader(buf, relu, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)), countSingle);
341 : }
342 2400 : }
343 :
344 2400 : void VectorReLUDiff(CommandBuffer &buf, ComputePipeline *relu, int vectorSize, float threshold) {
345 : struct {
346 : float value;
347 2400 : } param = { threshold };
348 :
349 2400 : runVectorShader(buf, relu, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)), Ceil(vectorSize, VectorCombine));
350 2400 : }
351 :
352 1200 : void MatrixSoftmaxByRows(CommandBuffer &buf, ComputePipeline *p, int height, int width) {
353 1200 : MatrixSoftmaxByRowsData param = { height, width };
354 :
355 1200 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0,
356 : BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
357 1200 : buf.cmdDispatchPipeline(p, width, height, 1);
358 1200 : }
359 :
360 1200 : void VectorNegLog(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
361 1200 : VectorLogData param = { 1 };
362 :
363 1200 : runVectorShader(buf, p, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)), Ceil(vectorSize, VectorCombine));
364 1200 : }
365 :
366 1200 : void VectorEltwiseMultiply(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
367 1200 : VectorMultiplyFloatData param = { 0, 0, 0 };
368 :
369 1200 : runVectorShader(buf, p, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)), Ceil(vectorSize, VectorCombine));
370 1200 : }
371 :
372 7200 : void VectorMultiply(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
373 7200 : VectorMultiplyFloatData param = { 1, 0, 0 };
374 :
375 7200 : runVectorShader(buf, p, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)), Ceil(vectorSize, VectorCombine));
376 7200 : }
377 :
378 7200 : void VectorMultiplyAndAdd(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
379 7200 : runVectorShader(buf, p, BytesView(), Ceil(vectorSize, VectorCombine));
380 7200 : }
381 :
382 1200 : void VectorSub(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
383 1200 : runVectorShader(buf, p, BytesView(), Ceil(vectorSize, VectorCombine));
384 1200 : }
385 :
386 2400 : void SumMatrixColumns(CommandBuffer &buf, ComputePipeline *p, int matrixHeight, int matrixWidth) {
387 2400 : SumMatrixColumnsData param = { matrixWidth, matrixHeight };
388 :
389 2400 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
390 2400 : buf.cmdDispatchPipeline(p, matrixHeight, 1, 1);
391 2400 : }
392 :
393 0 : void SumMatrixRowsAdd(CommandBuffer &buf, ComputePipeline *p, int batchSize, int matrixHeight, int matrixWidth ) {
394 0 : struct SumMatrixRowsData param = { matrixWidth, matrixHeight, batchSize, 1 };
395 0 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
396 0 : buf.cmdDispatchPipeline(p, matrixWidth, 1, batchSize);
397 0 : }
398 :
399 3600 : void SumMatrixRows(CommandBuffer &buf, ComputePipeline *p, int batchSize, int matrixHeight, int matrixWidth ) {
400 3600 : struct SumMatrixRowsData param = { matrixWidth, matrixHeight, batchSize, 0 };
401 3600 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
402 3600 : buf.cmdDispatchPipeline(p, matrixWidth, 1, batchSize);
403 3600 : }
404 :
405 2400 : void MultiplyDiagMatrixByMatrix(CommandBuffer &buf, ComputePipeline *p, int firstSize, int secondWidth, int resultBufferSize ) {
406 2400 : MultiplyDiagMatrixByMatrixData param = { firstSize, secondWidth };
407 :
408 2400 : buf.cmdPushConstants(VK_SHADER_STAGE_COMPUTE_BIT, 0, BytesView(reinterpret_cast<uint8_t *>(¶m), sizeof(param)));
409 2400 : buf.cmdDispatchPipeline(p, Ceil( firstSize, 4 ), secondWidth, 1);
410 2400 : }
411 :
412 1200 : void VectorDotProduct(CommandBuffer &buf, ComputePipeline *p, int vectorSize) {
413 1200 : runVectorShader(buf, p, BytesView(),
414 1200 : p->getLocalX() * p->getLocalY() * p->getLocalZ());
415 1200 : }
416 :
417 : }
|