File device_spmv.h
File List > cub > device > device_spmv.h
Go to the documentation of this file
#pragma once
#include <muda/cub/device/cub_wrapper.h>
#include "details/cub_wrapper_macro_def.inl"
#ifndef __INTELLISENSE__
#include <cub/device/device_spmv.cuh>
#endif
namespace muda
{
//ref: https://nvlabs.github.io/cub/structcub_1_1_device_spmv.html
class DeviceSpmv : public CubWrapper<DeviceSpmv>
{
using Base = CubWrapper<DeviceSpmv>;
public:
using Base::Base;
template <typename ValueT>
DeviceSpmv& CsrMV(const ValueT* d_values,
const int* d_row_offsets,
const int* d_column_indices,
const ValueT* d_vector_x,
ValueT* d_vector_y,
int num_rows,
int num_cols,
int num_nonzeros)
{
MUDA_CUB_WRAPPER_IMPL(cub::DeviceSpmv::CsrMV(d_temp_storage,
temp_storage_bytes,
d_values,
d_row_offsets,
d_column_indices,
d_vector_x,
d_vector_y,
num_rows,
num_cols,
num_nonzeros,
_stream,
false));
}
// Origin:
template <typename ValueT>
DeviceSpmv& CsrMV(void* d_temp_storage,
size_t& temp_storage_bytes,
const ValueT* d_values,
const int* d_row_offsets,
const int* d_column_indices,
const ValueT* d_vector_x,
ValueT* d_vector_y,
int num_rows,
int num_cols,
int num_nonzeros)
{
MUDA_CUB_WRAPPER_FOR_COMPUTE_GRAPH_IMPL(cub::DeviceSpmv::CsrMV(d_temp_storage,
temp_storage_bytes,
d_values,
d_row_offsets,
d_column_indices,
d_vector_x,
d_vector_y,
num_rows,
num_cols,
num_nonzeros,
_stream,
false));
}
};
} // namespace muda
#include "details/cub_wrapper_macro_undef.inl"