SIRIUS 7.5.0
Electronic structure library and applications
rocsolver.hpp
Go to the documentation of this file.
1// Copyright (c) 2023 Simon Pintarelli, Anton Kozhevnikov, Thomas Schulthess
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted provided that
5// the following conditions are met:
6//
7// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
8// following disclaimer.
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
10// and the following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
13// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
15// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
16// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
17// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
18// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
20/** \file rocsolver.hpp
21 *
22 * \brief Contains implementation of rocsolver wrappers
23 */
24
25#ifndef __ROCSOLVER_HPP__
26#define __ROCSOLVER_HPP__
27
28#include <rocsolver/rocsolver.h>
29#include <rocblas/rocblas.h>
30#include <unistd.h>
31#include "acc.hpp"
32#include "acc_blas_api.hpp"
33#include "core/rte/rte.hpp"
34
35namespace sirius {
36
37namespace acc {
38
39/// Interface to ROCM eigensolver.
40namespace rocsolver {
41
42#define CALL_ROCSOLVER(func__, args__) \
43 { \
44 rocblas_status status = func__ args__; \
45 if (status != rocblas_status::rocblas_status_success) { \
46 char nm[1024]; \
47 gethostname(nm, 1024); \
48 printf("hostname: %s\n", nm); \
49 printf("Error in %s at line %i of file %s: %s\n", #func__, __LINE__, __FILE__, \
50 rocblas_status_to_string(status)); \
51 acc::stack_backtrace(); \
52 } \
53 }
54
55acc::blas_api::handle_t& rocsolver_handle();
56
57inline rocblas_operation
58get_rocblas_operation(char trans)
59{
60 rocblas_operation op{rocblas_operation::rocblas_operation_none};
61 switch (trans) {
62 case 'n':
63 case 'N':
64 op = rocblas_operation::rocblas_operation_none;
65 break;
66 case 't':
67 case 'T':
68 op = rocblas_operation::rocblas_operation_transpose;
69 break;
70 case 'h':
71 case 'H':
72 op = rocblas_operation::rocblas_operation_conjugate_transpose;
73 break;
74 default:
75 RTE_THROW("invalid tranpose op.")
76 }
77
78 return op;
79}
80
81/// _sy_mmetric or _he_rmitian STANDARD eigenvalue problem | double
82template <class T>
83std::enable_if_t<std::is_same<T, double>::value>
84syheevd(rocblas_handle handle, const rocblas_evect evect, const rocblas_fill uplo, int n, T* A, int lda, T* D, T* E,
85 int* info)
86{
87 CALL_ROCSOLVER(rocsolver_dsyevd, (handle, evect, uplo, n, A, lda, D, E, info));
88}
89
90/// _sy_mmetric or _he_rmitian STANDARD eigenvalue problem | float
91template <class T>
92std::enable_if_t<std::is_same<T, float>::value>
93syheevd(rocblas_handle handle, const rocblas_evect evect, const rocblas_fill uplo, int n, T* A, int lda, T* D, T* E,
94 int* info)
95{
96 CALL_ROCSOLVER(rocsolver_ssyevd, (handle, evect, uplo, n, A, lda, D, E, info));
97}
98
99/// _sy_mmetric or _he_rmitian STANDARD eigenvalue problem | complex double
100template <class T>
101std::enable_if_t<std::is_same<T, double>::value>
102syheevd(rocblas_handle handle, const rocblas_evect evect, const rocblas_fill uplo, int n, std::complex<T>* A, int lda,
103 T* D, T* E, int* info)
104{
105 CALL_ROCSOLVER(rocsolver_zheevd,
106 (handle, evect, uplo, n, reinterpret_cast<rocblas_double_complex*>(A), lda, D, E, info));
107}
108
109/// _sy_mmetric or _he_rmitian STANDARD eigenvalue problem | complex float
110template <class T>
111std::enable_if_t<std::is_same<T, float>::value>
112syheevd(rocblas_handle handle, const rocblas_evect evect, const rocblas_fill uplo, int n, std::complex<T>* A, int lda,
113 T* D, T* E, int* info)
114{
115 CALL_ROCSOLVER(rocsolver_cheevd,
116 (handle, evect, uplo, n, reinterpret_cast<rocblas_float_complex*>(A), lda, D, E, info));
117}
118
119/// _sy_mmetric or _he_rmitian GENERALIZED eigenvalue problem | double
120template <class T>
121std::enable_if_t<std::is_same<T, double>::value>
122syhegvd(rocblas_handle handle, const rocblas_eform itype, const rocblas_evect evect, const rocblas_fill uplo, int n,
123 T* A, int lda, T* B, int ldb, T* D, T* E, int* info)
124{
125 CALL_ROCSOLVER(rocsolver_dsygvd, (handle, itype, evect, uplo, n, A, lda, B, ldb, D, E, info));
126}
127
128/// _sy_mmetric or _he_rmitian GENERALIZED eigenvalue problem | float
129template <class T>
130std::enable_if_t<std::is_same<T, float>::value>
131syhegvd(rocblas_handle handle, const rocblas_eform itype, const rocblas_evect evect, const rocblas_fill uplo, int n,
132 T* A, int lda, T* B, int ldb, T* D, T* E, int* info)
133{
134 CALL_ROCSOLVER(rocsolver_ssygvd, (handle, itype, evect, uplo, n, A, lda, B, ldb, D, E, info));
135}
136
137/// _sy_mmetric or _he_rmitian GENERALIZED eigenvalue problem | complex float
138template <class T>
139std::enable_if_t<std::is_same<T, float>::value>
140syhegvd(rocblas_handle handle, const rocblas_eform itype, const rocblas_evect evect, const rocblas_fill uplo, int n,
141 std::complex<T>* A, int lda, std::complex<T>* B, int ldb, T* D, T* E, int* info)
142{
143 CALL_ROCSOLVER(rocsolver_chegvd, (handle, itype, evect, uplo, n, reinterpret_cast<rocblas_float_complex*>(A), lda,
144 reinterpret_cast<rocblas_float_complex*>(B), ldb, D, E, info));
145}
146
147/// _sy_mmetric or _he_rmitian GENERALIZED eigenvalue problem | complex double
148template <class T>
149std::enable_if_t<std::is_same<T, double>::value>
150syhegvd(rocblas_handle handle, const rocblas_eform itype, const rocblas_evect evect, const rocblas_fill uplo, int n,
151 std::complex<T>* A, int lda, std::complex<T>* B, int ldb, T* D, T* E, int* info)
152{
153 CALL_ROCSOLVER(rocsolver_zhegvd, (handle, itype, evect, uplo, n, reinterpret_cast<rocblas_double_complex*>(A), lda,
154 reinterpret_cast<rocblas_double_complex*>(B), ldb, D, E, info));
155}
156
157#if (ROCSOLVER_VERSION_MAJOR > 3) || ((ROCSOLVER_VERSION_MAJOR == 3) && (ROCSOLVER_VERSION_MINOR >= 19))
158/// x versions
159/// -----------------------------------------------------------------------------------------------------------------
160template <class T>
161std::enable_if_t<std::is_same<T, double>::value>
162syheevx(rocblas_handle handle, const rocblas_evect evect, const rocblas_fill uplo, int n, T* A, int lda, int il, int iu,
163 double abstol, int* nev, T* D, T* Z, int ldz, int* ifail, int* info)
164{
165 double vl, vu{0}; // ingored if erange = erange_index
166 rocsolver_dsyevx(handle, evect, rocblas_erange::rocblas_erange_index, uplo, n, A, lda, vl, vu, il, iu, abstol, nev,
167 D, Z, ldz, ifail, info);
168}
169
170template <class T>
171std::enable_if_t<std::is_same<T, float>::value>
172syheevx(rocblas_handle handle, const rocblas_evect evect, const rocblas_fill uplo, int n, T* A, int lda, int il, int iu,
173 double abstol, int* nev, T* D, T* Z, int ldz, int* ifail, int* info)
174{
175 double vl, vu{0}; // ingored if erange = erange_index
176 rocsolver_ssyevx(handle, evect, rocblas_erange::rocblas_erange_index, uplo, n, A, lda, vl, vu, il, iu, abstol, nev,
177 D, Z, ldz, ifail, info);
178}
179
180/// Hermitian | complex double
181template <class T>
182std::enable_if_t<std::is_same<T, double>::value>
183syheevx(rocblas_handle handle, const rocblas_evect evect, const rocblas_fill uplo, int n, std::complex<double>* A,
184 int lda, int il, int iu, double abstol, int* nev, T* D, std::complex<double>* Z, int ldz, int* ifail, int* info)
185{
186 double vl, vu{0}; // ingored if erange = erange_index
187 rocsolver_zheevx(handle, evect, rocblas_erange::rocblas_erange_index, uplo, n,
188 reinterpret_cast<rocblas_double_complex*>(A), lda, vl, vu, il, iu, abstol, nev, D,
189 reinterpret_cast<rocblas_double_complex*>(Z), ldz, ifail, info);
190}
191
192template <class T>
193std::enable_if_t<std::is_same<T, float>::value>
194syheevx(rocblas_handle handle, const rocblas_evect evect, const rocblas_fill uplo, int n, std::complex<float>* A,
195 int lda, int il, int iu, double abstol, int* nev, T* D, std::complex<float>* Z, int ldz, int* ifail, int* info)
196{
197 double vl, vu{0}; // ingored if erange = erange_index
198 rocsolver_cheevx(handle, evect, rocblas_erange::rocblas_erange_index, uplo, n,
199 reinterpret_cast<rocblas_float_complex*>(A), lda, vl, vu, il, iu, abstol, nev, D,
200 reinterpret_cast<rocblas_float_complex*>(Z), ldz, ifail, info);
201}
202
203/// x versions
204/// -----------------------------------------------------------------------------------------------------------------
205template <class T>
206std::enable_if_t<std::is_same<T, double>::value>
207syhegvx(rocblas_handle handle, const rocblas_eform itype, const rocblas_evect evect, const rocblas_fill uplo, int n,
208 T* A, int lda, T* B, int ldb, int il, int iu, double abstol, int* nev, T* D, T* Z, int ldz, int* ifail,
209 int* info)
210{
211 double vl, vu{0}; // ingored if erange = erange_index
212 rocsolver_dsygvx(handle, itype, evect, rocblas_erange::rocblas_erange_index, uplo, n, A, lda, B, ldb, vl, vu, il,
213 iu, abstol, nev, D, Z, ldz, ifail, info);
214}
215
216template <class T>
217std::enable_if_t<std::is_same<T, float>::value>
218syhegvx(rocblas_handle handle, const rocblas_eform itype, const rocblas_evect evect, const rocblas_fill uplo, int n,
219 T* A, int lda, T* B, int ldb, int il, int iu, double abstol, int* nev, T* D, T* Z, int ldz, int* ifail,
220 int* info)
221{
222 double vl, vu{0}; // ingored if erange = erange_index
223 rocsolver_ssygvx(handle, itype, evect, rocblas_erange::rocblas_erange_index, uplo, n, A, lda, B, ldb, vl, vu, il,
224 iu, abstol, nev, D, Z, ldz, ifail, info);
225}
226
227/// Hermitian | complex double
228template <class T>
229std::enable_if_t<std::is_same<T, double>::value>
230syhegvx(rocblas_handle handle, const rocblas_eform itype, const rocblas_evect evect, const rocblas_fill uplo, int n,
231 std::complex<double>* A, int lda, std::complex<double>* B, int ldb, int il, int iu, double abstol, int* nev,
232 T* D, std::complex<double>* Z, int ldz, int* ifail, int* info)
233{
234 double vl, vu{0}; // ingored if erange = erange_index
235 rocsolver_zhegvx(handle, itype, evect, rocblas_erange::rocblas_erange_index, uplo, n,
236 reinterpret_cast<rocblas_double_complex*>(A), lda, reinterpret_cast<rocblas_double_complex*>(B),
237 ldb, vl, vu, il, iu, abstol, nev, D, reinterpret_cast<rocblas_double_complex*>(Z), ldz, ifail,
238 info);
239}
240
241template <class T>
242std::enable_if_t<std::is_same<T, float>::value>
243syhegvx(rocblas_handle handle, const rocblas_eform itype, const rocblas_evect evect, const rocblas_fill uplo, int n,
244 std::complex<float>* A, int lda, std::complex<float>* B, int ldb, int il, int iu, double abstol, int* nev, T* D,
245 std::complex<float>* Z, int ldz, int* ifail, int* info)
246{
247 double vl, vu{0}; // ingored if erange = erange_index
248 rocsolver_chegvx(handle, itype, evect, rocblas_erange::rocblas_erange_index, uplo, n,
249 reinterpret_cast<rocblas_float_complex*>(A), lda, reinterpret_cast<rocblas_float_complex*>(B), ldb,
250 vl, vu, il, iu, abstol, nev, D, reinterpret_cast<rocblas_float_complex*>(Z), ldz, ifail, info);
251}
252#endif // rocsolver >=5.3.0
253
254/// Linear Solvers
255void
256zgetrs(rocblas_handle handle, char trans, int n, int nrhs, acc_complex_double_t* A, int lda, const int* devIpiv,
257 acc_complex_double_t* B, int ldb);
258
259void
260zgetrf(rocblas_handle handle, int m, int n, acc_complex_double_t* A, int* devIpiv, int lda, int* devInfo);
261
262} // namespace rocsolver
263
264} // namespace acc
265
266} // namespace sirius
267
268#endif
Interface to accelerators API.
Interface to cuBLAS / rocblas related functions.
void zgetrs(rocblas_handle handle, char trans, int n, int nrhs, acc_complex_double_t *A, int lda, const int *devIpiv, acc_complex_double_t *B, int ldb)
Linear Solvers.
std::enable_if_t< std::is_same< T, double >::value > syhegvd(rocblas_handle handle, const rocblas_eform itype, const rocblas_evect evect, const rocblas_fill uplo, int n, T *A, int lda, T *B, int ldb, T *D, T *E, int *info)
_sy_mmetric or _he_rmitian GENERALIZED eigenvalue problem | double
Definition: rocsolver.hpp:122
std::enable_if_t< std::is_same< T, double >::value > syheevd(rocblas_handle handle, const rocblas_evect evect, const rocblas_fill uplo, int n, T *A, int lda, T *D, T *E, int *info)
_sy_mmetric or _he_rmitian STANDARD eigenvalue problem | double
Definition: rocsolver.hpp:84
Namespace of the SIRIUS library.
Definition: sirius.f90:5
Eror and warning handling during run-time execution.