MUDA
Loading...
Searching...
No Matches
evd.h
1#pragma once
2#include <muda/muda_def.h>
3#include <muda/ext/eigen/eigen_core_cxx20.h>
4#include <Eigen/Eigenvalues>
5namespace muda
6{
7namespace eigen
8{
9 template <typename T, int N>
10 MUDA_GENERIC void evd(const Eigen::Matrix<T, N, N>& M,
11 Eigen::Vector<T, N>& eigen_values,
12 Eigen::Matrix<T, N, N>& eigen_vectors)
13 {
14 Eigen::SelfAdjointEigenSolver<Eigen::Matrix<T, N, N>> eigen_solver;
15 // NOTE:
16 // On CUDA, if N <= 3, compute() is not supported.
17 // So, we use computeDirect() instead.
18 if constexpr(N <= 3)
19 eigen_solver.computeDirect(M);
20 else
21 eigen_solver.compute(M);
22 eigen_values = eigen_solver.eigenvalues();
23 eigen_vectors = eigen_solver.eigenvectors();
24 }
25} // namespace eigen
26} // namespace muda