SIRIUS 7.5.0
Electronic structure library and applications
linalg.hpp
Go to the documentation of this file.
1// Copyright (c) 2013-2020 Anton Kozhevnikov, Thomas Schulthess
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted provided that
5// the following conditions are met:
6//
7// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
8// following disclaimer.
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
10// and the following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
13// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
15// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
16// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
17// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
18// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
20/** \file linalg.hpp
21 *
22 * \brief Linear algebra interface.
23 */
24
25#ifndef __LINALG_HPP__
26#define __LINALG_HPP__
27
28#include <stdint.h>
29#include "core/acc/acc.hpp"
30#if defined(SIRIUS_GPU)
31#include "core/acc/acc_blas.hpp"
32#include "core/acc/acc_lapack.hpp"
33#endif
34#if defined(SIRIUS_MAGMA)
35#include "core/acc/magma.hpp"
36#endif
37#if defined(SIRIUS_GPU) and defined(SIRIUS_CUDA)
38#include "core/acc/cusolver.hpp"
39#endif
40#include "blas_lapack.h"
41#include "SDDK/memory.hpp"
42#include "dmatrix.hpp"
43#include "linalg_spla.hpp"
44
45namespace sirius {
46
47namespace la {
48
49namespace _local {
50/// check if device id has been set properly
51inline bool is_set_device_id()
52{
54}
55}
56
57#define linalg_msg_wrong_type "[" + std::string(__func__) + "] wrong type of linear algebra library: " + to_string(la_)
58
59const std::string linalg_msg_no_scalapack = "not compiled with ScaLAPACK";
60
61class wrap
62{
63 private:
64 lib_t la_;
65 public:
66 wrap(lib_t la__)
67 : la_(la__)
68 {
69 }
70
71 /*
72 BLAS Level 1
73 */
74
75 /// vector addition
76 template <typename T>
77 inline void axpy(int n, T const* alpha, T const* x, int incx, T* y, int incy);
78
79 /*
80 matrix - matrix multiplication
81 */
82
83 /// General matrix-matrix multiplication.
84 /** Compute C = alpha * op(A) * op(B) + beta * op(C) with raw pointers. */
85 template <typename T>
86 inline void gemm(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, T const* alpha, T const* A, ftn_int lda,
87 T const* B, ftn_int ldb, T const* beta, T* C, ftn_int ldc, acc::stream_id sid = acc::stream_id(-1)) const;
88
89 /// Distributed general matrix-matrix multiplication.
90 /** Compute C = alpha * op(A) * op(B) + beta * op(C) for distributed matrices. */
91 template <typename T>
92 inline void gemm(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, T const* alpha,
93 dmatrix<T> const& A, ftn_int ia, ftn_int ja, dmatrix<T> const& B,
94 ftn_int ib, ftn_int jb, T const* beta, dmatrix<T>& C, ftn_int ic, ftn_int jc);
95
96 /// Hermitian matrix times a general matrix or vice versa.
97 /** Perform one of the matrix-matrix operations \n
98 * C = alpha * A * B + beta * C (side = 'L') \n
99 * C = alpha * B * A + beta * C (side = 'R'), \n
100 * where A is a hermitian matrix with upper (uplo = 'U') of lower (uplo = 'L') triangular part defined.
101 */
102 template<typename T>
103 inline void hemm(char side, char uplo, ftn_int m, ftn_int n, T const* alpha, T const* A, ftn_len lda,
104 T const* B, ftn_len ldb, T const* beta, T* C, ftn_len ldc);
105
106 template <typename T>
107 inline void trmm(char side, char uplo, char transa, ftn_int m, ftn_int n, T const* aplha, T const* A, ftn_int lda,
108 T* B, ftn_int ldb, acc::stream_id sid = acc::stream_id(-1)) const;
109
110 /*
111 rank2 update
112 */
113
114 template<typename T>
115 inline void ger(ftn_int m, ftn_int n, T const* alpha, T const* x, ftn_int incx, T const* y, ftn_int incy, T* A, ftn_int lda,
116 acc::stream_id sid = acc::stream_id(-1)) const;
117
118 /*
119 matrix factorization
120 */
121
122 /// Cholesky factorization
123 template <typename T>
124 inline int potrf(ftn_int n, T* A, ftn_int lda, ftn_int const* desca = nullptr) const;
125
126 /// LU factorization of general matrix.
127 template <typename T>
128 inline int getrf(ftn_int m, ftn_int n, T* A, ftn_int lda, ftn_int* ipiv) const;
129
130 /// LU factorization of general matrix.
131 template <typename T>
132 inline int getrf(ftn_int m, ftn_int n, dmatrix<T>& A, ftn_int ia, ftn_int ja, ftn_int* ipiv) const;
133
134 template <typename T>
135 inline int getrs(char trans, ftn_int n, ftn_int nrhs, T const* A, ftn_int lda, ftn_int* ipiv, T* B,
136 ftn_int ldb) const;
137
138 /// U*D*U^H factorization of hermitian or symmetric matrix.
139 template <typename T>
140 inline int sytrf(ftn_int n, T* A, ftn_int lda, ftn_int* ipiv) const;
141
142 /// solve Ax=b in place of b where A is factorized with sytrf.
143 template <typename T>
144 inline int sytrs(ftn_int n, ftn_int nrhs, T* A, ftn_int lda, ftn_int* ipiv, T* b, ftn_int ldb) const;
145
146 /*
147 matrix inversion
148 */
149
150 /// Inversion of a triangular matrix.
151 template <typename T>
152 inline int trtri(ftn_int n, T* A, ftn_int lda, ftn_int const* desca = nullptr) const;
153
154 template <typename T>
155 inline int getri(ftn_int n, T* A, ftn_int lda, ftn_int* ipiv) const;
156
157 /// Inversion of factorized symmetric triangular matrix.
158 template <typename T>
159 inline int sytri(ftn_int n, T* A, ftn_int lda, ftn_int* ipiv) const;
160
161 /// Invert a general matrix.
162 template <typename T>
163 inline void geinv(ftn_int n, sddk::matrix<T>& A) const
164 {
165 std::vector<int> ipiv(n);
166 int info = this->getrf(n, n, A.at(sddk::memory_t::host), A.ld(), &ipiv[0]);
167 if (info) {
168 std::printf("getrf returned %i\n", info);
169 exit(-1);
170 }
171
172 info = this->getri(n, A.at(sddk::memory_t::host), A.ld(), &ipiv[0]);
173 if (info) {
174 std::printf("getri returned %i\n", info);
175 exit(-1);
176 }
177 }
178
179 template <typename T>
180 inline void syinv(ftn_int n, sddk::matrix<T>& A) const
181 {
182 std::vector<int> ipiv(n);
183 int info = this->sytrf(n, A.at(sddk::memory_t::host), A.ld(), &ipiv[0]);
184 if (info) {
185 std::printf("sytrf returned %i\n", info);
186 exit(-1);
187 }
188
189 info = this->sytri(n, A.at(sddk::memory_t::host), A.ld(), &ipiv[0]);
190 if (info) {
191 std::printf("sytri returned %i\n", info);
192 exit(-1);
193 }
194 }
195
196 template <typename T>
197 inline bool sysolve(ftn_int n, sddk::matrix<T> &A, sddk::mdarray<T, 1> &b) const
198 {
199 std::vector<int> ipiv(n);
200 int info = this->sytrf(n, A.at(sddk::memory_t::host), A.ld(), ipiv.data());
201 if (info) return false;
202
203 info = this->sytrs(n, 1, A.at(sddk::memory_t::host), A.ld(), ipiv.data(), b.at(sddk::memory_t::host), b.ld());
204
205 return !info;
206 }
207
208 /*
209 solution of a linear system
210 */
211
212 /// Compute the solution to system of linear equations A * X = B for general tri-diagonal matrix.
213 template <typename T>
214 inline int gtsv(ftn_int n, ftn_int nrhs, T* dl, T* d, T* du, T* b, ftn_int ldb) const;
215
216 /// Compute the solution to system of linear equations A * X = B for general matrix.
217 template <typename T>
218 inline int gesv(ftn_int n, ftn_int nrhs, T* A, ftn_int lda, T* B, ftn_int ldb) const;
219
220 /*
221 matrix transposition
222 */
223
224 /// Conjugate transpose matrix
225 /** \param [in] m Number of rows of the target sub-matrix.
226 \param [in] n Number of columns of the target sub-matrix.
227 \param [in] A Input matrix
228 \param [in] ia Starting row index of sub-matrix inside A
229 \param [in] ja Starting column index of sub-matrix inside A
230 \param [out] C Output matrix
231 \param [in] ic Starting row index of sub-matrix inside C
232 \param [in] jc Starting column index of sub-matrix inside C
233 */
234 template <typename T>
235 inline void tranc(ftn_int m, ftn_int n, dmatrix<T>& A, ftn_int ia, ftn_int ja, dmatrix<T>& C,
236 ftn_int ic, ftn_int jc) const;
237
238 /// Transpose matrix without conjugation.
239 template <typename T>
240 inline void tranu(ftn_int m, ftn_int n, dmatrix<T>& A, ftn_int ia, ftn_int ja, dmatrix<T>& C,
241 ftn_int ic, ftn_int jc) const;
242
243 // Constructing a Given's rotation
244 template <typename T>
245 inline std::tuple<ftn_double, ftn_double, ftn_double> lartg(T f, T g) const;
246
247 template <typename T>
248 inline void geqrf(ftn_int m, ftn_int n, dmatrix<T>& A, ftn_int ia, ftn_int ja);
249};
250
251template<>
252inline void
253wrap::geqrf<ftn_double_complex>(ftn_int m, ftn_int n, dmatrix<ftn_double_complex>& A, ftn_int ia, ftn_int ja)
254{
255 switch (la_) {
256 case lib_t::scalapack: {
257#if defined(SIRIUS_SCALAPACK)
258 ia++; ja++;
259 ftn_int lwork = -1;
260 ftn_double_complex z;
261 ftn_int info;
262 FORTRAN(pzgeqrf)(&m, &n, A.at(sddk::memory_t::host), &ia, &ja, const_cast<int*>(A.descriptor()), &z, &z, &lwork,
263 &info);
264 lwork = static_cast<int>(z.real() + 1);
265 std::vector<ftn_double_complex> work(lwork);
266 std::vector<ftn_double_complex> tau(std::max(m, n));
267 FORTRAN(pzgeqrf)(&m, &n, A.at(sddk::memory_t::host), &ia, &ja, const_cast<int*>(A.descriptor()), tau.data(),
268 work.data(), &lwork, &info);
269#else
270 throw std::runtime_error(linalg_msg_no_scalapack);
271#endif
272 break;
273 }
274 case lib_t::lapack: {
275 if (A.comm().size() != 1) {
276 throw std::runtime_error("[geqrf] can't use lapack for distributed matrix; use scalapck instead");
277 }
278 ftn_int lwork = -1;
279 ftn_double_complex z;
280 ftn_int info;
281 ftn_int lda = A.ld();
282 FORTRAN(zgeqrf)(&m, &n, A.at(sddk::memory_t::host, ia, ja), &lda, &z, &z, &lwork, &info);
283 lwork = static_cast<int>(z.real() + 1);
284 std::vector<ftn_double_complex> work(lwork);
285 std::vector<ftn_double_complex> tau(std::max(m, n));
286 FORTRAN(zgeqrf)(&m, &n, A.at(sddk::memory_t::host, ia, ja), &lda, tau.data(), work.data(), &lwork, &info);
287 break;
288 }
289 default: {
290 throw std::runtime_error(linalg_msg_wrong_type);
291 break;
292 }
293 }
294}
295
296template<>
297inline void
298wrap::geqrf<ftn_double>(ftn_int m, ftn_int n, dmatrix<ftn_double>& A, ftn_int ia, ftn_int ja)
299{
300 switch (la_) {
301 case lib_t::scalapack: {
302#if defined(SIRIUS_SCALAPACK)
303 ia++; ja++;
304 ftn_int lwork = -1;
305 ftn_double z;
306 ftn_int info;
307 FORTRAN(pdgeqrf)(&m, &n, A.at(sddk::memory_t::host), &ia, &ja, const_cast<int*>(A.descriptor()), &z, &z, &lwork,
308 &info);
309 lwork = static_cast<int>(z + 1);
310 std::vector<ftn_double> work(lwork);
311 std::vector<ftn_double> tau(std::max(m, n));
312 FORTRAN(pdgeqrf)(&m, &n, A.at(sddk::memory_t::host), &ia, &ja, const_cast<int*>(A.descriptor()), tau.data(),
313 work.data(), &lwork, &info);
314#else
315 throw std::runtime_error(linalg_msg_no_scalapack);
316#endif
317 break;
318 }
319 case lib_t::lapack: {
320 if (A.comm().size() != 1) {
321 throw std::runtime_error("[geqrf] can't use lapack for distributed matrix; use scalapck instead");
322 }
323 ftn_int lwork = -1;
324 ftn_double z;
325 ftn_int info;
326 ftn_int lda = A.ld();
327 FORTRAN(dgeqrf)(&m, &n, A.at(sddk::memory_t::host, ia, ja), &lda, &z, &z, &lwork, &info);
328 lwork = static_cast<int>(z + 1);
329 std::vector<ftn_double> work(lwork);
330 std::vector<ftn_double> tau(std::max(m, n));
331 FORTRAN(dgeqrf)(&m, &n, A.at(sddk::memory_t::host, ia, ja), &lda, tau.data(), work.data(), &lwork, &info);
332 break;
333 }
334 default: {
335 throw std::runtime_error(linalg_msg_wrong_type);
336 break;
337 }
338 }
339}
340
341template<>
342inline void
343wrap::geqrf<ftn_complex>(ftn_int m, ftn_int n, dmatrix<ftn_complex>& A, ftn_int ia, ftn_int ja)
344{
345 switch (la_) {
346 case lib_t::scalapack: {
347#if defined(SIRIUS_SCALAPACK)
348 ia++; ja++;
349 ftn_int lwork = -1;
350 ftn_complex z;
351 ftn_int info;
352 FORTRAN(pcgeqrf)(&m, &n, A.at(sddk::memory_t::host), &ia, &ja, const_cast<int*>(A.descriptor()), &z, &z, &lwork,
353 &info);
354 lwork = static_cast<int>(z.real() + 1);
355 std::vector<ftn_complex> work(lwork);
356 std::vector<ftn_complex> tau(std::max(m, n));
357 FORTRAN(pcgeqrf)(&m, &n, A.at(sddk::memory_t::host), &ia, &ja, const_cast<int*>(A.descriptor()), tau.data(),
358 work.data(), &lwork, &info);
359#else
360 throw std::runtime_error(linalg_msg_no_scalapack);
361#endif
362 break;
363 }
364 case lib_t::lapack: {
365 if (A.comm().size() != 1) {
366 throw std::runtime_error("[geqrf] can't use lapack for distributed matrix; use scalapck instead");
367 }
368 ftn_int lwork = -1;
369 ftn_complex z;
370 ftn_int info;
371 ftn_int lda = A.ld();
372 FORTRAN(cgeqrf)(&m, &n, A.at(sddk::memory_t::host, ia, ja), &lda, &z, &z, &lwork, &info);
373 lwork = static_cast<int>(z.real() + 1);
374 std::vector<ftn_complex> work(lwork);
375 std::vector<ftn_complex> tau(std::max(m, n));
376 FORTRAN(cgeqrf)(&m, &n, A.at(sddk::memory_t::host, ia, ja), &lda, tau.data(), work.data(), &lwork, &info);
377 break;
378 }
379 default: {
380 throw std::runtime_error(linalg_msg_wrong_type);
381 break;
382 }
383 }
384}
385
386template<>
387inline void
388wrap::geqrf<ftn_single>(ftn_int m, ftn_int n, dmatrix<ftn_single>& A, ftn_int ia, ftn_int ja)
389{
390 switch (la_) {
391 case lib_t::scalapack: {
392#if defined(SIRIUS_SCALAPACK)
393 ia++; ja++;
394 ftn_int lwork = -1;
395 ftn_single z;
396 ftn_int info;
397 FORTRAN(psgeqrf)(&m, &n, A.at(sddk::memory_t::host), &ia, &ja, const_cast<int*>(A.descriptor()), &z, &z, &lwork,
398 &info);
399 lwork = static_cast<int>(z + 1);
400 std::vector<ftn_single> work(lwork);
401 std::vector<ftn_single> tau(std::max(m, n));
402 FORTRAN(psgeqrf)(&m, &n, A.at(sddk::memory_t::host), &ia, &ja, const_cast<int*>(A.descriptor()), tau.data(),
403 work.data(), &lwork, &info);
404#else
405 throw std::runtime_error(linalg_msg_no_scalapack);
406#endif
407 break;
408 }
409 case lib_t::lapack: {
410 if (A.comm().size() != 1) {
411 throw std::runtime_error("[geqrf] can't use lapack for distributed matrix; use scalapck instead");
412 }
413 ftn_int lwork = -1;
414 ftn_single z;
415 ftn_int info;
416 ftn_int lda = A.ld();
417 FORTRAN(sgeqrf)(&m, &n, A.at(sddk::memory_t::host, ia, ja), &lda, &z, &z, &lwork, &info);
418 lwork = static_cast<int>(z + 1);
419 std::vector<ftn_single> work(lwork);
420 std::vector<ftn_single> tau(std::max(m, n));
421 FORTRAN(sgeqrf)(&m, &n, A.at(sddk::memory_t::host, ia, ja), &lda, tau.data(), work.data(), &lwork, &info);
422 break;
423 }
424 default: {
425 throw std::runtime_error(linalg_msg_wrong_type);
426 break;
427 }
428 }
429}
430
431template <>
432inline void
433wrap::axpy(int n, ftn_double_complex const* alpha, ftn_double_complex const* x, int incx, ftn_double_complex* y,
434 int incy)
435{
436 assert(n > 0);
437 assert(incx > 0);
438 assert(incy > 0);
439
440 switch (la_) {
441 case lib_t::blas: {
442 FORTRAN(zaxpy)(&n, alpha, x, &incx, y, &incy);
443 break;
444 }
445#if defined(SIRIUS_GPU)
446 case lib_t::gpublas: {
447 acc::blas::zaxpy(n, reinterpret_cast<const acc_complex_double_t*>(alpha),
448 reinterpret_cast<const acc_complex_double_t*>(x), incx,
449 reinterpret_cast<acc_complex_double_t*>(y), incy);
450 break;
451 }
452#endif
453 default: {
454 throw std::runtime_error(linalg_msg_wrong_type);
455 break;
456 }
457 }
458}
459
460template <>
461inline void
462wrap::gemm<ftn_single>(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, ftn_single const* alpha,
463 ftn_single const* A, ftn_int lda, ftn_single const* B, ftn_int ldb, ftn_single const* beta,
464 ftn_single* C, ftn_int ldc, acc::stream_id sid) const
465{
466 assert(lda > 0);
467 assert(ldb > 0);
468 assert(ldc > 0);
469 assert(m > 0);
470 assert(n > 0);
471 assert(k > 0);
472 switch (la_) {
473 case lib_t::blas: {
474 FORTRAN(sgemm)
475 (&transa, &transb, &m, &n, &k, const_cast<float*>(alpha), const_cast<float*>(A), &lda,
476 const_cast<float*>(B), &ldb, const_cast<float*>(beta), C, &ldc, (ftn_len)1, (ftn_len)1);
477 break;
478 }
479 case lib_t::gpublas: {
480#if defined(SIRIUS_GPU)
481 acc::blas::sgemm(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, sid());
482#else
483 throw std::runtime_error("not compiled with GPU blas support!");
484#endif
485 break;
486 }
487 case lib_t::cublasxt: {
488#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
489 acc::blas::xt::sgemm(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
490#else
491 throw std::runtime_error("not compiled with cublasxt");
492#endif
493 break;
494 }
495 case lib_t::spla: {
496 splablas::sgemm(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
497 break;
498 }
499 default: {
500 throw std::runtime_error(linalg_msg_wrong_type);
501 break;
502 }
503 }
504}
505
506template <>
507inline void wrap::gemm<ftn_double>(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, ftn_double const* alpha,
508 ftn_double const* A, ftn_int lda, ftn_double const* B, ftn_int ldb,
509 ftn_double const* beta, ftn_double* C, ftn_int ldc, acc::stream_id sid) const
510{
511 assert(lda > 0);
512 assert(ldb > 0);
513 assert(ldc > 0);
514 assert(m > 0);
515 assert(n > 0);
516 assert(k > 0);
517 switch (la_) {
518 case lib_t::blas: {
519 FORTRAN(dgemm)(&transa, &transb, &m, &n, &k, const_cast<double*>(alpha), const_cast<double*>(A), &lda,
520 const_cast<double*>(B), &ldb, const_cast<double*>(beta), C, &ldc, (ftn_len)1, (ftn_len)1);
521 break;
522 }
523 case lib_t::gpublas: {
524#if defined(SIRIUS_GPU)
525 acc::blas::dgemm(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, sid());
526#else
527 throw std::runtime_error("not compiled with GPU blas support!");
528#endif
529 break;
530 }
531 case lib_t::cublasxt: {
532#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
533 acc::blas::xt::dgemm(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
534#else
535 throw std::runtime_error("not compiled with cublasxt");
536#endif
537 break;
538
539 }
540 case lib_t::spla: {
541 splablas::dgemm(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
542 break;
543 }
544 default: {
545 throw std::runtime_error(linalg_msg_wrong_type);
546 break;
547 }
548 }
549}
550
551template <>
552inline void wrap::gemm<ftn_complex>(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, ftn_complex const* alpha,
553 ftn_complex const* A, ftn_int lda, ftn_complex const* B, ftn_int ldb, ftn_complex const *beta,
554 ftn_complex* C, ftn_int ldc, acc::stream_id sid) const
555{
556 assert(lda > 0);
557 assert(ldb > 0);
558 assert(ldc > 0);
559 assert(m > 0);
560 assert(n > 0);
561 assert(k > 0);
562 switch (la_) {
563 case lib_t::blas: {
564 FORTRAN(cgemm)
565 (&transa, &transb, &m, &n, &k, const_cast<ftn_complex*>(alpha), const_cast<ftn_complex*>(A), &lda,
566 const_cast<ftn_complex*>(B), &ldb, const_cast<ftn_complex*>(beta), C, &ldc, (ftn_len)1, (ftn_len)1);
567 break;
568 }
569 case lib_t::gpublas: {
570#if defined(SIRIUS_GPU)
571 acc::blas::cgemm(transa, transb, m, n, k, reinterpret_cast<acc_complex_float_t const*>(alpha),
572 reinterpret_cast<acc_complex_float_t const*>(A), lda,
573 reinterpret_cast<acc_complex_float_t const*>(B), ldb,
574 reinterpret_cast<acc_complex_float_t const*>(beta),
575 reinterpret_cast<acc_complex_float_t*>(C), ldc, sid());
576#else
577 throw std::runtime_error("not compiled with GPU blas support!");
578#endif
579 break;
580 }
581 case lib_t::cublasxt: {
582#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
583 acc::blas::xt::cgemm(transa, transb, m, n, k, reinterpret_cast<acc_complex_float_t const*>(alpha),
584 reinterpret_cast<acc_complex_float_t const*>(A), lda,
585 reinterpret_cast<acc_complex_float_t const*>(B), ldb,
586 reinterpret_cast<acc_complex_float_t const*>(beta),
587 reinterpret_cast<acc_complex_float_t*>(C), ldc);
588#else
589 throw std::runtime_error("not compiled with cublasxt");
590#endif
591 break;
592 }
593 case lib_t::spla: {
594 splablas::cgemm(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
595 break;
596 }
597 default: {
598 throw std::runtime_error(linalg_msg_wrong_type);
599 break;
600 }
601 }
602}
603
604template <>
605inline void wrap::gemm<ftn_double_complex>(char transa, char transb, ftn_int m, ftn_int n, ftn_int k,
606 ftn_double_complex const* alpha, ftn_double_complex const* A, ftn_int lda,
607 ftn_double_complex const* B, ftn_int ldb, ftn_double_complex const *beta,
608 ftn_double_complex* C, ftn_int ldc, acc::stream_id sid) const
609{
610 assert(lda > 0);
611 assert(ldb > 0);
612 assert(ldc > 0);
613 assert(m > 0);
614 assert(n > 0);
615 assert(k > 0);
616 switch (la_) {
617 case lib_t::blas: {
618 FORTRAN(zgemm)(&transa, &transb, &m, &n, &k, const_cast<ftn_double_complex*>(alpha),
619 const_cast<ftn_double_complex*>(A), &lda, const_cast<ftn_double_complex*>(B), &ldb,
620 const_cast<ftn_double_complex*>(beta), C, &ldc, (ftn_len)1, (ftn_len)1);
621 break;
622 }
623 case lib_t::gpublas: {
624#if defined(SIRIUS_GPU)
625 acc::blas::zgemm(transa, transb, m, n, k, reinterpret_cast<acc_complex_double_t const*>(alpha),
626 reinterpret_cast<acc_complex_double_t const*>(A), lda, reinterpret_cast<acc_complex_double_t const*>(B),
627 ldb, reinterpret_cast<acc_complex_double_t const*>(beta),
628 reinterpret_cast<acc_complex_double_t*>(C), ldc, sid());
629#else
630 throw std::runtime_error("not compiled with GPU blas support!");
631#endif
632 break;
633
634 }
635 case lib_t::cublasxt: {
636#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
637 acc::blas::xt::zgemm(transa, transb, m, n, k, reinterpret_cast<acc_complex_double_t const*>(alpha),
638 reinterpret_cast<acc_complex_double_t const*>(A), lda,
639 reinterpret_cast<acc_complex_double_t const*>(B), ldb,
640 reinterpret_cast<acc_complex_double_t const*>(beta),
641 reinterpret_cast<acc_complex_double_t*>(C), ldc);
642#else
643 throw std::runtime_error("not compiled with cublasxt");
644#endif
645 break;
646
647 }
648 case lib_t::spla: {
649 splablas::zgemm(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
650 break;
651 }
652 default: {
653 throw std::runtime_error(linalg_msg_wrong_type);
654 break;
655 }
656 }
657}
658
659template<>
660inline void
661wrap::gemm<ftn_single>(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, ftn_single const* alpha,
662 dmatrix<ftn_single> const& A, ftn_int ia, ftn_int ja, dmatrix<ftn_single> const& B,
663 ftn_int ib, ftn_int jb, ftn_single const* beta, dmatrix<ftn_single>& C, ftn_int ic, ftn_int jc)
664{
665 switch (la_) {
666 case lib_t::scalapack: {
667#if defined(SIRIUS_SCALAPACK)
668 assert(A.ld() != 0);
669 assert(B.ld() != 0);
670 assert(C.ld() != 0);
671
672 ia++; ja++;
673 ib++; jb++;
674 ic++; jc++;
675 FORTRAN(psgemm)(&transa, &transb, &m, &n, &k, alpha, A.at(sddk::memory_t::host), &ia, &ja, A.descriptor(),
676 B.at(sddk::memory_t::host), &ib, &jb, B.descriptor(), beta, C.at(sddk::memory_t::host), &ic, &jc, C.descriptor(),
677 (ftn_len)1, (ftn_len)1);
678#else
679 throw std::runtime_error(linalg_msg_no_scalapack);
680#endif
681 break;
682 }
683 default: {
684 throw std::runtime_error(linalg_msg_wrong_type);
685 break;
686 }
687 }
688}
689
690template<>
691inline void
692wrap::gemm<ftn_double>(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, ftn_double const* alpha,
693 dmatrix<ftn_double> const& A, ftn_int ia, ftn_int ja, dmatrix<ftn_double> const& B,
694 ftn_int ib, ftn_int jb, ftn_double const* beta, dmatrix<ftn_double>& C, ftn_int ic, ftn_int jc)
695{
696 switch (la_) {
697 case lib_t::scalapack: {
698#if defined(SIRIUS_SCALAPACK)
699 assert(A.ld() != 0);
700 assert(B.ld() != 0);
701 assert(C.ld() != 0);
702
703 ia++; ja++;
704 ib++; jb++;
705 ic++; jc++;
706 FORTRAN(pdgemm)(&transa, &transb, &m, &n, &k, alpha, A.at(sddk::memory_t::host), &ia, &ja, A.descriptor(),
707 B.at(sddk::memory_t::host), &ib, &jb, B.descriptor(), beta, C.at(sddk::memory_t::host), &ic, &jc, C.descriptor(),
708 (ftn_len)1, (ftn_len)1);
709#else
710 throw std::runtime_error(linalg_msg_no_scalapack);
711#endif
712 break;
713 }
714 default: {
715 throw std::runtime_error(linalg_msg_wrong_type);
716 break;
717 }
718 }
719}
720
721template<>
722inline void
723wrap::gemm<ftn_complex>(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, ftn_complex const* alpha,
724 dmatrix<ftn_complex> const& A, ftn_int ia, ftn_int ja, dmatrix<ftn_complex> const& B,
725 ftn_int ib, ftn_int jb, ftn_complex const* beta, dmatrix<ftn_complex>& C, ftn_int ic, ftn_int jc)
726{
727 switch (la_) {
728 case lib_t::scalapack: {
729#if defined(SIRIUS_SCALAPACK)
730 assert(A.ld() != 0);
731 assert(B.ld() != 0);
732 assert(C.ld() != 0);
733
734 ia++; ja++;
735 ib++; jb++;
736 ic++; jc++;
737 FORTRAN(pcgemm)(&transa, &transb, &m, &n, &k, alpha, A.at(sddk::memory_t::host), &ia, &ja, A.descriptor(),
738 B.at(sddk::memory_t::host), &ib, &jb, B.descriptor(), beta, C.at(sddk::memory_t::host), &ic, &jc, C.descriptor(),
739 (ftn_len)1, (ftn_len)1);
740#else
741 throw std::runtime_error(linalg_msg_no_scalapack);
742#endif
743 break;
744 }
745 default: {
746 throw std::runtime_error(linalg_msg_wrong_type);
747 break;
748 }
749 }
750}
751
752template<>
753inline void
754wrap::gemm<ftn_double_complex>(char transa, char transb, ftn_int m, ftn_int n, ftn_int k,
755 ftn_double_complex const* alpha, dmatrix<ftn_double_complex> const& A,
756 ftn_int ia, ftn_int ja, dmatrix<ftn_double_complex> const& B,
757 ftn_int ib, ftn_int jb, ftn_double_complex const* beta,
758 dmatrix<ftn_double_complex>& C, ftn_int ic, ftn_int jc)
759{
760 switch (la_) {
761 case lib_t::scalapack: {
762#if defined(SIRIUS_SCALAPACK)
763 assert(A.ld() != 0);
764 assert(B.ld() != 0);
765 assert(C.ld() != 0);
766
767 ia++; ja++;
768 ib++; jb++;
769 ic++; jc++;
770 FORTRAN(pzgemm)(&transa, &transb, &m, &n, &k, alpha, A.at(sddk::memory_t::host), &ia, &ja, A.descriptor(),
771 B.at(sddk::memory_t::host), &ib, &jb, B.descriptor(), beta, C.at(sddk::memory_t::host), &ic, &jc, C.descriptor(),
772 (ftn_len)1, (ftn_len)1);
773#else
774 throw std::runtime_error(linalg_msg_no_scalapack);
775#endif
776 break;
777 }
778 default: {
779 throw std::runtime_error(linalg_msg_wrong_type);
780 break;
781 }
782 }
783}
784
785template<>
786inline void
787wrap::hemm<ftn_complex>(char side, char uplo, ftn_int m, ftn_int n, ftn_complex const* alpha, ftn_complex const* A,
788 ftn_len lda, ftn_complex const* B, ftn_len ldb, ftn_complex const* beta, ftn_complex* C, ftn_len ldc)
789{
790 assert(lda > 0);
791 assert(ldb > 0);
792 assert(ldc > 0);
793 assert(m > 0);
794 assert(n > 0);
795 switch (la_) {
796 case lib_t::blas: {
797 FORTRAN(chemm)
798 (&side, &uplo, &m, &n, const_cast<ftn_complex*>(alpha), const_cast<ftn_complex*>(A), &lda,
799 const_cast<ftn_complex*>(B), &ldb, const_cast<ftn_complex*>(beta), C, &ldc, (ftn_len)1, (ftn_len)1);
800 break;
801 }
802 default: {
803 throw std::runtime_error(linalg_msg_wrong_type);
804 break;
805 }
806 }
807}
808
809template<>
810inline void
811wrap::hemm<ftn_double_complex>(char side, char uplo, ftn_int m, ftn_int n, ftn_double_complex const* alpha,
812 ftn_double_complex const* A, ftn_len lda, ftn_double_complex const* B, ftn_len ldb,
813 ftn_double_complex const* beta, ftn_double_complex* C, ftn_len ldc)
814{
815 assert(lda > 0);
816 assert(ldb > 0);
817 assert(ldc > 0);
818 assert(m > 0);
819 assert(n > 0);
820 switch (la_) {
821 case lib_t::blas: {
822 FORTRAN(zhemm)(&side, &uplo, &m, &n, const_cast<ftn_double_complex*>(alpha),
823 const_cast<ftn_double_complex*>(A), &lda, const_cast<ftn_double_complex*>(B), &ldb,
824 const_cast<ftn_double_complex*>(beta), C, &ldc, (ftn_len)1, (ftn_len)1);
825 break;
826 }
827 default: {
828 throw std::runtime_error(linalg_msg_wrong_type);
829 break;
830 }
831 }
832}
833
834template<>
835inline void wrap::ger<ftn_single>(ftn_int m, ftn_int n, ftn_single const* alpha, ftn_single const* x, ftn_int incx,
836 ftn_single const* y, ftn_int incy, ftn_single* A, ftn_int lda, acc::stream_id sid) const
837{
838 switch (la_) {
839 case lib_t::blas: {
840 FORTRAN(sger)(&m, &n, const_cast<ftn_single*>(alpha), const_cast<ftn_single*>(x), &incx,
841 const_cast<ftn_single*>(y), &incy, A, &lda);
842 break;
843 }
844 case lib_t::gpublas: {
845#if defined(SIRIUS_GPU)
846 acc::blas::sger(m, n, alpha, x, incx, y, incy, A, lda, sid());
847#else
848 throw std::runtime_error("not compiled with GPU blas support!");
849#endif
850 break;
851 }
852 case lib_t::cublasxt: {
853 throw std::runtime_error("(s,c)ger is not implemented in cublasxt");
854 break;
855 }
856 default: {
857 throw std::runtime_error(linalg_msg_wrong_type);
858 break;
859 }
860 }
861}
862
863template<>
864inline void wrap::ger<ftn_double>(ftn_int m, ftn_int n, ftn_double const* alpha, ftn_double const* x, ftn_int incx,
865 ftn_double const* y, ftn_int incy, ftn_double* A, ftn_int lda, acc::stream_id sid) const
866{
867 switch (la_) {
868 case lib_t::blas: {
869 FORTRAN(dger)(&m, &n, const_cast<ftn_double*>(alpha), const_cast<ftn_double*>(x), &incx,
870 const_cast<ftn_double*>(y), &incy, A, &lda);
871 break;
872 }
873 case lib_t::gpublas: {
874#if defined(SIRIUS_GPU)
875 acc::blas::dger(m, n, alpha, x, incx, y, incy, A, lda, sid());
876#else
877 throw std::runtime_error("not compiled with GPU blas support!");
878#endif
879 break;
880 }
881 case lib_t::cublasxt: {
882 throw std::runtime_error("(d,z)ger is not implemented in cublasxt");
883 break;
884 }
885 default: {
886 throw std::runtime_error(linalg_msg_wrong_type);
887 break;
888 }
889 }
890}
891
892template <>
893inline void wrap::trmm<ftn_double>(char side, char uplo, char transa, ftn_int m, ftn_int n, ftn_double const* alpha,
894 ftn_double const* A, ftn_int lda, ftn_double* B, ftn_int ldb, acc::stream_id sid) const
895{
896 switch (la_) {
897 case lib_t::blas: {
898 FORTRAN(dtrmm)(&side, &uplo, &transa, "N", &m, &n, const_cast<ftn_double*>(alpha),
899 const_cast<ftn_double*>(A), &lda, B, &ldb, (ftn_len)1, (ftn_len)1, (ftn_len)1, (ftn_len)1);
900 break;
901 }
902 case lib_t::gpublas: {
903#if defined(SIRIUS_GPU)
904 acc::blas::dtrmm(side, uplo, transa, 'N', m, n, alpha, A, lda, B, ldb, sid());
905#else
906 throw std::runtime_error("not compiled with GPU blas support!");
907#endif
908 break;
909 }
910 case lib_t::cublasxt: {
911#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
912 acc::blas::xt::dtrmm(side, uplo, transa, 'N', m, n, alpha, A, lda, B, ldb);
913#else
914 throw std::runtime_error("not compiled with cublasxt");
915#endif
916 break;
917 }
918 default: {
919 throw std::runtime_error(linalg_msg_wrong_type);
920 break;
921 }
922 }
923}
924
925template <>
926inline void wrap::trmm<ftn_single>(char side, char uplo, char transa, ftn_int m, ftn_int n, ftn_single const* alpha,
927 ftn_single const* A, ftn_int lda, ftn_single* B, ftn_int ldb, acc::stream_id sid) const
928{
929 switch (la_) {
930 case lib_t::blas: {
931 FORTRAN(strmm)(&side, &uplo, &transa, "N", &m, &n, const_cast<ftn_single*>(alpha),
932 const_cast<ftn_single*>(A), &lda, B, &ldb, (ftn_len)1, (ftn_len)1, (ftn_len)1, (ftn_len)1);
933 break;
934 }
935 case lib_t::gpublas: {
936#if defined(SIRIUS_GPU)
937 acc::blas::strmm(side, uplo, transa, 'N', m, n, alpha, A, lda, B, ldb, sid());
938#else
939 throw std::runtime_error("not compiled with GPU blas support!");
940#endif
941 break;
942 }
943 case lib_t::cublasxt: {
944#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
945 acc::blas::xt::strmm(side, uplo, transa, 'N', m, n, alpha, A, lda, B, ldb);
946#else
947 throw std::runtime_error("not compiled with cublasxt");
948#endif
949 break;
950 }
951 default: {
952 throw std::runtime_error(linalg_msg_wrong_type);
953 break;
954 }
955 }
956}
957
958template <>
959inline void wrap::trmm<ftn_double_complex>(char side, char uplo, char transa, ftn_int m, ftn_int n,
960 ftn_double_complex const* alpha, ftn_double_complex const* A,
961 ftn_int lda, ftn_double_complex* B, ftn_int ldb, acc::stream_id sid) const
962{
963 switch (la_) {
964 case lib_t::blas: {
965 FORTRAN(ztrmm)(&side, &uplo, &transa, "N", &m, &n, const_cast<ftn_double_complex*>(alpha),
966 const_cast<ftn_double_complex*>(A), &lda, B, &ldb, (ftn_len)1, (ftn_len)1,
967 (ftn_len)1, (ftn_len)1);
968 break;
969 }
970 case lib_t::gpublas: {
971#if defined(SIRIUS_GPU)
972 acc::blas::ztrmm(side, uplo, transa, 'N', m, n, reinterpret_cast<acc_complex_double_t const*>(alpha),
973 reinterpret_cast<acc_complex_double_t const*>(A), lda,
974 reinterpret_cast<acc_complex_double_t*>(B), ldb, sid());
975#else
976 throw std::runtime_error("not compiled with GPU blas support!");
977#endif
978 break;
979 }
980 case lib_t::cublasxt: {
981#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
982 acc::blas::xt::ztrmm(side, uplo, transa, 'N', m, n, reinterpret_cast<acc_complex_double_t const*>(alpha),
983 reinterpret_cast<acc_complex_double_t const*>(A), lda, reinterpret_cast<acc_complex_double_t*>(B), ldb);
984#else
985 throw std::runtime_error("not compiled with cublasxt");
986#endif
987 break;
988 }
989 default: {
990 throw std::runtime_error(linalg_msg_wrong_type);
991 break;
992 }
993 }
994}
995
996template <>
997inline void wrap::trmm<ftn_complex>(char side, char uplo, char transa, ftn_int m, ftn_int n,
998 ftn_complex const* alpha, ftn_complex const* A,
999 ftn_int lda, ftn_complex* B, ftn_int ldb, acc::stream_id sid) const
1000{
1001 switch (la_) {
1002 case lib_t::blas: {
1003 FORTRAN(ctrmm)
1004 (&side, &uplo, &transa, "N", &m, &n, const_cast<ftn_complex*>(alpha), const_cast<ftn_complex*>(A), &lda, B,
1005 &ldb, (ftn_len)1, (ftn_len)1, (ftn_len)1, (ftn_len)1);
1006 break;
1007 }
1008 case lib_t::gpublas: {
1009#if defined(SIRIUS_GPU)
1010 acc::blas::ctrmm(side, uplo, transa, 'N', m, n, reinterpret_cast<acc_complex_float_t const*>(alpha),
1011 reinterpret_cast<acc_complex_float_t const*>(A), lda,
1012 reinterpret_cast<acc_complex_float_t*>(B), ldb, sid());
1013#else
1014 throw std::runtime_error("not compiled with GPU blas support!");
1015#endif
1016 break;
1017 }
1018 case lib_t::cublasxt: {
1019#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
1020 acc::blas::xt::ctrmm(side, uplo, transa, 'N', m, n, reinterpret_cast<acc_complex_float_t const*>(alpha),
1021 reinterpret_cast<acc_complex_float_t const*>(A), lda,
1022 reinterpret_cast<acc_complex_float_t*>(B), ldb);
1023#else
1024 throw std::runtime_error("not compiled with cublasxt");
1025#endif
1026 break;
1027 }
1028 default: {
1029 throw std::runtime_error(linalg_msg_wrong_type);
1030 break;
1031 }
1032 }
1033}
1034
1035template<>
1036inline int wrap::potrf<ftn_double>(ftn_int n, ftn_double* A, ftn_int lda, ftn_int const* desca) const
1037{
1038 switch (la_) {
1039 case lib_t::lapack: {
1040 ftn_int info;
1041 FORTRAN(dpotrf)("U", &n, A, &lda, &info, (ftn_len)1);
1042 return info;
1043 break;
1044 }
1045 case lib_t::magma: {
1046#if defined(SIRIUS_GPU) && defined(SIRIUS_MAGMA)
1047 return magma::dpotrf('U', n, A, lda);
1048#else
1049 throw std::runtime_error("not compiled with magma");
1050#endif
1051 break;
1052 }
1053 case lib_t::scalapack: {
1054#if defined(SIRIUS_SCALAPACK)
1055 assert(desca != nullptr);
1056 ftn_int ia{1};
1057 ftn_int ja{1};
1058 ftn_int info;
1059 FORTRAN(pdpotrf)("U", &n, A, &ia, &ja, const_cast<ftn_int*>(desca), &info, (ftn_len)1);
1060 return info;
1061#else
1062 throw std::runtime_error(linalg_msg_no_scalapack);
1063#endif
1064 break;
1065 }
1066 case lib_t::gpublas: {
1067#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
1068 acc::cusolver::potrf<ftn_double>(n, A, lda);
1069#else
1070 throw std::runtime_error("not compiled with CUDA");
1071#endif
1072 break;
1073 }
1074 default: {
1075 throw std::runtime_error(linalg_msg_wrong_type);
1076 break;
1077 }
1078 }
1079 return -1;
1080}
1081
1082template<>
1083inline int wrap::potrf<ftn_single>(ftn_int n, ftn_single* A, ftn_int lda, ftn_int const* desca) const
1084{
1085 switch (la_) {
1086 case lib_t::lapack: {
1087 ftn_int info;
1088 FORTRAN(spotrf)("U", &n, A, &lda, &info, (ftn_len)1);
1089 return info;
1090 break;
1091 }
1092 case lib_t::magma: {
1093#if defined(SIRIUS_GPU) && defined(SIRIUS_MAGMA)
1094 return magma::spotrf('U', n, A, lda);
1095#else
1096 throw std::runtime_error("not compiled with magma");
1097#endif
1098 break;
1099 }
1100 case lib_t::scalapack: {
1101#if defined(SIRIUS_SCALAPACK)
1102 assert(desca != nullptr);
1103 ftn_int ia{1};
1104 ftn_int ja{1};
1105 ftn_int info;
1106 FORTRAN(pspotrf)("U", &n, A, &ia, &ja, const_cast<ftn_int*>(desca), &info, (ftn_len)1);
1107 return info;
1108#else
1109 throw std::runtime_error(linalg_msg_no_scalapack);
1110#endif
1111 break;
1112 }
1113 case lib_t::gpublas: {
1114#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
1115 acc::cusolver::potrf<ftn_single>(n, A, lda);
1116#else
1117 throw std::runtime_error("not compiled with CUDA");
1118#endif
1119 break;
1120 }
1121 default: {
1122 throw std::runtime_error(linalg_msg_wrong_type);
1123 break;
1124 }
1125 }
1126 return -1;
1127}
1128
1129template<>
1130inline int wrap::potrf<ftn_double_complex>(ftn_int n, ftn_double_complex* A, ftn_int lda, ftn_int const* desca) const
1131{
1132 switch (la_) {
1133 case lib_t::lapack: {
1134 ftn_int info;
1135 FORTRAN(zpotrf)("U", &n, A, &lda, &info, (ftn_len)1);
1136 return info;
1137 break;
1138 }
1139 case lib_t::scalapack: {
1140#if defined(SIRIUS_SCALAPACK)
1141 assert(desca != nullptr);
1142 ftn_int ia{1};
1143 ftn_int ja{1};
1144 ftn_int info;
1145 FORTRAN(pzpotrf)("U", &n, A, &ia, &ja, const_cast<ftn_int*>(desca), &info, (ftn_len)1);
1146 return info;
1147#else
1148 throw std::runtime_error(linalg_msg_no_scalapack);
1149#endif
1150 break;
1151 }
1152 case lib_t::magma: {
1153#if defined(SIRIUS_GPU) && defined(SIRIUS_MAGMA)
1154 return magma::zpotrf('U', n, reinterpret_cast<magmaDoubleComplex*>(A), lda);
1155#else
1156 throw std::runtime_error("not compiled with magma");
1157#endif
1158 break;
1159 }
1160 case lib_t::gpublas: {
1161#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
1162 acc::cusolver::potrf<ftn_double_complex>(n, A, lda);
1163#else
1164 throw std::runtime_error("not compiled with CUDA");
1165#endif
1166 break;
1167 }
1168 default: {
1169 throw std::runtime_error(linalg_msg_wrong_type);
1170 break;
1171 }
1172 }
1173 return -1;
1174}
1175
1176template<>
1177inline int wrap::potrf<ftn_complex>(ftn_int n, ftn_complex* A, ftn_int lda, ftn_int const* desca) const
1178{
1179 switch (la_) {
1180 case lib_t::lapack: {
1181 ftn_int info;
1182 FORTRAN(cpotrf)("U", &n, A, &lda, &info, (ftn_len)1);
1183 return info;
1184 break;
1185 }
1186 case lib_t::scalapack: {
1187#if defined(SIRIUS_SCALAPACK)
1188 assert(desca != nullptr);
1189 ftn_int ia{1};
1190 ftn_int ja{1};
1191 ftn_int info;
1192 FORTRAN(pcpotrf)("U", &n, A, &ia, &ja, const_cast<ftn_int*>(desca), &info, (ftn_len)1);
1193 return info;
1194#else
1195 throw std::runtime_error(linalg_msg_no_scalapack);
1196#endif
1197 break;
1198 }
1199 case lib_t::magma: {
1200#if defined(SIRIUS_GPU) && defined(SIRIUS_MAGMA)
1201 return magma::cpotrf('U', n, reinterpret_cast<magmaFloatComplex*>(A), lda);
1202#else
1203 throw std::runtime_error("not compiled with magma");
1204#endif
1205 break;
1206 }
1207 case lib_t::gpublas: {
1208#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
1209 acc::cusolver::potrf<ftn_complex>(n, A, lda);
1210#else
1211 throw std::runtime_error("not compiled with CUDA");
1212#endif
1213 break;
1214 }
1215 default: {
1216 throw std::runtime_error(linalg_msg_wrong_type);
1217 break;
1218 }
1219 }
1220 return -1;
1221}
1222
1223template<>
1224inline int wrap::trtri<ftn_double>(ftn_int n, ftn_double* A, ftn_int lda, ftn_int const* desca) const
1225{
1226 switch (la_) {
1227 case lib_t::lapack: {
1228 ftn_int info;
1229 FORTRAN(dtrtri)("U", "N", &n, A, &lda, &info, (ftn_len)1, (ftn_len)1);
1230 return info;
1231 break;
1232 }
1233 case lib_t::scalapack: {
1234#if defined(SIRIUS_SCALAPACK)
1235 assert(desca != nullptr);
1236 ftn_int ia{1};
1237 ftn_int ja{1};
1238 ftn_int info;
1239 FORTRAN(pdtrtri)("U", "N", &n, A, &ia, &ja, const_cast<ftn_int*>(desca), &info, (ftn_len)1, (ftn_len)1);
1240 return info;
1241#else
1242 throw std::runtime_error(linalg_msg_no_scalapack);
1243#endif
1244 break;
1245 }
1246 case lib_t::magma: {
1247#if defined(SIRIUS_GPU) && defined(SIRIUS_MAGMA)
1248 return magma::dtrtri('U', n, A, lda);
1249#else
1250 throw std::runtime_error("not compiled with magma");
1251#endif
1252 break;
1253 }
1254 case lib_t::gpublas: {
1255#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
1256 acc::cusolver::trtri<ftn_double>(n, A, lda);
1257#else
1258 throw std::runtime_error("not compiled with CUDA");
1259#endif
1260 break;
1261 }
1262 default: {
1263 throw std::runtime_error(linalg_msg_wrong_type);
1264 break;
1265 }
1266 }
1267 return -1;
1268}
1269
1270template<>
1271inline int wrap::trtri<ftn_single>(ftn_int n, ftn_single* A, ftn_int lda, ftn_int const* desca) const
1272{
1273 switch (la_) {
1274 case lib_t::lapack: {
1275 ftn_int info;
1276 FORTRAN(strtri)("U", "N", &n, A, &lda, &info, (ftn_len)1, (ftn_len)1);
1277 return info;
1278 break;
1279 }
1280 case lib_t::scalapack: {
1281#if defined(SIRIUS_SCALAPACK)
1282 assert(desca != nullptr);
1283 ftn_int ia{1};
1284 ftn_int ja{1};
1285 ftn_int info;
1286 FORTRAN(pstrtri)("U", "N", &n, A, &ia, &ja, const_cast<ftn_int*>(desca), &info, (ftn_len)1, (ftn_len)1);
1287 return info;
1288#else
1289 throw std::runtime_error(linalg_msg_no_scalapack);
1290#endif
1291 break;
1292 }
1293 case lib_t::magma: {
1294#if defined(SIRIUS_GPU) && defined(SIRIUS_MAGMA)
1295 return magma::strtri('U', n, A, lda);
1296#else
1297 throw std::runtime_error("not compiled with magma");
1298#endif
1299 break;
1300 }
1301 case lib_t::gpublas: {
1302#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
1303 acc::cusolver::trtri<ftn_single>(n, A, lda);
1304#else
1305 throw std::runtime_error("not compiled with CUDA");
1306#endif
1307 break;
1308 }
1309 default: {
1310 throw std::runtime_error(linalg_msg_wrong_type);
1311 break;
1312 }
1313 }
1314 return -1;
1315}
1316
1317template<>
1318inline int wrap::trtri<ftn_double_complex>(ftn_int n, ftn_double_complex* A, ftn_int lda, ftn_int const* desca) const
1319{
1320 switch (la_) {
1321 case lib_t::lapack: {
1322 ftn_int info;
1323 FORTRAN(ztrtri)("U", "N", &n, A, &lda, &info, (ftn_len)1, (ftn_len)1);
1324 return info;
1325 break;
1326 }
1327 case lib_t::scalapack: {
1328#if defined(SIRIUS_SCALAPACK)
1329 assert(desca != nullptr);
1330 ftn_int ia{1};
1331 ftn_int ja{1};
1332 ftn_int info;
1333 FORTRAN(pztrtri)("U", "N", &n, A, &ia, &ja, const_cast<ftn_int*>(desca), &info, (ftn_len)1, (ftn_len)1);
1334 return info;
1335#else
1336 throw std::runtime_error(linalg_msg_no_scalapack);
1337#endif
1338 break;
1339 }
1340 case lib_t::magma: {
1341#if defined(SIRIUS_GPU) && defined(SIRIUS_MAGMA)
1342 return magma::ztrtri('U', n, reinterpret_cast<magmaDoubleComplex*>(A), lda);
1343#else
1344 throw std::runtime_error("not compiled with magma");
1345#endif
1346 break;
1347 }
1348 case lib_t::gpublas: {
1349#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
1350 acc::cusolver::trtri<ftn_double_complex>(n, A, lda);
1351#else
1352 throw std::runtime_error("not compiled with CUDA");
1353#endif
1354 break;
1355 }
1356 default: {
1357 throw std::runtime_error(linalg_msg_wrong_type);
1358 break;
1359 }
1360 }
1361 return -1;
1362}
1363
1364template<>
1365inline int wrap::trtri<ftn_complex>(ftn_int n, ftn_complex* A, ftn_int lda, ftn_int const* desca) const
1366{
1367 switch (la_) {
1368 case lib_t::lapack: {
1369 ftn_int info;
1370 FORTRAN(ctrtri)("U", "N", &n, A, &lda, &info, (ftn_len)1, (ftn_len)1);
1371 return info;
1372 break;
1373 }
1374 case lib_t::scalapack: {
1375#if defined(SIRIUS_SCALAPACK)
1376 assert(desca != nullptr);
1377 ftn_int ia{1};
1378 ftn_int ja{1};
1379 ftn_int info;
1380 FORTRAN(pctrtri)("U", "N", &n, A, &ia, &ja, const_cast<ftn_int*>(desca), &info, (ftn_len)1, (ftn_len)1);
1381 return info;
1382#else
1383 throw std::runtime_error(linalg_msg_no_scalapack);
1384#endif
1385 break;
1386 }
1387 case lib_t::magma: {
1388#if defined(SIRIUS_GPU) && defined(SIRIUS_MAGMA)
1389 return magma::ctrtri('U', n, reinterpret_cast<magmaFloatComplex*>(A), lda);
1390#else
1391 throw std::runtime_error("not compiled with magma");
1392#endif
1393 break;
1394 }
1395 case lib_t::gpublas: {
1396#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA)
1397 acc::cusolver::trtri<ftn_complex>(n, A, lda);
1398#else
1399 throw std::runtime_error("not compiled with CUDA");
1400#endif
1401 break;
1402 }
1403 default: {
1404 throw std::runtime_error(linalg_msg_wrong_type);
1405 break;
1406 }
1407 }
1408 return -1;
1409}
1410
1411template<>
1412inline int wrap::gtsv<ftn_double>(ftn_int n, ftn_int nrhs, ftn_double* dl, ftn_double* d, ftn_double* du,
1413 ftn_double* b, ftn_int ldb) const
1414{
1415 switch (la_) {
1416 case lib_t::lapack: {
1417 ftn_int info;
1418 FORTRAN(dgtsv)(&n, &nrhs, dl, d, du, b, &ldb, &info);
1419 return info;
1420 break;
1421 }
1422 default: {
1423 throw std::runtime_error(linalg_msg_wrong_type);
1424 break;
1425 }
1426 }
1427 return -1;
1428}
1429
1430template<>
1431inline int wrap::gtsv<ftn_double_complex>(ftn_int n, ftn_int nrhs, ftn_double_complex* dl, ftn_double_complex* d,
1432 ftn_double_complex* du, ftn_double_complex* b, ftn_int ldb) const
1433{
1434 switch (la_) {
1435 case lib_t::lapack: {
1436 ftn_int info;
1437 FORTRAN(zgtsv)(&n, &nrhs, dl, d, du, b, &ldb, &info);
1438 return info;
1439 break;
1440 }
1441 default: {
1442 throw std::runtime_error(linalg_msg_wrong_type);
1443 break;
1444 }
1445 }
1446 return -1;
1447}
1448
1449template<>
1450inline int wrap::gesv<ftn_double>(ftn_int n, ftn_int nrhs, ftn_double* A, ftn_int lda, ftn_double* B, ftn_int ldb) const
1451{
1452 switch (la_) {
1453 case lib_t::lapack: {
1454 ftn_int info;
1455 std::vector<ftn_int> ipiv(n);
1456 FORTRAN(dgesv)(&n, &nrhs, A, &lda, &ipiv[0], B, &ldb, &info);
1457 return info;
1458 break;
1459 }
1460 default: {
1461 throw std::runtime_error(linalg_msg_wrong_type);
1462 break;
1463 }
1464 }
1465 return -1;
1466}
1467
1468template<>
1469inline int wrap::gesv<ftn_double_complex>(ftn_int n, ftn_int nrhs, ftn_double_complex* A, ftn_int lda,
1470 ftn_double_complex* B, ftn_int ldb) const
1471{
1472 switch (la_) {
1473 case lib_t::lapack: {
1474 ftn_int info;
1475 std::vector<ftn_int> ipiv(n);
1476 FORTRAN(zgesv)(&n, &nrhs, A, &lda, &ipiv[0], B, &ldb, &info);
1477 return info;
1478 break;
1479 }
1480 default: {
1481 throw std::runtime_error(linalg_msg_wrong_type);
1482 break;
1483 }
1484 }
1485 return -1;
1486}
1487
1488// LU factorization, double
1489template<>
1490inline int wrap::getrf<ftn_double>(ftn_int m, ftn_int n, ftn_double* A, ftn_int lda, ftn_int* ipiv) const
1491{
1492 switch (la_) {
1493 case lib_t::lapack: {
1494 ftn_int info;
1495 FORTRAN(dgetrf)(&m, &n, A, &lda, ipiv, &info);
1496 return info;
1497 break;
1498 }
1499 default: {
1500 throw std::runtime_error(linalg_msg_wrong_type);
1501 break;
1502 }
1503 }
1504 return -1;
1505}
1506
1507// LU factorization, double_complex
1508template<>
1509inline int wrap::getrf<ftn_double_complex>(ftn_int m, ftn_int n, ftn_double_complex* A, ftn_int lda, ftn_int* ipiv) const
1510{
1511 switch (la_) {
1512 case lib_t::lapack: {
1513 ftn_int info;
1514 FORTRAN(zgetrf)(&m, &n, A, &lda, ipiv, &info);
1515 return info;
1516 break;
1517 }
1518 default: {
1519 throw std::runtime_error(linalg_msg_wrong_type);
1520 break;
1521 }
1522 }
1523 return -1;
1524}
1525
1526template<>
1527inline int wrap::getrf<ftn_double_complex>(ftn_int m, ftn_int n, dmatrix<ftn_double_complex>& A,
1528 ftn_int ia, ftn_int ja, ftn_int* ipiv) const
1529{
1530 switch (la_) {
1531 case lib_t::scalapack: {
1532#if defined (SIRIUS_SCALAPACK)
1533 ftn_int info;
1534 ia++;
1535 ja++;
1536 FORTRAN(pzgetrf)(&m, &n, A.at(sddk::memory_t::host), &ia, &ja, const_cast<int*>(A.descriptor()), ipiv, &info);
1537 return info;
1538#else
1539 throw std::runtime_error(linalg_msg_no_scalapack);
1540#endif
1541 break;
1542 }
1543 default: {
1544 throw std::runtime_error(linalg_msg_wrong_type);
1545 break;
1546 }
1547 }
1548 return -1;
1549}
1550
1551template<>
1552inline int
1553wrap::getrs<ftn_double_complex>(char trans, ftn_int n, ftn_int nrhs, const ftn_double_complex* A, ftn_int lda,
1554 ftn_int* ipiv, ftn_double_complex* B, ftn_int ldb) const
1555{
1556 switch (la_) {
1557 case lib_t::lapack: {
1558 ftn_int info;
1559 FORTRAN(zgetrs)(&trans, &n, &nrhs, const_cast<ftn_double_complex*>(A), &lda, ipiv, B, &ldb, &info);
1560 return info;
1561 break;
1562 }
1563#if defined(SIRIUS_GPU)
1564 case lib_t::gpublas: {
1565 return acc::lapack::getrs(trans, n, nrhs, reinterpret_cast<const acc_complex_double_t*>(A), lda, ipiv,
1566 reinterpret_cast<acc_complex_double_t*>(B), ldb);
1567 break;
1568 }
1569#endif
1570 default: {
1571 throw std::runtime_error(linalg_msg_wrong_type);
1572 break;
1573 }
1574 }
1575 return -1;
1576}
1577
1578template <>
1579inline void
1580wrap::tranc<ftn_complex>(ftn_int m, ftn_int n, dmatrix<ftn_complex>& A, ftn_int ia, ftn_int ja, dmatrix<ftn_complex>& C,
1581 ftn_int ic, ftn_int jc) const
1582{
1583 switch (la_) {
1584 case lib_t::scalapack: {
1585#if defined(SIRIUS_SCALAPACK)
1586 ia++; ja++;
1587 ic++; jc++;
1588
1589 auto A_ptr = (A.num_rows_local() * A.num_cols_local() > 0) ? A.at(sddk::memory_t::host) : nullptr;
1590 auto C_ptr = (C.num_rows_local() * C.num_cols_local() > 0) ? C.at(sddk::memory_t::host) : nullptr;
1591
1592 FORTRAN(pctranc)(&m, &n, const_cast<ftn_complex*>(&constant<ftn_complex>::one()),
1593 A_ptr, &ia, &ja, A.descriptor(),
1594 const_cast<ftn_complex*>(&constant<ftn_complex>::zero()),
1595 C_ptr, &ic, &jc, C.descriptor());
1596#else
1597 throw std::runtime_error(linalg_msg_no_scalapack);
1598#endif
1599 break;
1600 }
1601 default: {
1602 throw std::runtime_error(linalg_msg_wrong_type);
1603 break;
1604 }
1605 }
1606}
1607
1608template<>
1609inline void wrap::tranu<ftn_double_complex>(ftn_int m, ftn_int n, dmatrix<ftn_double_complex>& A,
1610 ftn_int ia, ftn_int ja, dmatrix<ftn_double_complex>& C, ftn_int ic, ftn_int jc) const
1611{
1612 switch (la_) {
1613 case lib_t::scalapack: {
1614#if defined(SIRIUS_SCALAPACK)
1615 ia++; ja++;
1616 ic++; jc++;
1617
1618 auto A_ptr = (A.num_rows_local() * A.num_cols_local() > 0) ? A.at(sddk::memory_t::host) : nullptr;
1619 auto C_ptr = (C.num_rows_local() * C.num_cols_local() > 0) ? C.at(sddk::memory_t::host) : nullptr;
1620
1621 FORTRAN(pztranu)(&m, &n, const_cast<ftn_double_complex*>(&constant<ftn_double_complex>::one()),
1622 A_ptr, &ia, &ja, A.descriptor(),
1623 const_cast<ftn_double_complex*>(&constant<ftn_double_complex>::zero()),
1624 C_ptr, &ic, &jc, C.descriptor());
1625#else
1626 throw std::runtime_error(linalg_msg_no_scalapack);
1627#endif
1628 break;
1629 }
1630 default: {
1631 throw std::runtime_error(linalg_msg_wrong_type);
1632 break;
1633 }
1634 }
1635}
1636
1637template<>
1638inline void wrap::tranc<ftn_double_complex>(ftn_int m, ftn_int n, dmatrix<ftn_double_complex>& A,
1639 ftn_int ia, ftn_int ja, dmatrix<ftn_double_complex>& C, ftn_int ic, ftn_int jc) const
1640{
1641 switch (la_) {
1642 case lib_t::scalapack: {
1643#if defined(SIRIUS_SCALAPACK)
1644 ia++; ja++;
1645 ic++; jc++;
1646
1647 auto A_ptr = (A.num_rows_local() * A.num_cols_local() > 0) ? A.at(sddk::memory_t::host) : nullptr;
1648 auto C_ptr = (C.num_rows_local() * C.num_cols_local() > 0) ? C.at(sddk::memory_t::host) : nullptr;
1649
1650 FORTRAN(pztranc)(&m, &n, const_cast<ftn_double_complex*>(&constant<ftn_double_complex>::one()),
1651 A_ptr, &ia, &ja, A.descriptor(),
1652 const_cast<ftn_double_complex*>(&constant<ftn_double_complex>::zero()),
1653 C_ptr, &ic, &jc, C.descriptor());
1654#else
1655 throw std::runtime_error(linalg_msg_no_scalapack);
1656#endif
1657 break;
1658 }
1659 default: {
1660 throw std::runtime_error(linalg_msg_wrong_type);
1661 break;
1662 }
1663 }
1664}
1665
1666template <>
1667inline void wrap::tranc<ftn_single>(ftn_int m, ftn_int n, dmatrix<ftn_single>& A, ftn_int ia, ftn_int ja,
1668 dmatrix<ftn_single>& C, ftn_int ic, ftn_int jc) const
1669{
1670 switch (la_) {
1671 case lib_t::scalapack: {
1672#if defined(SIRIUS_SCALAPACK)
1673 ia++; ja++;
1674 ic++; jc++;
1675
1676 auto A_ptr = (A.num_rows_local() * A.num_cols_local() > 0) ? A.at(sddk::memory_t::host) : nullptr;
1677 auto C_ptr = (C.num_rows_local() * C.num_cols_local() > 0) ? C.at(sddk::memory_t::host) : nullptr;
1678
1679 FORTRAN(pstran)(&m, &n, const_cast<ftn_single*>(&constant<ftn_single>::one()), A_ptr,
1680 &ia, &ja, A.descriptor(), const_cast<ftn_single*>(&constant<ftn_single>::zero()),
1681 C_ptr, &ic, &jc, C.descriptor());
1682#else
1683 throw std::runtime_error(linalg_msg_no_scalapack);
1684#endif
1685 break;
1686 }
1687 default: {
1688 throw std::runtime_error(linalg_msg_wrong_type);
1689 break;
1690 }
1691 }
1692}
1693
1694template <>
1695inline void wrap::tranu<ftn_double>(ftn_int m, ftn_int n, dmatrix<ftn_double>& A, ftn_int ia, ftn_int ja,
1696 dmatrix<ftn_double>& C, ftn_int ic, ftn_int jc) const
1697{
1698 switch (la_) {
1699 case lib_t::scalapack: {
1700#if defined(SIRIUS_SCALAPACK)
1701 ia++; ja++;
1702 ic++; jc++;
1703
1704 auto A_ptr = (A.num_rows_local() * A.num_cols_local() > 0) ? A.at(sddk::memory_t::host) : nullptr;
1705 auto C_ptr = (C.num_rows_local() * C.num_cols_local() > 0) ? C.at(sddk::memory_t::host) : nullptr;
1706
1707 FORTRAN(pdtran)(&m, &n, const_cast<ftn_double*>(&constant<ftn_double>::one()), A_ptr,
1708 &ia, &ja, A.descriptor(), const_cast<ftn_double*>(&constant<ftn_double>::zero()),
1709 C_ptr, &ic, &jc, C.descriptor());
1710#else
1711 throw std::runtime_error(linalg_msg_no_scalapack);
1712#endif
1713 break;
1714 }
1715 default: {
1716 throw std::runtime_error(linalg_msg_wrong_type);
1717 break;
1718 }
1719 }
1720}
1721
1722template <>
1723inline void wrap::tranc<ftn_double>(ftn_int m, ftn_int n, dmatrix<ftn_double>& A, ftn_int ia, ftn_int ja,
1724 dmatrix<ftn_double>& C, ftn_int ic, ftn_int jc) const
1725{
1726 switch (la_) {
1727 case lib_t::scalapack: {
1728#if defined(SIRIUS_SCALAPACK)
1729 ia++; ja++;
1730 ic++; jc++;
1731
1732 auto A_ptr = (A.num_rows_local() * A.num_cols_local() > 0) ? A.at(sddk::memory_t::host) : nullptr;
1733 auto C_ptr = (C.num_rows_local() * C.num_cols_local() > 0) ? C.at(sddk::memory_t::host) : nullptr;
1734
1735 FORTRAN(pdtran)(&m, &n, const_cast<ftn_double*>(&constant<ftn_double>::one()), A_ptr,
1736 &ia, &ja, A.descriptor(), const_cast<ftn_double*>(&constant<ftn_double>::zero()),
1737 C_ptr, &ic, &jc, C.descriptor());
1738#else
1739 throw std::runtime_error(linalg_msg_no_scalapack);
1740#endif
1741 break;
1742 }
1743 default: {
1744 throw std::runtime_error(linalg_msg_wrong_type);
1745 break;
1746 }
1747 }
1748}
1749
1750// Inversion of LU factorized matrix, double
1751template<>
1752inline int wrap::getri<ftn_double>(ftn_int n, ftn_double* A, ftn_int lda, ftn_int* ipiv) const
1753{
1754 switch (la_) {
1755 case lib_t::lapack: {
1756 ftn_int nb = linalg_base::ilaenv(1, "dgetri", "U", n, -1, -1, -1);
1757 ftn_int lwork = n * nb;
1758 std::vector<ftn_double> work(lwork);
1759
1760 int32_t info;
1761 FORTRAN(dgetri)(&n, A, &lda, ipiv, &work[0], &lwork, &info);
1762 return info;
1763 break;
1764 }
1765 default: {
1766 throw std::runtime_error(linalg_msg_wrong_type);
1767 break;
1768 }
1769 }
1770 return -1;
1771}
1772
1773// Inversion of LU factorized matrix, double_complex
1774template<>
1775inline int wrap::getri<ftn_double_complex>(ftn_int n, ftn_double_complex* A, ftn_int lda, ftn_int* ipiv) const
1776{
1777 switch (la_) {
1778 case lib_t::lapack: {
1779 ftn_int nb = linalg_base::ilaenv(1, "zgetri", "U", n, -1, -1, -1);
1780 ftn_int lwork = n * nb;
1781 std::vector<ftn_double_complex> work(lwork);
1782
1783 int32_t info;
1784 FORTRAN(zgetri)(&n, A, &lda, ipiv, &work[0], &lwork, &info);
1785 return info;
1786 break;
1787 }
1788 default: {
1789 throw std::runtime_error(linalg_msg_wrong_type);
1790 break;
1791 }
1792 }
1793 return -1;
1794}
1795
1796template<>
1797inline int wrap::sytrf<ftn_double_complex>(ftn_int n, ftn_double_complex* A, ftn_int lda, ftn_int* ipiv) const
1798{
1799 switch (la_) {
1800 case lib_t::lapack: {
1801 ftn_int nb = linalg_base::ilaenv(1, "zhetrf", "U", n, -1, -1, -1);
1802 ftn_int lwork = n * nb;
1803 std::vector<ftn_double_complex> work(lwork);
1804
1805 ftn_int info;
1806 FORTRAN(zhetrf)("U", &n, A, &lda, ipiv, &work[0], &lwork, &info, (ftn_len)1);
1807 return info;
1808 break;
1809 }
1810 default: {
1811 throw std::runtime_error(linalg_msg_wrong_type);
1812 break;
1813 }
1814 }
1815 return -1;
1816}
1817
1818template<>
1819inline int wrap::sytrf<ftn_double>(ftn_int n, ftn_double* A, ftn_int lda, ftn_int* ipiv) const
1820{
1821 switch (la_) {
1822 case lib_t::lapack: {
1823 ftn_int nb = linalg_base::ilaenv(1, "dsytrf", "U", n, -1, -1, -1);
1824 ftn_int lwork = n * nb;
1825 std::vector<ftn_double> work(lwork);
1826
1827 ftn_int info;
1828 FORTRAN(dsytrf)("U", &n, A, &lda, ipiv, &work[0], &lwork, &info, (ftn_len)1);
1829 return info;
1830 break;
1831 }
1832 default: {
1833 throw std::runtime_error(linalg_msg_wrong_type);
1834 break;
1835 }
1836 }
1837 return -1;
1838}
1839
1840template<>
1841inline int wrap::sytri<ftn_double>(ftn_int n, ftn_double* A, ftn_int lda, ftn_int* ipiv) const
1842{
1843 switch (la_) {
1844 case lib_t::lapack: {
1845 std::vector<ftn_double> work(n);
1846 ftn_int info;
1847 FORTRAN(dsytri)("U", &n, A, &lda, ipiv, &work[0], &info, (ftn_len)1);
1848 return info;
1849 break;
1850 }
1851 default: {
1852 throw std::runtime_error(linalg_msg_wrong_type);
1853 break;
1854 }
1855 }
1856 return -1;
1857}
1858
1859template<>
1860inline int wrap::sytrs<ftn_double>(ftn_int n, ftn_int nrhs, ftn_double* A, ftn_int lda, ftn_int* ipiv, ftn_double* b, ftn_int ldb) const
1861{
1862 switch (la_) {
1863 case lib_t::lapack: {
1864 ftn_int info;
1865 FORTRAN(dsytrs)("U", &n, &nrhs, A, &lda, ipiv, b, &ldb, &info, (ftn_len)1);
1866 return info;
1867 break;
1868 }
1869 default: {
1870 throw std::runtime_error(linalg_msg_wrong_type);
1871 break;
1872 }
1873 }
1874 return -1;
1875}
1876
1877template<>
1878inline int wrap::sytri<ftn_double_complex>(ftn_int n, ftn_double_complex* A, ftn_int lda, ftn_int* ipiv) const
1879{
1880 switch (la_) {
1881 case lib_t::lapack: {
1882 std::vector<ftn_double_complex> work(n);
1883 ftn_int info;
1884 FORTRAN(zhetri)("U", &n, A, &lda, ipiv, &work[0], &info, (ftn_len)1);
1885 return info;
1886 }
1887 default: {
1888 throw std::runtime_error(linalg_msg_wrong_type);
1889 break;
1890 }
1891 }
1892 return -1;
1893}
1894
1895template<>
1896inline std::tuple<ftn_double, ftn_double, ftn_double> wrap::lartg(ftn_double f, ftn_double g) const
1897{
1898 switch (la_) {
1899 case lib_t::lapack: {
1900 ftn_double cs, sn, r;
1901 FORTRAN(dlartg)(&f, &g, &cs, &sn, &r);
1902 return std::make_tuple(cs, sn, r);
1903 }
1904 default: {
1905 throw std::runtime_error(linalg_msg_wrong_type);
1906 break;
1907 }
1908 }
1909}
1910
1911template <typename T>
1912inline void check_hermitian(std::string const& name, sddk::matrix<T> const& mtrx, int n = -1)
1913{
1914 assert(mtrx.size(0) == mtrx.size(1));
1915
1916 double maxdiff = 0.0;
1917 int i0 = -1;
1918 int j0 = -1;
1919
1920 if (n == -1) {
1921 n = static_cast<int>(mtrx.size(0));
1922 }
1923
1924 for (int i = 0; i < n; i++) {
1925 for (int j = 0; j < n; j++) {
1926 double diff = std::abs(mtrx(i, j) - std::conj(mtrx(j, i)));
1927 if (diff > maxdiff) {
1928 maxdiff = diff;
1929 i0 = i;
1930 j0 = j;
1931 }
1932 }
1933 }
1934
1935 if (maxdiff > 1e-10) {
1936 std::stringstream s;
1937 s << name << " is not a symmetric or hermitian matrix" << std::endl
1938 << " maximum error: i, j : " << i0 << " " << j0 << " diff : " << maxdiff;
1939
1940 RTE_WARNING(s);
1941 }
1942}
1943
1944template <typename T>
1945inline real_type<T> check_hermitian(dmatrix<T>& mtrx__, int n__)
1946{
1947 real_type<T> max_diff{0};
1948 if (mtrx__.comm().size() != 1) {
1949 dmatrix<T> tmp(n__, n__, mtrx__.blacs_grid(), mtrx__.bs_row(), mtrx__.bs_col());
1950 wrap(lib_t::scalapack).tranc(n__, n__, mtrx__, 0, 0, tmp, 0, 0);
1951 for (int i = 0; i < tmp.num_cols_local(); i++) {
1952 for (int j = 0; j < tmp.num_rows_local(); j++) {
1953 max_diff = std::max(max_diff, std::abs(mtrx__(j, i) - tmp(j, i)));
1954 }
1955 }
1956 mtrx__.blacs_grid().comm().template allreduce<real_type<T>, mpi::op_t::max>(&max_diff, 1);
1957 } else {
1958 for (int i = 0; i < n__; i++) {
1959 for (int j = 0; j < n__; j++) {
1960 max_diff = std::max(max_diff, std::abs(mtrx__(j, i) - std::conj(mtrx__(i, j))));
1961 }
1962 }
1963 }
1964 return max_diff;
1965}
1966
1967template <typename T>
1968inline double check_identity(dmatrix<T>& mtrx__, int n__)
1969{
1970 real_type<T> max_diff{0};
1971 for (int i = 0; i < mtrx__.num_cols_local(); i++) {
1972 int icol = mtrx__.icol(i);
1973 if (icol < n__) {
1974 for (int j = 0; j < mtrx__.num_rows_local(); j++) {
1975 int jrow = mtrx__.irow(j);
1976 if (jrow < n__) {
1977 if (icol == jrow) {
1978 max_diff = std::max(max_diff, std::abs(mtrx__(j, i) - static_cast<real_type<T>>(1.0)));
1979 } else {
1980 max_diff = std::max(max_diff, std::abs(mtrx__(j, i)));
1981 }
1982 }
1983 }
1984 }
1985 }
1986 mtrx__.comm().template allreduce<real_type<T>, mpi::op_t::max>(&max_diff, 1);
1987 return max_diff;
1988}
1989
1990template <typename T>
1991inline double check_diagonal(dmatrix<T>& mtrx__, int n__, sddk::mdarray<double, 1> const& diag__)
1992{
1993 double max_diff{0};
1994 for (int i = 0; i < mtrx__.num_cols_local(); i++) {
1995 int icol = mtrx__.icol(i);
1996 if (icol < n__) {
1997 for (int j = 0; j < mtrx__.num_rows_local(); j++) {
1998 int jrow = mtrx__.irow(j);
1999 if (jrow < n__) {
2000 if (icol == jrow) {
2001 max_diff = std::max(max_diff, std::abs(mtrx__(j, i) - diag__[icol]));
2002 } else {
2003 max_diff = std::max(max_diff, std::abs(mtrx__(j, i)));
2004 }
2005 }
2006 }
2007 }
2008 }
2009 mtrx__.comm().template allreduce<double, mpi::op_t::max>(&max_diff, 1);
2010 return max_diff;
2011}
2012
2013/** Perform one of the following operations:
2014 * A <= U A U^{H} (kind = 0)
2015 * A <= U^{H} A U (kind = 1)
2016 */
2017template <typename T>
2018inline void unitary_similarity_transform(int kind__, dmatrix<T>& A__, dmatrix<T> const& U__, int n__)
2019{
2020 // TODO: use memory pool to allocate tmp matrix
2021 if (!(kind__ == 0 || kind__ == 1)) {
2022 RTE_THROW("wrong 'kind' parameter");
2023 }
2024 char c1 = kind__ == 0 ? 'N' : 'C';
2025 char c2 = kind__ == 0 ? 'C' : 'N';
2026 if (A__.comm().size() != 1) {
2027 dmatrix<T> tmp(n__, n__, A__.blacs_grid(), A__.bs_row(), A__.bs_col());
2028
2029 /* compute tmp <= U A or U^{H} A */
2030 wrap(lib_t::scalapack).gemm(c1, 'N', n__, n__, n__, &constant<T>::one(),
2031 U__, 0, 0, A__, 0, 0, &constant<T>::zero(), tmp, 0, 0);
2032
2033 /* compute A <= tmp U^{H} or tmp U */
2034 wrap(lib_t::scalapack).gemm('N', c2, n__, n__, n__, &constant<T>::one(),
2035 tmp, 0, 0, U__, 0, 0, &constant<T>::zero(), A__, 0, 0);
2036 } else {
2037 dmatrix<T> tmp(n__, n__);
2038
2039 /* compute tmp <= U A or U^{H} A */
2040 wrap(lib_t::blas).gemm(c1, 'N', n__, n__, n__, &constant<T>::one(),
2041 U__.at(sddk::memory_t::host), U__.ld(), A__.at(sddk::memory_t::host), A__.ld(), &constant<T>::zero(),
2042 tmp.at(sddk::memory_t::host), tmp.ld());
2043
2044 /* compute A <= tmp U^{H} or tmp U */
2045 wrap(lib_t::blas).gemm('N', c2, n__, n__, n__, &constant<T>::one(),
2046 tmp.at(sddk::memory_t::host), tmp.ld(), U__.at(sddk::memory_t::host), U__.ld(), &constant<T>::zero(),
2047 A__.at(sddk::memory_t::host), A__.ld());
2048 }
2049}
2050
2051} // namespace
2052
2053} // namespace sirius
2054
2055#endif // __LINALG_HPP__
Interface to accelerators API.
Blas functions for execution on GPUs.
Interface to some BLAS/LAPACK functions.
Helper class to wrap stream id (integer number).
Definition: acc.hpp:132
Distributed matrix.
Definition: dmatrix.hpp:56
int bs_row() const
Row blocking factor.
Definition: dmatrix.hpp:301
int bs_col() const
Column blocking factor.
Definition: dmatrix.hpp:307
int gesv(ftn_int n, ftn_int nrhs, T *A, ftn_int lda, T *B, ftn_int ldb) const
Compute the solution to system of linear equations A * X = B for general matrix.
int sytri(ftn_int n, T *A, ftn_int lda, ftn_int *ipiv) const
Inversion of factorized symmetric triangular matrix.
void geinv(ftn_int n, sddk::matrix< T > &A) const
Invert a general matrix.
Definition: linalg.hpp:163
void gemm(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, T const *alpha, dmatrix< T > const &A, ftn_int ia, ftn_int ja, dmatrix< T > const &B, ftn_int ib, ftn_int jb, T const *beta, dmatrix< T > &C, ftn_int ic, ftn_int jc)
Distributed general matrix-matrix multiplication.
void tranc(ftn_int m, ftn_int n, dmatrix< T > &A, ftn_int ia, ftn_int ja, dmatrix< T > &C, ftn_int ic, ftn_int jc) const
Conjugate transpose matrix.
int sytrf(ftn_int n, T *A, ftn_int lda, ftn_int *ipiv) const
U*D*U^H factorization of hermitian or symmetric matrix.
void tranu(ftn_int m, ftn_int n, dmatrix< T > &A, ftn_int ia, ftn_int ja, dmatrix< T > &C, ftn_int ic, ftn_int jc) const
Transpose matrix without conjugation.
int trtri(ftn_int n, T *A, ftn_int lda, ftn_int const *desca=nullptr) const
Inversion of a triangular matrix.
int gtsv(ftn_int n, ftn_int nrhs, T *dl, T *d, T *du, T *b, ftn_int ldb) const
Compute the solution to system of linear equations A * X = B for general tri-diagonal matrix.
int sytrs(ftn_int n, ftn_int nrhs, T *A, ftn_int lda, ftn_int *ipiv, T *b, ftn_int ldb) const
solve Ax=b in place of b where A is factorized with sytrf.
void hemm(char side, char uplo, ftn_int m, ftn_int n, T const *alpha, T const *A, ftn_len lda, T const *B, ftn_len ldb, T const *beta, T *C, ftn_len ldc)
Hermitian matrix times a general matrix or vice versa.
int getrf(ftn_int m, ftn_int n, T *A, ftn_int lda, ftn_int *ipiv) const
LU factorization of general matrix.
void gemm(char transa, char transb, ftn_int m, ftn_int n, ftn_int k, T const *alpha, T const *A, ftn_int lda, T const *B, ftn_int ldb, T const *beta, T *C, ftn_int ldc, acc::stream_id sid=acc::stream_id(-1)) const
General matrix-matrix multiplication.
void axpy(int n, T const *alpha, T const *x, int incx, T *y, int incy)
vector addition
int potrf(ftn_int n, T *A, ftn_int lda, ftn_int const *desca=nullptr) const
Cholesky factorization.
int getrf(ftn_int m, ftn_int n, dmatrix< T > &A, ftn_int ia, ftn_int ja, ftn_int *ipiv) const
LU factorization of general matrix.
uint32_t ld() const
Return leading dimension size.
Definition: memory.hpp:1233
Interface to CUDA eigen-solver library.
Contains definition and implementation of distributed matrix class.
bool is_set_device_id()
check if device id has been set properly
Definition: linalg.hpp:51
Interface to SPLA library.
Interface to some of the MAGMA functions.
Memory management functions and classes.
void zgetrs(rocblas_handle handle, char trans, int n, int nrhs, acc_complex_double_t *A, int lda, const int *devIpiv, acc_complex_double_t *B, int ldb)
Linear Solvers.
int num_devices()
Get the number of devices.
Definition: acc.cpp:32
int get_device_id()
Get current device ID.
Definition: acc.hpp:191
lib_t
Type of linear algebra backend library.
Definition: linalg_base.hpp:70
@ cublasxt
cuBlasXt (cuBlas with CPU pointers and large matrices support)
@ scalapack
CPU ScaLAPACK.
@ spla
SPLA library. Can take CPU and device pointers.
@ lapack
CPU LAPACK.
@ magma
MAGMA with CPU pointers.
@ gpublas
GPU BLAS (cuBlas or ROCblas)
void unitary_similarity_transform(int kind__, dmatrix< T > &A__, dmatrix< T > const &U__, int n__)
Definition: linalg.hpp:2018
int get_device_id(int num_devices__)
Get GPU device id associated with the current rank.
Namespace of the SIRIUS library.
Definition: sirius.f90:5
@ du
Down-up block.
auto conj(double x__)
Return complex conjugate of a number. For a real value this is the number itself.
Definition: math_tools.hpp:165