1#include <muda/cub/device/device_merge_sort.h>
2#include <muda/cub/device/device_radix_sort.h>
3#include <muda/cub/device/device_run_length_encode.h>
4#include <muda/cub/device/device_scan.h>
5#include <muda/cub/device/device_segmented_reduce.h>
6#include <muda/cub/device/device_reduce.h>
7#include <muda/launch.h>
10MUDA_GENERIC
constexpr bool operator==(
const int2& a,
const int2& b)
12 return a.x == b.x && a.y == b.y;
15namespace muda::details
20template <
typename T,
int N>
21void MatrixFormatConverter<T, N>::convert(
const DeviceTripletMatrix<T, N>& from,
22 DeviceBCOOMatrix<T, N>& to)
24 to.reshape(from.block_rows(), from.block_cols());
25 to.resize_triplets(from.triplet_count());
28 if(to.triplet_count() == 0)
33 radix_sort_indices_and_blocks(from, to);
34 make_unique_indices_and_blocks(from, to);
38 merge_sort_indices_and_blocks(from, to);
39 make_unique_indices(from, to);
40 make_unique_blocks(from, to);
44template <
typename T,
int N>
45void MatrixFormatConverter<T, N>::radix_sort_indices_and_blocks(
46 const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to)
48 auto src_row_indices = from.block_row_indices();
49 auto src_col_indices = from.block_col_indices();
50 auto src_blocks = from.block_values();
52 loose_resize(ij_hash_input, src_row_indices.size());
53 loose_resize(sort_index_input, src_row_indices.size());
55 loose_resize(ij_hash, src_row_indices.size());
56 loose_resize(sort_index, src_row_indices.size());
57 ij_pairs.resize(src_row_indices.size());
62 .kernel_name(__FUNCTION__)
63 .apply(src_row_indices.size(),
64 [row_indices = src_row_indices.cviewer().name(
"row_indices"),
65 col_indices = src_col_indices.cviewer().name(
"col_indices"),
66 ij_hash = ij_hash_input.viewer().name(
"ij_hash"),
67 sort_index = sort_index_input.viewer().name(
"sort_index")] __device__(
int i)
mutable
70 (uint64_t{row_indices(i)} << 32) + uint64_t{col_indices(i)};
74 DeviceRadixSort().SortPairs(ij_hash_input.data(),
76 sort_index_input.data(),
81 auto dst_row_indices = to.block_row_indices();
82 auto dst_col_indices = to.block_col_indices();
85 .kernel_name(__FUNCTION__)
86 .apply(dst_row_indices.size(),
87 [ij_hash = ij_hash.viewer().name(
"ij_hash"),
88 ij_pairs = ij_pairs.viewer().name(
"ij_pairs")] __device__(
int i)
mutable
90 auto hash = ij_hash(i);
91 auto row_index =
int{hash >> 32};
92 auto col_index =
int{hash & 0xFFFFFFFF};
93 ij_pairs(i).x = row_index;
94 ij_pairs(i).y = col_index;
100 loose_resize(blocks_sorted, from.block_values().size());
102 .kernel_name(__FUNCTION__)
103 .apply(src_blocks.size(),
104 [src_blocks = src_blocks.cviewer().name(
"blocks"),
105 sort_index = sort_index.cviewer().name(
"sort_index"),
106 dst_blocks = blocks_sorted.viewer().name(
"block_values")] __device__(
int i)
mutable
107 { dst_blocks(i) = src_blocks(sort_index(i)); });
111template <
typename T,
int N>
112void MatrixFormatConverter<T, N>::make_unique_indices_and_blocks(
113 const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to)
116 auto& unique_ij_hash = ij_hash_input;
120 unique_ij_hash.data(),
121 blocks_sorted.data(),
122 to.block_values().data(),
124 [] CUB_RUNTIME_FUNCTION(
const BlockMatrix& l,
const BlockMatrix& r) -> BlockMatrix
130 to.resize_triplets(h_count);
134 .kernel_name(
"set col row indices")
135 .apply(to.block_row_indices().size(),
136 [ij_hash = unique_ij_hash.viewer().name(
"ij_hash"),
137 row_indices = to.block_row_indices().viewer().name(
"row_indices"),
138 col_indices = to.block_col_indices().viewer().name(
"col_indices")] __device__(
int i)
mutable
140 auto hash = ij_hash(i);
141 auto row_index =
int{hash >> 32};
142 auto col_index =
int{hash & 0xFFFFFFFF};
143 row_indices(i) = row_index;
144 col_indices(i) = col_index;
149template <
typename T,
int N>
150void MatrixFormatConverter<T, N>::merge_sort_indices_and_blocks(
151 const DeviceTripletMatrix<T, N>& from, DeviceBCOOMatrix<T, N>& to)
153 using namespace muda;
155 auto src_row_indices = from.block_row_indices();
156 auto src_col_indices = from.block_col_indices();
157 auto src_blocks = from.block_values();
159 loose_resize(sort_index, src_row_indices.size());
160 loose_resize(ij_pairs, src_row_indices.size());
163 .kernel_name(__FUNCTION__)
164 .apply(src_row_indices.size(),
165 [row_indices = src_row_indices.cviewer().name(
"row_indices"),
166 col_indices = src_col_indices.cviewer().name(
"col_indices"),
167 ij_pairs = ij_pairs.viewer().name(
"ij_pairs")] __device__(
int i)
mutable
169 ij_pairs(i).x = row_indices(i);
170 ij_pairs(i).y = col_indices(i);
174 .kernel_name(__FUNCTION__)
175 .apply(src_row_indices.size(),
176 [sort_index = sort_index.viewer().name(
"sort_index")] __device__(
int i)
mutable
177 { sort_index(i) = i; });
182 [] __device__(
const int2& a,
const int2& b) {
183 return a.x < b.x || (a.x == b.x && a.y < b.y);
189 auto dst_row_indices = to.block_row_indices();
190 auto dst_col_indices = to.block_col_indices();
193 .kernel_name(
"set col row indices")
194 .apply(dst_row_indices.size(),
195 [row_indices = dst_row_indices.viewer().name(
"row_indices"),
196 col_indices = dst_col_indices.viewer().name(
"col_indices"),
197 ij_pairs = ij_pairs.viewer().name(
"ij_pairs")] __device__(
int i)
mutable
199 row_indices(i) = ij_pairs(i).x;
200 col_indices(i) = ij_pairs(i).y;
206 loose_resize(unique_blocks, from.m_block_values.size());
209 .kernel_name(__FUNCTION__)
210 .apply(src_blocks.size(),
211 [src_blocks = src_blocks.cviewer().name(
"blocks"),
212 sort_index = sort_index.cviewer().name(
"sort_index"),
213 dst_blocks = unique_blocks.viewer().name(
"block_values")] __device__(
int i)
mutable
214 { dst_blocks(i) = src_blocks(sort_index(i)); });
217template <
typename T,
int N>
221 using namespace muda;
223 auto& row_indices = to.m_block_row_indices;
224 auto& col_indices = to.m_block_col_indices;
226 loose_resize(unique_ij_pairs, ij_pairs.size());
227 loose_resize(unique_counts, ij_pairs.size());
231 unique_ij_pairs.data(),
232 unique_counts.data(),
238 unique_ij_pairs.resize(h_count);
239 unique_counts.resize(h_count);
241 offsets.resize(unique_counts.size() + 1);
244 unique_counts.data(), offsets.data(), unique_counts.size());
248 .kernel_name(__FUNCTION__)
249 .apply(unique_counts.size(),
250 [unique_ij_pairs = unique_ij_pairs.viewer().name(
"unique_ij_pairs"),
251 row_indices = row_indices.viewer().name(
"row_indices"),
252 col_indices = col_indices.viewer().name(
"col_indices")] __device__(
int i)
mutable
254 row_indices(i) = unique_ij_pairs(i).x;
255 col_indices(i) = unique_ij_pairs(i).y;
258 row_indices.resize(h_count);
259 col_indices.resize(h_count);
262template <
typename T,
int N>
266 using namespace muda;
268 auto& row_indices = to.m_block_row_indices;
269 auto& blocks = to.m_block_values;
270 blocks.resize(row_indices.size());
274 .kernel_name(__FUNCTION__)
275 .apply([offsets = offsets.viewer().name(
"offset"),
276 counts = unique_counts.cviewer().name(
"counts"),
277 last = unique_counts.size() - 1] __device__()
mutable
278 { offsets(last + 1) = offsets(last) + counts(last); });
280 auto& begin_offset = offsets;
281 auto& end_offset = unique_counts;
286 unique_blocks.data(),
291 [] __host__ __device__(
const BlockMatrix& a,
const BlockMatrix& b) -> BlockMatrix
293 BlockMatrix::Zero().eval());
296template <
typename T,
int N>
299 bool clear_dense_matrix)
301 using namespace muda;
302 auto size = N * from.block_rows();
303 to.reshape(size, size);
305 if(clear_dense_matrix)
309 .kernel_name(__FUNCTION__)
310 .apply(from.block_values().size(),
311 [blocks = from.cviewer().name(
"src_sparse_matrix"),
312 dst = to.viewer().name(
"dst_dense_matrix")] __device__(
int i)
mutable
314 auto block = blocks(i);
315 auto row = block.block_row_index * N;
316 auto col = block.block_col_index * N;
317 dst.block<N, N>(row, col).as_eigen() += block.block_value;
321template <
typename T,
int N>
325 calculate_block_offsets(from, to);
327 to.m_block_col_indices = from.m_block_col_indices;
328 to.m_block_values = from.m_block_values;
331template <
typename T,
int N>
335 calculate_block_offsets(from, to);
336 to.m_block_col_indices = std::move(from.m_block_col_indices);
337 to.m_block_values = std::move(from.m_block_values);
340template <
typename T,
int N>
344 using namespace muda;
345 to.reshape(from.block_rows(), from.block_cols());
347 auto& dst_row_offsets = to.m_block_row_offsets;
350 auto& col_counts_per_row = offsets;
351 col_counts_per_row.resize(to.m_block_row_offsets.size());
352 col_counts_per_row.fill(0);
354 unique_indices.resize(from.non_zero_blocks());
355 unique_counts.resize(from.non_zero_blocks());
359 unique_indices.data(),
360 unique_counts.data(),
362 from.non_zero_blocks());
365 unique_indices.resize(h_count);
366 unique_counts.resize(h_count);
369 .kernel_name(__FUNCTION__)
370 .apply(unique_counts.size(),
371 [unique_indices = unique_indices.cviewer().name(
"offset"),
372 counts = unique_counts.viewer().name(
"counts"),
373 col_counts_per_row = col_counts_per_row.viewer().name(
374 "col_counts_per_row")] __device__(
int i)
mutable
376 auto row = unique_indices(i);
377 col_counts_per_row(row) = counts(i);
381 DeviceScan().ExclusiveSum(col_counts_per_row.data(),
382 dst_row_offsets.data(),
383 col_counts_per_row.size());
385template <
typename T,
int N>
388 bool clear_dense_vector)
390 to.resize(N * from.segment_count());
391 set_unique_segments_to_dense_vector(from, to, clear_dense_vector);
394template <
typename T,
int N>
398 to.reshape(from.segment_count());
399 to.resize_doublets(from.doublet_count());
400 merge_sort_indices_and_segments(from, to);
401 make_unique_indices(from, to);
402 make_unique_segments(from, to);
405template <
typename T,
int N>
409 using namespace muda;
411 auto& indices = sort_index;
414 indices = from.m_segment_indices;
415 temp_segments = from.m_segment_values;
418 temp_segments.data(),
420 [] __device__(
const int& a,
const int& b)
424template <
typename T,
int N>
428 using namespace muda;
430 auto& indices = sort_index;
431 auto& unique_indices = to.m_segment_indices;
433 loose_resize(unique_indices, indices.size());
434 loose_resize(unique_counts, indices.size());
437 unique_indices.data(),
438 unique_counts.data(),
444 unique_indices.resize(h_count);
445 unique_counts.resize(h_count);
447 loose_resize(offsets, unique_counts.size() + 1);
450 unique_counts.data(), offsets.data(), unique_counts.size());
454 auto& begin_offset = offsets;
457 .kernel_name(__FUNCTION__)
458 .apply([offset = offsets.viewer().name(
"offset"),
459 count = unique_counts.cviewer().name(
"counts"),
460 last = unique_counts.size() - 1] __device__()
mutable
461 { offset(last + 1) = offset(last) + count(last); });
464template <
typename T,
int N>
468 using namespace muda;
470 auto& begin_offset = offsets;
471 auto& end_offset = unique_counts;
473 auto& unique_indices = to.m_segment_indices;
474 auto& unique_segments = to.m_segment_values;
476 unique_segments.resize(unique_indices.size());
479 temp_segments.data(),
480 unique_segments.data(),
481 unique_segments.size(),
483 begin_offset.data() + 1,
484 [] __host__ __device__(
const SegmentVector& a,
const SegmentVector& b) -> SegmentVector
486 SegmentVector::Zero().eval());
489template <
typename T,
int N>
493 using namespace muda;
495 if(clear_dense_vector)
499 .kernel_name(__FUNCTION__)
500 .apply(from.non_zero_segments(),
501 [unique_segments = from.m_segment_values.cviewer().name(
"unique_segments"),
502 unique_indices = from.m_segment_indices.cviewer().name(
"unique_indices"),
503 dst = to.viewer().name(
"dst_dense_vector")] __device__(
int i)
mutable
505 auto index = unique_indices(i);
506 dst.segment<N>(index * N).as_eigen() += unique_segments(i);
510template <
typename T,
int N>
513 bool clear_dense_vector)
515 using namespace muda;
517 to.resize(N * from.segment_count());
519 if(clear_dense_vector)
523 .kernel_name(__FUNCTION__)
524 .apply(from.doublet_count(),
525 [src = from.viewer().name(
"src_sparse_vector"),
526 dst = to.viewer().name(
"dst_dense_vector")] __device__(
int i)
mutable
528 auto&& [index, value] = src(i);
529 dst.segment<N>(index * N).atomic_add(value);
534void bsr2csr(cusparseHandle_t handle,
538 cusparseMatDescr_t descrA,
540 const int* bsrRowPtrA,
541 const int* bsrColIndA,
548 using namespace muda;
549 cusparseDirection_t dir = CUSPARSE_DIRECTION_COLUMN;
550 int m = mb * blockDim;
551 int nnz = nnzb * blockDim * blockDim;
553 col_indices.resize(nnz);
555 if constexpr(std::is_same_v<T, float>)
557 checkCudaErrors(cusparseSbsr2csr(handle,
569 col_indices.data()));
571 else if constexpr(std::is_same_v<T, double>)
573 checkCudaErrors(cusparseDbsr2csr(handle,
585 col_indices.data()));
590template <
typename T,
int N>
594 expand_blocks(from, to);
595 sort_indices_and_values(from, to);
598template <
typename T,
int N>
602 using namespace muda;
604 constexpr int N2 = N * N;
606 to.reshape(from.block_rows() * N, from.block_cols() * N);
607 to.resize_triplets(from.non_zero_blocks() * N2);
609 auto& row_indices = to.m_row_indices;
610 auto& col_indices = to.m_col_indices;
611 auto& values = to.m_values;
613 auto& block_row_indices = from.m_block_row_indices;
614 auto& block_col_indices = from.m_block_col_indices;
615 auto& block_values = from.m_block_values;
619 .kernel_name(__FUNCTION__)
620 .apply(block_row_indices.size(),
621 [block_row_indices = block_row_indices.cviewer().name(
"block_row_indices"),
622 block_col_indices = block_col_indices.cviewer().name(
"block_col_indices"),
623 block_values = block_values.cviewer().name(
"block_values"),
624 row_indices = row_indices.viewer().name(
"row_indices"),
625 col_indices = col_indices.viewer().name(
"col_indices"),
626 values = values.viewer().name(
"values")] __device__(
int i)
mutable
628 auto block_row_index = block_row_indices(i);
629 auto block_col_index = block_col_indices(i);
630 auto block = block_values(i);
632 auto row = block_row_index * N;
633 auto col = block_col_index * N;
637 for(
int r = 0; r < N; ++r)
640 for(
int c = 0; c < N; ++c)
642 row_indices(index) = row + r;
643 col_indices(index) = col + c;
644 values(index) = block(r, c);
651template <
typename T,
int N>
655 using namespace muda;
657 auto& row_indices = to.m_row_indices;
658 auto& col_indices = to.m_col_indices;
659 auto& values = to.m_values;
661 ij_pairs.resize(row_indices.size());
664 .kernel_name(__FUNCTION__)
665 .apply(row_indices.size(),
666 [row_indices = row_indices.cviewer().name(
"row_indices"),
667 col_indices = col_indices.cviewer().name(
"col_indices"),
668 ij_pairs = ij_pairs.viewer().name(
"ij_pairs")] __device__(
int i)
mutable
670 ij_pairs(i).x = row_indices(i);
671 ij_pairs(i).y = col_indices(i);
677 [] __device__(
const int2& a,
const int2& b) {
678 return a.x < b.x || (a.x == b.x && a.y < b.y);
683 auto dst_row_indices = to.row_indices();
684 auto dst_col_indices = to.col_indices();
687 .kernel_name(__FUNCTION__)
688 .apply(dst_row_indices.size(),
689 [row_indices = dst_row_indices.viewer().name(
"row_indices"),
690 col_indices = dst_col_indices.viewer().name(
"col_indices"),
691 ij_pairs = ij_pairs.viewer().name(
"ij_pairs")] __device__(
int i)
mutable
693 row_indices(i) = ij_pairs(i).x;
694 col_indices(i) = ij_pairs(i).y;
698template <
typename T,
int N>
702 using namespace muda;
709 (
const T*)from.m_block_values.data(),
710 from.m_block_row_offsets.data(),
711 from.m_block_col_indices.data(),
712 from.non_zero_blocks(),
Definition device_bcoo_matrix.h:18
Definition device_bcoo_vector.h:10
Definition device_bsr_matrix.h:16
A std::vector like wrapper of cuda device memory, allows user to:
Definition device_buffer.h:46
Definition device_csr_matrix.h:16
Definition device_dense_matrix.h:16
Definition device_dense_vector.h:16
Definition device_doublet_vector.h:16
Definition device_merge_sort.h:12
Definition device_reduce.h:12
Definition device_run_length_encode.h:12
Definition device_scan.h:21
Definition device_segmented_reduce.h:12
Definition device_triplet_matrix.h:14
A wrapper of raw cuda kernel launch in muda style, removing the <<<>>> usage, for better intellisense...
Definition launch.h:86
a frequently used parallel for loop, DynamicBlockDim and GridStrideLoop strategy are provided,...
Definition parallel_for.h:116