MUDA
Loading...
Searching...
No Matches
matrix_format_converter_impl_block.inl
1#include <muda/cub/device/device_merge_sort.h>
2#include <muda/cub/device/device_radix_sort.h>
3#include <muda/cub/device/device_run_length_encode.h>
4#include <muda/cub/device/device_scan.h>
5#include <muda/cub/device/device_segmented_reduce.h>
6#include <muda/cub/device/device_reduce.h>
7#include <muda/launch.h>
8
9// for encode run length usage
10MUDA_GENERIC constexpr bool operator==(const int2& a, const int2& b)
11{
12 return a.x == b.x && a.y == b.y;
13}
14
15namespace muda::details
16{
17//using T = float;
18//constexpr int N = 3;
19
20template <typename T, int N>
21void MatrixFormatConverter<T, N>::convert(const DeviceTripletMatrix<T, N>& from,
22 DeviceBCOOMatrix<T, N>& to)
23{
24 to.reshape(from.block_rows(), from.block_cols());
25 to.resize_triplets(from.triplet_count());
26
27
28 if(to.triplet_count() == 0)
29 return;
30
31 if constexpr(N <= 3)
32 {
33 radix_sort_indices_and_blocks(from, to);
34 make_unique_indices_and_blocks(from, to);
35 }
36 else
37 {
38 merge_sort_indices_and_blocks(from, to);
39 make_unique_indices(from, to);
40 make_unique_blocks(from, to);
41 }
42}
43
44template <typename T, int N>
45void MatrixFormatConverter<T, N>::radix_sort_indices_and_blocks(
46 const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to)
47{
48 auto src_row_indices = from.block_row_indices();
49 auto src_col_indices = from.block_col_indices();
50 auto src_blocks = from.block_values();
51
52 loose_resize(ij_hash_input, src_row_indices.size());
53 loose_resize(sort_index_input, src_row_indices.size());
54
55 loose_resize(ij_hash, src_row_indices.size());
56 loose_resize(sort_index, src_row_indices.size());
57 ij_pairs.resize(src_row_indices.size());
58
59
60 // hash ij
61 ParallelFor(256)
62 .kernel_name(__FUNCTION__)
63 .apply(src_row_indices.size(),
64 [row_indices = src_row_indices.cviewer().name("row_indices"),
65 col_indices = src_col_indices.cviewer().name("col_indices"),
66 ij_hash = ij_hash_input.viewer().name("ij_hash"),
67 sort_index = sort_index_input.viewer().name("sort_index")] __device__(int i) mutable
68 {
69 ij_hash(i) =
70 (uint64_t{row_indices(i)} << 32) + uint64_t{col_indices(i)};
71 sort_index(i) = i;
72 });
73
74 DeviceRadixSort().SortPairs(ij_hash_input.data(),
75 ij_hash.data(),
76 sort_index_input.data(),
77 sort_index.data(),
78 ij_hash.size());
79
80 // set ij_hash back to row_indices and col_indices
81 auto dst_row_indices = to.block_row_indices();
82 auto dst_col_indices = to.block_col_indices();
83
84 ParallelFor(256)
85 .kernel_name(__FUNCTION__)
86 .apply(dst_row_indices.size(),
87 [ij_hash = ij_hash.viewer().name("ij_hash"),
88 ij_pairs = ij_pairs.viewer().name("ij_pairs")] __device__(int i) mutable
89 {
90 auto hash = ij_hash(i);
91 auto row_index = int{hash >> 32};
92 auto col_index = int{hash & 0xFFFFFFFF};
93 ij_pairs(i).x = row_index;
94 ij_pairs(i).y = col_index;
95 });
96
97 // sort the block values
98
99 {
100 loose_resize(blocks_sorted, from.block_values().size());
101 ParallelFor(256)
102 .kernel_name(__FUNCTION__)
103 .apply(src_blocks.size(),
104 [src_blocks = src_blocks.cviewer().name("blocks"),
105 sort_index = sort_index.cviewer().name("sort_index"),
106 dst_blocks = blocks_sorted.viewer().name("block_values")] __device__(int i) mutable
107 { dst_blocks(i) = src_blocks(sort_index(i)); });
108 }
109}
110
111template <typename T, int N>
112void MatrixFormatConverter<T, N>::make_unique_indices_and_blocks(
113 const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to)
114{
115 // alias to reuse the memory
116 auto& unique_ij_hash = ij_hash_input;
117
118 muda::DeviceReduce().ReduceByKey(
119 ij_hash.data(),
120 unique_ij_hash.data(),
121 blocks_sorted.data(),
122 to.block_values().data(),
123 count.data(),
124 [] CUB_RUNTIME_FUNCTION(const BlockMatrix& l, const BlockMatrix& r) -> BlockMatrix
125 { return l + r; },
126 ij_hash.size());
127
128 int h_count = count;
129
130 to.resize_triplets(h_count);
131
132 // set ij_hash back to row_indices and col_indices
133 ParallelFor()
134 .kernel_name("set col row indices")
135 .apply(to.block_row_indices().size(),
136 [ij_hash = unique_ij_hash.viewer().name("ij_hash"),
137 row_indices = to.block_row_indices().viewer().name("row_indices"),
138 col_indices = to.block_col_indices().viewer().name("col_indices")] __device__(int i) mutable
139 {
140 auto hash = ij_hash(i);
141 auto row_index = int{hash >> 32};
142 auto col_index = int{hash & 0xFFFFFFFF};
143 row_indices(i) = row_index;
144 col_indices(i) = col_index;
145 });
146}
147
148
149template <typename T, int N>
150void MatrixFormatConverter<T, N>::merge_sort_indices_and_blocks(
151 const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to)
152{
153 using namespace muda;
154
155 auto src_row_indices = from.block_row_indices();
156 auto src_col_indices = from.block_col_indices();
157 auto src_blocks = from.block_values();
158
159 loose_resize(sort_index, src_row_indices.size());
160 loose_resize(ij_pairs, src_row_indices.size());
161
162 ParallelFor(256)
163 .kernel_name(__FUNCTION__)
164 .apply(src_row_indices.size(),
165 [row_indices = src_row_indices.cviewer().name("row_indices"),
166 col_indices = src_col_indices.cviewer().name("col_indices"),
167 ij_pairs = ij_pairs.viewer().name("ij_pairs")] __device__(int i) mutable
168 {
169 ij_pairs(i).x = row_indices(i);
170 ij_pairs(i).y = col_indices(i);
171 });
172
173 ParallelFor(256)
174 .kernel_name(__FUNCTION__) //
175 .apply(src_row_indices.size(),
176 [sort_index = sort_index.viewer().name("sort_index")] __device__(int i) mutable
177 { sort_index(i) = i; });
178
179 DeviceMergeSort().SortPairs(ij_pairs.data(),
180 sort_index.data(),
181 ij_pairs.size(),
182 [] __device__(const int2& a, const int2& b) {
183 return a.x < b.x || (a.x == b.x && a.y < b.y);
184 });
185
186
187 // set ij_pairs back to row_indices and col_indices
188
189 auto dst_row_indices = to.block_row_indices();
190 auto dst_col_indices = to.block_col_indices();
191
192 ParallelFor(256)
193 .kernel_name("set col row indices")
194 .apply(dst_row_indices.size(),
195 [row_indices = dst_row_indices.viewer().name("row_indices"),
196 col_indices = dst_col_indices.viewer().name("col_indices"),
197 ij_pairs = ij_pairs.viewer().name("ij_pairs")] __device__(int i) mutable
198 {
199 row_indices(i) = ij_pairs(i).x;
200 col_indices(i) = ij_pairs(i).y;
201 });
202
203
204 // sort the block values
205
206 loose_resize(unique_blocks, from.m_block_values.size());
207
208 ParallelFor(256)
209 .kernel_name(__FUNCTION__)
210 .apply(src_blocks.size(),
211 [src_blocks = src_blocks.cviewer().name("blocks"),
212 sort_index = sort_index.cviewer().name("sort_index"),
213 dst_blocks = unique_blocks.viewer().name("block_values")] __device__(int i) mutable
214 { dst_blocks(i) = src_blocks(sort_index(i)); });
215}
216
217template <typename T, int N>
220{
221 using namespace muda;
222
223 auto& row_indices = to.m_block_row_indices;
224 auto& col_indices = to.m_block_col_indices;
225
226 loose_resize(unique_ij_pairs, ij_pairs.size());
227 loose_resize(unique_counts, ij_pairs.size());
228
229
230 DeviceRunLengthEncode().Encode(ij_pairs.data(),
231 unique_ij_pairs.data(),
232 unique_counts.data(),
233 count.data(),
234 ij_pairs.size());
235
236 int h_count = count;
237
238 unique_ij_pairs.resize(h_count);
239 unique_counts.resize(h_count);
240
241 offsets.resize(unique_counts.size() + 1); // +1 for the last offset_end
242
243 DeviceScan().ExclusiveSum(
244 unique_counts.data(), offsets.data(), unique_counts.size());
245
246
248 .kernel_name(__FUNCTION__)
249 .apply(unique_counts.size(),
250 [unique_ij_pairs = unique_ij_pairs.viewer().name("unique_ij_pairs"),
251 row_indices = row_indices.viewer().name("row_indices"),
252 col_indices = col_indices.viewer().name("col_indices")] __device__(int i) mutable
253 {
254 row_indices(i) = unique_ij_pairs(i).x;
255 col_indices(i) = unique_ij_pairs(i).y;
256 });
257
258 row_indices.resize(h_count);
259 col_indices.resize(h_count);
260}
261
262template <typename T, int N>
265{
266 using namespace muda;
267
268 auto& row_indices = to.m_block_row_indices;
269 auto& blocks = to.m_block_values;
270 blocks.resize(row_indices.size());
271 // first we add the offsets to counts, to get the offset_ends
272
273 Launch()
274 .kernel_name(__FUNCTION__)
275 .apply([offsets = offsets.viewer().name("offset"),
276 counts = unique_counts.cviewer().name("counts"),
277 last = unique_counts.size() - 1] __device__() mutable
278 { offsets(last + 1) = offsets(last) + counts(last); });
279
280 auto& begin_offset = offsets;
281 auto& end_offset = unique_counts; // already contains the offset_ends
282
283 // then we do a segmented reduce to get the unique blocks
284
285 DeviceSegmentedReduce().Reduce(
286 unique_blocks.data(),
287 blocks.data(),
288 blocks.size(),
289 offsets.data(),
290 offsets.data() + 1,
291 [] __host__ __device__(const BlockMatrix& a, const BlockMatrix& b) -> BlockMatrix
292 { return a + b; },
293 BlockMatrix::Zero().eval());
294}
295
296template <typename T, int N>
299 bool clear_dense_matrix)
300{
301 using namespace muda;
302 auto size = N * from.block_rows();
303 to.reshape(size, size);
304
305 if(clear_dense_matrix)
306 to.fill(0);
307
308 ParallelFor(256)
309 .kernel_name(__FUNCTION__)
310 .apply(from.block_values().size(),
311 [blocks = from.cviewer().name("src_sparse_matrix"),
312 dst = to.viewer().name("dst_dense_matrix")] __device__(int i) mutable
313 {
314 auto block = blocks(i);
315 auto row = block.block_row_index * N;
316 auto col = block.block_col_index * N;
317 dst.block<N, N>(row, col).as_eigen() += block.block_value;
318 });
319}
320
321template <typename T, int N>
324{
325 calculate_block_offsets(from, to);
326
327 to.m_block_col_indices = from.m_block_col_indices;
328 to.m_block_values = from.m_block_values;
329}
330
331template <typename T, int N>
334{
335 calculate_block_offsets(from, to);
336 to.m_block_col_indices = std::move(from.m_block_col_indices);
337 to.m_block_values = std::move(from.m_block_values);
338}
339
340template <typename T, int N>
343{
344 using namespace muda;
345 to.reshape(from.block_rows(), from.block_cols());
346
347 auto& dst_row_offsets = to.m_block_row_offsets;
348
349 // alias the offsets to the col_counts_per_row(reuse)
350 auto& col_counts_per_row = offsets;
351 col_counts_per_row.resize(to.m_block_row_offsets.size());
352 col_counts_per_row.fill(0);
353
354 unique_indices.resize(from.non_zero_blocks());
355 unique_counts.resize(from.non_zero_blocks());
356
357 // run length encode the row
358 DeviceRunLengthEncode().Encode(from.m_block_row_indices.data(),
359 unique_indices.data(),
360 unique_counts.data(),
361 count.data(),
362 from.non_zero_blocks());
363 int h_count = count;
364
365 unique_indices.resize(h_count);
366 unique_counts.resize(h_count);
367
368 ParallelFor(256)
369 .kernel_name(__FUNCTION__)
370 .apply(unique_counts.size(),
371 [unique_indices = unique_indices.cviewer().name("offset"),
372 counts = unique_counts.viewer().name("counts"),
373 col_counts_per_row = col_counts_per_row.viewer().name(
374 "col_counts_per_row")] __device__(int i) mutable
375 {
376 auto row = unique_indices(i);
377 col_counts_per_row(row) = counts(i);
378 });
379
380 // calculate the offsets
381 DeviceScan().ExclusiveSum(col_counts_per_row.data(),
382 dst_row_offsets.data(),
383 col_counts_per_row.size());
384}
385template <typename T, int N>
388 bool clear_dense_vector)
389{
390 to.resize(N * from.segment_count());
391 set_unique_segments_to_dense_vector(from, to, clear_dense_vector);
392}
393
394template <typename T, int N>
397{
398 to.reshape(from.segment_count());
399 to.resize_doublets(from.doublet_count());
400 merge_sort_indices_and_segments(from, to);
401 make_unique_indices(from, to);
402 make_unique_segments(from, to);
403}
404
405template <typename T, int N>
408{
409 using namespace muda;
410
411 auto& indices = sort_index; // alias sort_index to index
412
413 // copy as temp
414 indices = from.m_segment_indices;
415 temp_segments = from.m_segment_values;
416
417 DeviceMergeSort().SortPairs(indices.data(),
418 temp_segments.data(),
419 indices.size(),
420 [] __device__(const int& a, const int& b)
421 { return a < b; });
422}
423
424template <typename T, int N>
427{
428 using namespace muda;
429
430 auto& indices = sort_index; // alias sort_index to index
431 auto& unique_indices = to.m_segment_indices;
432
433 loose_resize(unique_indices, indices.size());
434 loose_resize(unique_counts, indices.size());
435
436 DeviceRunLengthEncode().Encode(indices.data(),
437 unique_indices.data(),
438 unique_counts.data(),
439 count.data(),
440 indices.size());
441
442 int h_count = count;
443
444 unique_indices.resize(h_count);
445 unique_counts.resize(h_count);
446
447 loose_resize(offsets, unique_counts.size() + 1);
448
449 DeviceScan().ExclusiveSum(
450 unique_counts.data(), offsets.data(), unique_counts.size());
451
452 // calculate the offset_ends, and set to the unique_counts
453
454 auto& begin_offset = offsets;
455
456 Launch()
457 .kernel_name(__FUNCTION__)
458 .apply([offset = offsets.viewer().name("offset"),
459 count = unique_counts.cviewer().name("counts"),
460 last = unique_counts.size() - 1] __device__() mutable
461 { offset(last + 1) = offset(last) + count(last); });
462}
463
464template <typename T, int N>
467{
468 using namespace muda;
469
470 auto& begin_offset = offsets;
471 auto& end_offset = unique_counts;
472
473 auto& unique_indices = to.m_segment_indices;
474 auto& unique_segments = to.m_segment_values;
475
476 unique_segments.resize(unique_indices.size());
477
478 DeviceSegmentedReduce().Reduce(
479 temp_segments.data(),
480 unique_segments.data(),
481 unique_segments.size(),
482 begin_offset.data(),
483 begin_offset.data() + 1,
484 [] __host__ __device__(const SegmentVector& a, const SegmentVector& b) -> SegmentVector
485 { return a + b; },
486 SegmentVector::Zero().eval());
487}
488
489template <typename T, int N>
491 const DeviceBCOOVector<T, N>& from, DeviceDenseVector<T>& to, bool clear_dense_vector)
492{
493 using namespace muda;
494
495 if(clear_dense_vector)
496 to.fill(0);
497
498 ParallelFor(256)
499 .kernel_name(__FUNCTION__)
500 .apply(from.non_zero_segments(),
501 [unique_segments = from.m_segment_values.cviewer().name("unique_segments"),
502 unique_indices = from.m_segment_indices.cviewer().name("unique_indices"),
503 dst = to.viewer().name("dst_dense_vector")] __device__(int i) mutable
504 {
505 auto index = unique_indices(i);
506 dst.segment<N>(index * N).as_eigen() += unique_segments(i);
507 });
508}
509
510template <typename T, int N>
513 bool clear_dense_vector)
514{
515 using namespace muda;
516
517 to.resize(N * from.segment_count());
518
519 if(clear_dense_vector)
520 to.fill(0);
521
522 ParallelFor(256)
523 .kernel_name(__FUNCTION__)
524 .apply(from.doublet_count(),
525 [src = from.viewer().name("src_sparse_vector"),
526 dst = to.viewer().name("dst_dense_vector")] __device__(int i) mutable
527 {
528 auto&& [index, value] = src(i);
529 dst.segment<N>(index * N).atomic_add(value);
530 });
531}
532
533template <typename T>
534void bsr2csr(cusparseHandle_t handle,
535 int mb,
536 int nb,
537 int blockDim,
538 cusparseMatDescr_t descrA,
539 const T* bsrValA,
540 const int* bsrRowPtrA,
541 const int* bsrColIndA,
542 int nnzb,
544 muda::DeviceBuffer<int>& row_offsets,
545 muda::DeviceBuffer<int>& col_indices,
546 muda::DeviceBuffer<T>& values)
547{
548 using namespace muda;
549 cusparseDirection_t dir = CUSPARSE_DIRECTION_COLUMN;
550 int m = mb * blockDim;
551 int nnz = nnzb * blockDim * blockDim; // number of elements
552 to.reshape(m, m);
553 col_indices.resize(nnz);
554 values.resize(nnz);
555 if constexpr(std::is_same_v<T, float>)
556 {
557 checkCudaErrors(cusparseSbsr2csr(handle,
558 dir,
559 mb,
560 nb,
561 descrA,
562 bsrValA,
563 bsrRowPtrA,
564 bsrColIndA,
565 blockDim,
566 to.legacy_descr(),
567 values.data(),
568 row_offsets.data(),
569 col_indices.data()));
570 }
571 else if constexpr(std::is_same_v<T, double>)
572 {
573 checkCudaErrors(cusparseDbsr2csr(handle,
574 dir,
575 mb,
576 nb,
577 descrA,
578 bsrValA,
579 bsrRowPtrA,
580 bsrColIndA,
581 blockDim,
582 to.legacy_descr(),
583 values.data(),
584 row_offsets.data(),
585 col_indices.data()));
586 }
587}
588
589
590template <typename T, int N>
593{
594 expand_blocks(from, to);
595 sort_indices_and_values(from, to);
596}
597
598template <typename T, int N>
601{
602 using namespace muda;
603
604 constexpr int N2 = N * N;
605
606 to.reshape(from.block_rows() * N, from.block_cols() * N);
607 to.resize_triplets(from.non_zero_blocks() * N2);
608
609 auto& row_indices = to.m_row_indices;
610 auto& col_indices = to.m_col_indices;
611 auto& values = to.m_values;
612
613 auto& block_row_indices = from.m_block_row_indices;
614 auto& block_col_indices = from.m_block_col_indices;
615 auto& block_values = from.m_block_values;
616
617
618 ParallelFor(256)
619 .kernel_name(__FUNCTION__)
620 .apply(block_row_indices.size(),
621 [block_row_indices = block_row_indices.cviewer().name("block_row_indices"),
622 block_col_indices = block_col_indices.cviewer().name("block_col_indices"),
623 block_values = block_values.cviewer().name("block_values"),
624 row_indices = row_indices.viewer().name("row_indices"),
625 col_indices = col_indices.viewer().name("col_indices"),
626 values = values.viewer().name("values")] __device__(int i) mutable
627 {
628 auto block_row_index = block_row_indices(i);
629 auto block_col_index = block_col_indices(i);
630 auto block = block_values(i);
631
632 auto row = block_row_index * N;
633 auto col = block_col_index * N;
634
635 auto index = i * N2;
636#pragma unroll
637 for(int r = 0; r < N; ++r)
638 {
639#pragma unroll
640 for(int c = 0; c < N; ++c)
641 {
642 row_indices(index) = row + r;
643 col_indices(index) = col + c;
644 values(index) = block(r, c);
645 ++index;
646 }
647 }
648 });
649}
650
651template <typename T, int N>
654{
655 using namespace muda;
656
657 auto& row_indices = to.m_row_indices;
658 auto& col_indices = to.m_col_indices;
659 auto& values = to.m_values;
660
661 ij_pairs.resize(row_indices.size());
662
663 ParallelFor(256)
664 .kernel_name(__FUNCTION__)
665 .apply(row_indices.size(),
666 [row_indices = row_indices.cviewer().name("row_indices"),
667 col_indices = col_indices.cviewer().name("col_indices"),
668 ij_pairs = ij_pairs.viewer().name("ij_pairs")] __device__(int i) mutable
669 {
670 ij_pairs(i).x = row_indices(i);
671 ij_pairs(i).y = col_indices(i);
672 });
673
674 DeviceMergeSort().SortPairs(ij_pairs.data(),
675 to.m_values.data(),
676 ij_pairs.size(),
677 [] __device__(const int2& a, const int2& b) {
678 return a.x < b.x || (a.x == b.x && a.y < b.y);
679 });
680
681 // set ij_pairs back to row_indices and col_indices
682
683 auto dst_row_indices = to.row_indices();
684 auto dst_col_indices = to.col_indices();
685
686 ParallelFor(256)
687 .kernel_name(__FUNCTION__)
688 .apply(dst_row_indices.size(),
689 [row_indices = dst_row_indices.viewer().name("row_indices"),
690 col_indices = dst_col_indices.viewer().name("col_indices"),
691 ij_pairs = ij_pairs.viewer().name("ij_pairs")] __device__(int i) mutable
692 {
693 row_indices(i) = ij_pairs(i).x;
694 col_indices(i) = ij_pairs(i).y;
695 });
696}
697
698template <typename T, int N>
701{
702 using namespace muda;
703
704 bsr2csr(cusparse(),
705 from.block_rows(),
706 from.block_cols(),
707 N,
708 from.legacy_descr(),
709 (const T*)from.m_block_values.data(),
710 from.m_block_row_offsets.data(),
711 from.m_block_col_indices.data(),
712 from.non_zero_blocks(),
713 to,
714 to.m_row_offsets,
715 to.m_col_indices,
716 to.m_values);
717}
718} // namespace muda::details
Definition device_bcoo_matrix.h:18
Definition device_bcoo_vector.h:10
Definition device_bsr_matrix.h:16
A std::vector like wrapper of cuda device memory, allows user to:
Definition device_buffer.h:46
Definition device_csr_matrix.h:16
Definition device_dense_matrix.h:16
Definition device_dense_vector.h:16
Definition device_doublet_vector.h:16
Definition device_merge_sort.h:12
Definition device_reduce.h:12
Definition device_run_length_encode.h:12
Definition device_scan.h:21
Definition device_segmented_reduce.h:12
Definition device_triplet_matrix.h:14
A wrapper of raw cuda kernel launch in muda style, removing the <<<>>> usage, for better intellisense...
Definition launch.h:86
Definition matrix_format_converter.h:48
a frequently used parallel for loop, DynamicBlockDim and GridStrideLoop strategy are provided,...
Definition parallel_for.h:116