MUDA
Loading...
Searching...
No Matches
matrix_format_converter_impl.inl
1#include <muda/cub/device/device_merge_sort.h>
2#include <muda/cub/device/device_radix_sort.h>
3#include <muda/cub/device/device_run_length_encode.h>
4#include <muda/cub/device/device_scan.h>
5#include <muda/cub/device/device_segmented_reduce.h>
6#include <muda/cub/device/device_reduce.h>
7#include <muda/launch.h>
8
9namespace muda::details
10{
11// using T = float;
12template <typename T>
13void MatrixFormatConverter<T, 1>::convert(const DeviceTripletMatrix<T, 1>& from,
14 DeviceBCOOMatrix<T, 1>& to)
15{
16 to.reshape(from.rows(), from.cols());
17 to.resize_triplets(from.triplet_count());
18 if(to.triplet_count() == 0)
19 return;
20 merge_sort_indices_and_values(from, to);
21 make_unique_indices(from, to);
22 make_unique_values(from, to);
23}
24
25template <typename T>
26void MatrixFormatConverter<T, 1>::merge_sort_indices_and_values(
27 const DeviceTripletMatrix<T, 1>& from, DeviceBCOOMatrix<T, 1>& to)
28{
29 using namespace muda;
30
31 auto src_row_indices = from.row_indices();
32 auto src_col_indices = from.col_indices();
33 auto src_values = from.values();
34
35 loose_resize(sort_index, src_row_indices.size());
36 loose_resize(ij_pairs, src_row_indices.size());
37
38 ParallelFor(256)
39 .kernel_name("set ij pairs")
40 .apply(src_row_indices.size(),
41 [row_indices = src_row_indices.cviewer().name("row_indices"),
42 col_indices = src_col_indices.cviewer().name("col_indices"),
43 ij_pairs = ij_pairs.viewer().name("ij_pairs")] __device__(int i) mutable
44 {
45 ij_pairs(i).x = row_indices(i);
46 ij_pairs(i).y = col_indices(i);
47 });
48
49 ParallelFor(256)
50 .kernel_name("iota") //
51 .apply(src_row_indices.size(),
52 [sort_index = sort_index.viewer().name("sort_index")] __device__(int i) mutable
53 { sort_index(i) = i; });
54
55 DeviceMergeSort().SortPairs(ij_pairs.data(),
56 sort_index.data(),
57 ij_pairs.size(),
58 [] __device__(const int2& a, const int2& b) {
59 return a.x < b.x || (a.x == b.x && a.y < b.y);
60 });
61
62 // set ij_pairs back to row_indices and col_indices
63
64 auto dst_row_indices = to.row_indices();
65 auto dst_col_indices = to.col_indices();
66
67 ParallelFor(256)
68 .kernel_name("set col row indices")
69 .apply(dst_row_indices.size(),
70 [row_indices = dst_row_indices.viewer().name("row_indices"),
71 col_indices = dst_col_indices.viewer().name("col_indices"),
72 ij_pairs = ij_pairs.viewer().name("ij_pairs")] __device__(int i) mutable
73 {
74 row_indices(i) = ij_pairs(i).x;
75 col_indices(i) = ij_pairs(i).y;
76 });
77
78 // sort the block values
79
80 loose_resize(unique_values, from.m_values.size());
81
82 ParallelFor(256)
83 .kernel_name("set block values")
84 .apply(src_values.size(),
85 [src_values = src_values.cviewer().name("blocks"),
86 sort_index = sort_index.cviewer().name("sort_index"),
87 dst_values = unique_values.viewer().name("values")] __device__(int i) mutable
88 { dst_values(i) = src_values(sort_index(i)); });
89}
90
91template <typename T>
94{
95 using namespace muda;
96
97 auto& row_indices = to.m_row_indices;
98 auto& col_indices = to.m_col_indices;
99
100 loose_resize(unique_ij_pairs, ij_pairs.size());
101 loose_resize(unique_counts, ij_pairs.size());
102
103 DeviceRunLengthEncode().Encode(ij_pairs.data(),
104 unique_ij_pairs.data(),
105 unique_counts.data(),
106 count.data(),
107 ij_pairs.size());
108
109 int h_count = count;
110
111 unique_ij_pairs.resize(h_count);
112 unique_counts.resize(h_count);
113
114 loose_resize(offsets, unique_counts.size());
115
116 DeviceScan().ExclusiveSum(
117 unique_counts.data(), offsets.data(), unique_counts.size());
118
119
121 .kernel_name("make unique indices")
122 .apply(unique_counts.size(),
123 [unique_ij_pairs = unique_ij_pairs.viewer().name("unique_ij_pairs"),
124 row_indices = row_indices.viewer().name("row_indices"),
125 col_indices = col_indices.viewer().name("col_indices")] __device__(int i) mutable
126 {
127 row_indices(i) = unique_ij_pairs(i).x;
128 col_indices(i) = unique_ij_pairs(i).y;
129 });
130
131 row_indices.resize(h_count);
132 col_indices.resize(h_count);
133}
134
135template <typename T>
138{
139 using namespace muda;
140
141 auto& row_indices = to.m_row_indices;
142 auto& values = to.m_values;
143 values.resize(row_indices.size());
144 // first we add the offsets to counts, to get the offset_ends
145
146 ParallelFor(256)
147 .kernel_name("calculate offset_ends")
148 .apply(unique_counts.size(),
149 [offset = offsets.cviewer().name("offset"),
150 counts = unique_counts.viewer().name("counts")] __device__(int i) mutable
151 { counts(i) += offset(i); });
152
153 auto& begin_offset = offsets;
154 auto& end_offset = unique_counts; // already contains the offset_ends
155
156 // then we do a segmented reduce to get the unique blocks
157 DeviceSegmentedReduce().Sum(unique_values.data(),
158 values.data(),
159 values.size(),
160 offsets.data(),
161 end_offset.data());
162}
163
164template <typename T>
167{
168 auto src_row_indices = from.row_indices();
169 auto src_col_indices = from.col_indices();
170 auto src_blocks = from.values();
171
172 loose_resize(ij_hash_input, src_row_indices.size());
173 loose_resize(sort_index_input, src_row_indices.size());
174
175 loose_resize(ij_hash, src_row_indices.size());
176 loose_resize(sort_index, src_row_indices.size());
177 ij_pairs.resize(src_row_indices.size());
178
179
180 // hash ij
181 ParallelFor(256)
182 .kernel_name(__FUNCTION__)
183 .apply(src_row_indices.size(),
184 [row_indices = src_row_indices.cviewer().name("row_indices"),
185 col_indices = src_col_indices.cviewer().name("col_indices"),
186 ij_hash = ij_hash_input.viewer().name("ij_hash"),
187 sort_index = sort_index_input.viewer().name("sort_index")] __device__(int i) mutable
188 {
189 ij_hash(i) =
190 (uint64_t{row_indices(i)} << 32) + uint64_t{col_indices(i)};
191 sort_index(i) = i;
192 });
193
194 DeviceRadixSort().SortPairs(ij_hash_input.data(),
195 ij_hash.data(),
196 sort_index_input.data(),
197 sort_index.data(),
198 ij_hash.size());
199
200 // sort the block values
201
202 {
203 loose_resize(values_sorted, from.values().size());
204 ParallelFor(256)
205 .kernel_name(__FUNCTION__)
206 .apply(src_blocks.size(),
207 [src_blocks = src_blocks.cviewer().name("blocks"),
208 sort_index = sort_index.cviewer().name("sort_index"),
209 dst_blocks = values_sorted.viewer().name("block_values")] __device__(int i) mutable
210 { dst_blocks(i) = src_blocks(sort_index(i)); });
211 }
212}
213
214template <typename T>
217{
218 // alias to reuse the memory
219 auto& unique_ij_hash = ij_hash_input;
220
221 muda::DeviceReduce().ReduceByKey(
222 ij_hash.data(),
223 unique_ij_hash.data(),
224 values_sorted.data(),
225 to.values().data(),
226 count.data(),
227 [] CUB_RUNTIME_FUNCTION(const T& l, const T& r) -> T { return l + r; },
228 ij_hash.size());
229
230 int h_count = count;
231
232 to.resize_triplets(h_count);
233
234 // set ij_hash back to row_indices and col_indices
236 .kernel_name("set col row indices")
237 .apply(to.row_indices().size(),
238 [ij_hash = unique_ij_hash.viewer().name("ij_hash"),
239 row_indices = to.row_indices().viewer().name("row_indices"),
240 col_indices = to.col_indices().viewer().name("col_indices")] __device__(int i) mutable
241 {
242 auto hash = ij_hash(i);
243 auto row_index = int{hash >> 32};
244 auto col_index = int{hash & 0xFFFFFFFF};
245 row_indices(i) = row_index;
246 col_indices(i) = col_index;
247 });
248}
249
250
251template <typename T>
254 bool clear_dense_matrix)
255{
256 using namespace muda;
257 auto size = from.rows();
258 to.reshape(size, size);
259
260 if(clear_dense_matrix)
261 to.fill(0);
262
263 ParallelFor(256)
264 .kernel_name(__FUNCTION__)
265 .apply(from.values().size(),
266 [values = from.cviewer().name("src_sparse_matrix"),
267 dst = to.viewer().name("dst_dense_matrix")] __device__(int i) mutable
268 {
269 auto value = values(i);
270 auto row = value.row_index;
271 auto col = value.col_index;
272 dst(row, col) += value.value;
273 });
274}
275
276template <typename T>
279{
280 calculate_block_offsets(from, to);
281 to.m_col_indices = from.m_col_indices;
282 to.m_values = from.m_values;
283}
284
285template <typename T>
287{
288 calculate_block_offsets(from, to);
289 to.m_col_indices = std::move(from.m_col_indices);
290 to.m_values = std::move(from.m_values);
291}
292
293template <typename T>
296{
297 using namespace muda;
298 to.reshape(from.rows(), from.cols());
299
300 auto& dst_row_offsets = to.m_row_offsets;
301
302 // alias the offsets to the col_counts_per_row(reuse)
303 auto& col_counts_per_row = offsets;
304 col_counts_per_row.resize(to.m_row_offsets.size());
305 col_counts_per_row.fill(0);
306
307 loose_resize(unique_indices, from.non_zeros());
308 loose_resize(unique_counts, from.non_zeros());
309
310 // run length encode the row
311 DeviceRunLengthEncode().Encode(from.m_row_indices.data(),
312 unique_indices.data(),
313 unique_counts.data(),
314 count.data(),
315 from.non_zeros());
316 int h_count = count;
317
318 unique_indices.resize(h_count);
319 unique_counts.resize(h_count);
320
321 ParallelFor(256)
322 .kernel_name(__FUNCTION__)
323 .apply(unique_counts.size(),
324 [unique_indices = unique_indices.cviewer().name("offset"),
325 counts = unique_counts.viewer().name("counts"),
326 col_counts_per_row = col_counts_per_row.viewer().name(
327 "col_counts_per_row")] __device__(int i) mutable
328 {
329 auto row = unique_indices(i);
330 col_counts_per_row(row) = counts(i);
331 });
332
333 // calculate the offsets
334 DeviceScan().ExclusiveSum(col_counts_per_row.data(),
335 dst_row_offsets.data(),
336 col_counts_per_row.size());
337}
338
339template <typename T>
342{
343 to.reshape(from.size());
344 to.resize_doublet(from.doublet_count());
345
346 merge_sort_indices_and_values(from, to);
347 make_unique_indices(from, to);
348 make_unique_values(from, to);
349}
350
351template <typename T>
354{
355 using namespace muda;
356
357 auto& indices = sort_index;
358 auto& values = temp_values;
359
360 indices = from.m_indices;
361 values = from.m_values;
362
363 DeviceMergeSort().SortPairs(indices.data(),
364 values.data(),
365 indices.size(),
366 [] __device__(const int& a, const int& b)
367 { return a < b; });
368}
369
370template <typename T>
373{
374 using namespace muda;
375
376 auto& indices = to.m_indices;
377 auto& values = to.m_values;
378
379 auto& unique_indices = to.m_indices;
380 unique_indices.resize(indices.size());
381 loose_resize(unique_counts, indices.size());
382
383 DeviceRunLengthEncode().Encode(indices.data(),
384 unique_indices.data(),
385 unique_counts.data(),
386 count.data(),
387 indices.size());
388
389 int h_count = count;
390
391 unique_indices.resize(h_count);
392 unique_counts.resize(h_count);
393 loose_resize(offsets, unique_counts.size());
394
395 DeviceScan().ExclusiveSum(
396 unique_counts.data(), offsets.data(), unique_counts.size());
397
398 // calculate the offset_ends, and set to the unique_counts
399 auto& begin_offset = offsets;
400 auto& end_offset = unique_counts;
401
402 ParallelFor(256)
403 .kernel_name("calculate offset_ends")
404 .apply(unique_counts.size(),
405 [offset = offsets.cviewer().name("offset"),
406 counts = unique_counts.viewer().name("counts")] __device__(int i) mutable
407 { counts(i) += offset(i); });
408}
409
410template <typename T>
413{
414 using namespace muda;
415
416 auto& begin_offset = offsets;
417 auto& end_offset = unique_counts;
418
419 auto& unique_values = to.m_values;
420 unique_values.resize(unique_indices.size());
421
422 DeviceSegmentedReduce().Sum(temp_values.data(),
423 unique_values.data(),
424 unique_values.size(),
425 begin_offset.data(),
426 end_offset.data());
427}
428
429template <typename T>
432 bool clear_dense_vector)
433{
434 to.resize(from.size());
435 set_unique_values_to_dense_vector(from, to, clear_dense_vector);
436}
437
438template <typename T>
440 const DeviceDoubletVector<T, 1>& from, DeviceDenseVector<T>& to, bool clear_dense_vector)
441{
442 using namespace muda;
443
444 if(clear_dense_vector)
445 to.fill(0);
446
447 auto& unique_values = from.m_values;
448 auto& unique_indices = from.m_indices;
449
450 ParallelFor(256)
451 .kernel_name("set unique values to dense vector")
452 .apply(unique_values.size(),
453 [unique_values = unique_values.cviewer().name("unique_values"),
454 unique_indices = unique_indices.cviewer().name("unique_indices"),
455 dst = to.viewer().name("dst_dense_vector")] __device__(int i) mutable
456 {
457 auto index = unique_indices(i);
458 dst(index) += unique_values(i);
459 });
460}
461
462// using T = float;
463template <typename T>
466 bool clear_dense_vector)
467{
468 using namespace muda;
469
470 to.resize(from.segment_count());
471
472 if(clear_dense_vector)
473 to.fill(0);
474
475 ParallelFor(256)
476 .kernel_name(__FUNCTION__)
477 .apply(from.doublet_count(),
478 [src = from.viewer().name("src_sparse_vector"),
479 dst = to.viewer().name("dst_dense_vector")] __device__(int i) mutable
480 {
481 auto&& [index, value] = src(i);
482 dst.segment<1>(index).atomic_add(value);
483 });
484}
485} // namespace muda::details
Definition device_bcoo_matrix.h:18
Definition device_bcoo_vector.h:28
Definition device_csr_matrix.h:16
Definition device_dense_matrix.h:16
Definition device_dense_vector.h:16
Definition device_doublet_vector.h:16
Definition device_merge_sort.h:12
Definition device_radix_sort.h:18
Definition device_reduce.h:12
Definition device_run_length_encode.h:12
Definition device_scan.h:21
Definition device_segmented_reduce.h:12
Definition device_triplet_matrix.h:14
Definition matrix_format_converter.h:48
a frequently used parallel for loop, DynamicBlockDim and GridStrideLoop strategy are provided,...
Definition parallel_for.h:116