1#include <muda/cub/device/device_merge_sort.h>
2#include <muda/cub/device/device_radix_sort.h>
3#include <muda/cub/device/device_run_length_encode.h>
4#include <muda/cub/device/device_scan.h>
5#include <muda/cub/device/device_segmented_reduce.h>
6#include <muda/cub/device/device_reduce.h>
7#include <muda/launch.h>
13void MatrixFormatConverter<T, 1>::convert(
const DeviceTripletMatrix<T, 1>& from,
14 DeviceBCOOMatrix<T, 1>& to)
16 to.reshape(from.rows(), from.cols());
17 to.resize_triplets(from.triplet_count());
18 if(to.triplet_count() == 0)
20 merge_sort_indices_and_values(from, to);
21 make_unique_indices(from, to);
22 make_unique_values(from, to);
26void MatrixFormatConverter<T, 1>::merge_sort_indices_and_values(
27 const DeviceTripletMatrix<T, 1>& from, DeviceBCOOMatrix<T, 1>& to)
31 auto src_row_indices = from.row_indices();
32 auto src_col_indices = from.col_indices();
33 auto src_values = from.values();
35 loose_resize(sort_index, src_row_indices.size());
36 loose_resize(ij_pairs, src_row_indices.size());
39 .kernel_name(
"set ij pairs")
40 .apply(src_row_indices.size(),
41 [row_indices = src_row_indices.cviewer().name(
"row_indices"),
42 col_indices = src_col_indices.cviewer().name(
"col_indices"),
43 ij_pairs = ij_pairs.viewer().name(
"ij_pairs")] __device__(
int i)
mutable
45 ij_pairs(i).x = row_indices(i);
46 ij_pairs(i).y = col_indices(i);
51 .apply(src_row_indices.size(),
52 [sort_index = sort_index.viewer().name(
"sort_index")] __device__(
int i)
mutable
53 { sort_index(i) = i; });
58 [] __device__(
const int2& a,
const int2& b) {
59 return a.x < b.x || (a.x == b.x && a.y < b.y);
64 auto dst_row_indices = to.row_indices();
65 auto dst_col_indices = to.col_indices();
68 .kernel_name(
"set col row indices")
69 .apply(dst_row_indices.size(),
70 [row_indices = dst_row_indices.viewer().name(
"row_indices"),
71 col_indices = dst_col_indices.viewer().name(
"col_indices"),
72 ij_pairs = ij_pairs.viewer().name(
"ij_pairs")] __device__(
int i)
mutable
74 row_indices(i) = ij_pairs(i).x;
75 col_indices(i) = ij_pairs(i).y;
80 loose_resize(unique_values, from.m_values.size());
83 .kernel_name(
"set block values")
84 .apply(src_values.size(),
85 [src_values = src_values.cviewer().name(
"blocks"),
86 sort_index = sort_index.cviewer().name(
"sort_index"),
87 dst_values = unique_values.viewer().name(
"values")] __device__(
int i)
mutable
88 { dst_values(i) = src_values(sort_index(i)); });
97 auto& row_indices = to.m_row_indices;
98 auto& col_indices = to.m_col_indices;
100 loose_resize(unique_ij_pairs, ij_pairs.size());
101 loose_resize(unique_counts, ij_pairs.size());
104 unique_ij_pairs.data(),
105 unique_counts.data(),
111 unique_ij_pairs.resize(h_count);
112 unique_counts.resize(h_count);
114 loose_resize(offsets, unique_counts.size());
117 unique_counts.data(), offsets.data(), unique_counts.size());
121 .kernel_name(
"make unique indices")
122 .apply(unique_counts.size(),
123 [unique_ij_pairs = unique_ij_pairs.viewer().name(
"unique_ij_pairs"),
124 row_indices = row_indices.viewer().name(
"row_indices"),
125 col_indices = col_indices.viewer().name(
"col_indices")] __device__(
int i)
mutable
127 row_indices(i) = unique_ij_pairs(i).x;
128 col_indices(i) = unique_ij_pairs(i).y;
131 row_indices.resize(h_count);
132 col_indices.resize(h_count);
139 using namespace muda;
141 auto& row_indices = to.m_row_indices;
142 auto& values = to.m_values;
143 values.resize(row_indices.size());
147 .kernel_name(
"calculate offset_ends")
148 .apply(unique_counts.size(),
149 [offset = offsets.cviewer().name(
"offset"),
150 counts = unique_counts.viewer().name(
"counts")] __device__(
int i)
mutable
151 { counts(i) += offset(i); });
153 auto& begin_offset = offsets;
154 auto& end_offset = unique_counts;
168 auto src_row_indices = from.row_indices();
169 auto src_col_indices = from.col_indices();
170 auto src_blocks = from.values();
172 loose_resize(ij_hash_input, src_row_indices.size());
173 loose_resize(sort_index_input, src_row_indices.size());
175 loose_resize(ij_hash, src_row_indices.size());
176 loose_resize(sort_index, src_row_indices.size());
177 ij_pairs.resize(src_row_indices.size());
182 .kernel_name(__FUNCTION__)
183 .apply(src_row_indices.size(),
184 [row_indices = src_row_indices.cviewer().name(
"row_indices"),
185 col_indices = src_col_indices.cviewer().name(
"col_indices"),
186 ij_hash = ij_hash_input.viewer().name(
"ij_hash"),
187 sort_index = sort_index_input.viewer().name(
"sort_index")] __device__(
int i)
mutable
190 (uint64_t{row_indices(i)} << 32) + uint64_t{col_indices(i)};
196 sort_index_input.data(),
203 loose_resize(values_sorted, from.values().size());
205 .kernel_name(__FUNCTION__)
206 .apply(src_blocks.size(),
207 [src_blocks = src_blocks.cviewer().name(
"blocks"),
208 sort_index = sort_index.cviewer().name(
"sort_index"),
209 dst_blocks = values_sorted.viewer().name(
"block_values")] __device__(
int i)
mutable
210 { dst_blocks(i) = src_blocks(sort_index(i)); });
219 auto& unique_ij_hash = ij_hash_input;
223 unique_ij_hash.data(),
224 values_sorted.data(),
227 [] CUB_RUNTIME_FUNCTION(
const T& l,
const T& r) -> T {
return l + r; },
232 to.resize_triplets(h_count);
236 .kernel_name(
"set col row indices")
237 .apply(to.row_indices().size(),
238 [ij_hash = unique_ij_hash.viewer().name(
"ij_hash"),
239 row_indices = to.row_indices().viewer().name(
"row_indices"),
240 col_indices = to.col_indices().viewer().name(
"col_indices")] __device__(
int i)
mutable
242 auto hash = ij_hash(i);
243 auto row_index =
int{hash >> 32};
244 auto col_index =
int{hash & 0xFFFFFFFF};
245 row_indices(i) = row_index;
246 col_indices(i) = col_index;
254 bool clear_dense_matrix)
256 using namespace muda;
257 auto size = from.rows();
258 to.reshape(size, size);
260 if(clear_dense_matrix)
264 .kernel_name(__FUNCTION__)
265 .apply(from.values().size(),
266 [values = from.cviewer().name(
"src_sparse_matrix"),
267 dst = to.viewer().name(
"dst_dense_matrix")] __device__(
int i)
mutable
269 auto value = values(i);
270 auto row = value.row_index;
271 auto col = value.col_index;
272 dst(row, col) += value.value;
280 calculate_block_offsets(from, to);
281 to.m_col_indices = from.m_col_indices;
282 to.m_values = from.m_values;
288 calculate_block_offsets(from, to);
289 to.m_col_indices = std::move(from.m_col_indices);
290 to.m_values = std::move(from.m_values);
297 using namespace muda;
298 to.reshape(from.rows(), from.cols());
300 auto& dst_row_offsets = to.m_row_offsets;
303 auto& col_counts_per_row = offsets;
304 col_counts_per_row.resize(to.m_row_offsets.size());
305 col_counts_per_row.fill(0);
307 loose_resize(unique_indices, from.non_zeros());
308 loose_resize(unique_counts, from.non_zeros());
312 unique_indices.data(),
313 unique_counts.data(),
318 unique_indices.resize(h_count);
319 unique_counts.resize(h_count);
322 .kernel_name(__FUNCTION__)
323 .apply(unique_counts.size(),
324 [unique_indices = unique_indices.cviewer().name(
"offset"),
325 counts = unique_counts.viewer().name(
"counts"),
326 col_counts_per_row = col_counts_per_row.viewer().name(
327 "col_counts_per_row")] __device__(
int i)
mutable
329 auto row = unique_indices(i);
330 col_counts_per_row(row) = counts(i);
334 DeviceScan().ExclusiveSum(col_counts_per_row.data(),
335 dst_row_offsets.data(),
336 col_counts_per_row.size());
343 to.reshape(from.size());
344 to.resize_doublet(from.doublet_count());
346 merge_sort_indices_and_values(from, to);
347 make_unique_indices(from, to);
348 make_unique_values(from, to);
355 using namespace muda;
357 auto& indices = sort_index;
358 auto& values = temp_values;
360 indices = from.m_indices;
361 values = from.m_values;
366 [] __device__(
const int& a,
const int& b)
374 using namespace muda;
376 auto& indices = to.m_indices;
377 auto& values = to.m_values;
379 auto& unique_indices = to.m_indices;
380 unique_indices.resize(indices.size());
381 loose_resize(unique_counts, indices.size());
384 unique_indices.data(),
385 unique_counts.data(),
391 unique_indices.resize(h_count);
392 unique_counts.resize(h_count);
393 loose_resize(offsets, unique_counts.size());
396 unique_counts.data(), offsets.data(), unique_counts.size());
399 auto& begin_offset = offsets;
400 auto& end_offset = unique_counts;
403 .kernel_name(
"calculate offset_ends")
404 .apply(unique_counts.size(),
405 [offset = offsets.cviewer().name(
"offset"),
406 counts = unique_counts.viewer().name(
"counts")] __device__(
int i)
mutable
407 { counts(i) += offset(i); });
414 using namespace muda;
416 auto& begin_offset = offsets;
417 auto& end_offset = unique_counts;
419 auto& unique_values = to.m_values;
420 unique_values.resize(unique_indices.size());
423 unique_values.data(),
424 unique_values.size(),
432 bool clear_dense_vector)
434 to.resize(from.size());
435 set_unique_values_to_dense_vector(from, to, clear_dense_vector);
442 using namespace muda;
444 if(clear_dense_vector)
447 auto& unique_values = from.m_values;
448 auto& unique_indices = from.m_indices;
451 .kernel_name(
"set unique values to dense vector")
452 .apply(unique_values.size(),
453 [unique_values = unique_values.cviewer().name(
"unique_values"),
454 unique_indices = unique_indices.cviewer().name(
"unique_indices"),
455 dst = to.viewer().name(
"dst_dense_vector")] __device__(
int i)
mutable
457 auto index = unique_indices(i);
458 dst(index) += unique_values(i);
466 bool clear_dense_vector)
468 using namespace muda;
470 to.resize(from.segment_count());
472 if(clear_dense_vector)
476 .kernel_name(__FUNCTION__)
477 .apply(from.doublet_count(),
478 [src = from.viewer().name(
"src_sparse_vector"),
479 dst = to.viewer().name(
"dst_dense_vector")] __device__(
int i)
mutable
481 auto&& [index, value] = src(i);
482 dst.segment<1>(index).atomic_add(value);
Definition device_bcoo_matrix.h:18
Definition device_bcoo_vector.h:28
Definition device_csr_matrix.h:16
Definition device_dense_matrix.h:16
Definition device_dense_vector.h:16
Definition device_doublet_vector.h:16
Definition device_merge_sort.h:12
Definition device_radix_sort.h:18
Definition device_reduce.h:12
Definition device_run_length_encode.h:12
Definition device_scan.h:21
Definition device_segmented_reduce.h:12
Definition device_triplet_matrix.h:14
a frequently used parallel for loop, DynamicBlockDim and GridStrideLoop strategy are provided,...
Definition parallel_for.h:116