SIRIUS 7.5.0
Electronic structure library and applications
linalg_base.hpp
Go to the documentation of this file.
1// Copyright (c) 2013-2016 Anton Kozhevnikov, Thomas Schulthess
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted provided that
5// the following conditions are met:
6//
7// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
8// following disclaimer.
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
10// and the following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
13// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
15// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
16// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
17// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
18// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
20/** \file linalg_base.hpp
21 *
22 * \brief Basic interface to linear algebra functions.
23 */
24
25#ifndef __LINALG_BASE_HPP__
26#define __LINALG_BASE_HPP__
27
28#include <algorithm>
29#include <map>
30
31#include "blas_lapack.h"
32#include "scalapack.h"
33#include <mpi.h>
34
35namespace sirius {
36
37/// Interface to linear algebra BLAS/LAPACK functions.
38namespace la {
39
40template <typename T>
42{
43 static T const& one() noexcept
44 {
45 static const T a = 1;
46 return a;
47 }
48
49 static T const& two() noexcept
50 {
51 static const T a = 2;
52 return a;
53 }
54
55 static T const& m_one() noexcept
56 {
57 static const T a = -1;
58 return a;
59 }
60
61 static T const& zero() noexcept
62 {
63 static const T a = 0;
64 return a;
65 }
66};
67
68/// Type of linear algebra backend library.
69enum class lib_t
70{
71 /// None
72 none,
73 /// CPU BLAS
74 blas,
75 /// CPU LAPACK
76 lapack,
77 /// CPU ScaLAPACK
79 /// GPU BLAS (cuBlas or ROCblas)
80 gpublas,
81 /// cuBlasXt (cuBlas with CPU pointers and large matrices support)
83 /// MAGMA with CPU pointers
84 magma,
85 /// SPLA library. Can take CPU and device pointers
86 spla
87};
88
89inline auto get_lib_t(std::string name__)
90{
91 std::transform(name__.begin(), name__.end(), name__.begin(), ::tolower);
92
93 static const std::map<std::string, lib_t> map_to_type = {
94 {"blas", lib_t::blas},
95 {"lapack", lib_t::lapack},
96 {"scalapack", lib_t::scalapack},
97 {"cublas", lib_t::gpublas},
98 {"gpublas", lib_t::gpublas},
99 {"cublasxt", lib_t::cublasxt},
100 {"magma", lib_t::magma},
101 };
102
103 if (map_to_type.count(name__) == 0) {
104 std::stringstream s;
105 s << "wrong label of linear algebra type: " << name__;
106 throw std::runtime_error(s.str());
107 }
108
109 return map_to_type.at(name__);
110}
111
112inline std::string to_string(lib_t la__)
113{
114 switch (la__) {
115 case lib_t::none: {
116 return "none";
117 break;
118 }
119 case lib_t::blas: {
120 return "blas";
121 break;
122 }
123 case lib_t::lapack: {
124 return "lapack";
125 break;
126 }
127 case lib_t::scalapack: {
128 return "scalapack";
129 break;
130 }
131 case lib_t::gpublas: {
132 return "gpublas";
133 break;
134 }
135 case lib_t::cublasxt: {
136 return "cublasxt";
137 break;
138 }
139 case lib_t::magma: {
140 return "magma";
141 break;
142 }
143 case lib_t::spla: {
144 return "spla";
145 break;
146 }
147 }
148 return ""; // make compiler happy
149}
150
151extern "C" {
152
153ftn_int FORTRAN(ilaenv)(ftn_int* ispec, ftn_char name, ftn_char opts, ftn_int* n1, ftn_int* n2, ftn_int* n3,
154 ftn_int* n4, ftn_len name_len, ftn_len opts_len);
155
156ftn_double FORTRAN(dlamch)(ftn_char cmach, ftn_len cmach_len);
157
158#ifdef SIRIUS_SCALAPACK
159int Csys2blacs_handle(MPI_Comm SysCtxt);
160
161MPI_Comm Cblacs2sys_handle(int BlacsCtxt);
162
163void Cblacs_gridinit(int* ConTxt, const char* order, int nprow, int npcol);
164
165void Cblacs_gridmap(int* ConTxt, int* usermap, int ldup, int nprow0, int npcol0);
166
167void Cblacs_gridinfo(int ConTxt, int* nprow, int* npcol, int* myrow, int* mycol);
168
169void Cfree_blacs_system_handle(int ISysCtxt);
170
171void Cblacs_barrier(int ConTxt, const char* scope);
172
173void Cblacs_gridexit(int ConTxt);
174
175void FORTRAN(psgemm)(ftn_char transa, ftn_char transb, ftn_int* m, ftn_int* n, ftn_int* k, ftn_single const* aplha,
176 ftn_single const* A, ftn_int* ia, ftn_int* ja, ftn_int const* desca, ftn_single const* B, ftn_int* ib,
177 ftn_int* jb, ftn_int const* descb, ftn_single const* beta, ftn_single* C, ftn_int* ic, ftn_int* jc,
178 ftn_int const* descc, ftn_len transa_len, ftn_len transb_len);
179
180void FORTRAN(pdgemm)(ftn_char transa, ftn_char transb, ftn_int* m, ftn_int* n, ftn_int* k, ftn_double const* aplha,
181 ftn_double const* A, ftn_int* ia, ftn_int* ja, ftn_int const* desca, ftn_double const* B, ftn_int* ib,
182 ftn_int* jb, ftn_int const* descb, ftn_double const* beta, ftn_double* C, ftn_int* ic, ftn_int* jc,
183 ftn_int const* descc, ftn_len transa_len, ftn_len transb_len);
184
185void FORTRAN(pcgemm)(ftn_char transa, ftn_char transb, ftn_int* m, ftn_int* n, ftn_int* k, ftn_complex const* aplha,
186 ftn_complex const* A, ftn_int* ia, ftn_int* ja, ftn_int const* desca, ftn_complex const* B,
187 ftn_int* ib, ftn_int* jb, ftn_int const* descb, ftn_complex const* beta, ftn_complex* C,
188 ftn_int* ic, ftn_int* jc, ftn_int const* descc, ftn_len transa_len, ftn_len transb_len);
189
190void FORTRAN(pzgemm)(ftn_char transa, ftn_char transb, ftn_int* m, ftn_int* n, ftn_int* k, ftn_double_complex const* aplha,
191 ftn_double_complex const* A, ftn_int* ia, ftn_int* ja, ftn_int const* desca, ftn_double_complex const* B,
192 ftn_int* ib, ftn_int* jb, ftn_int const* descb, ftn_double_complex const* beta, ftn_double_complex* C,
193 ftn_int* ic, ftn_int* jc, ftn_int const* descc, ftn_len transa_len, ftn_len transb_len);
194
195void FORTRAN(descinit)(ftn_int const* desc, ftn_int* m, ftn_int* n, ftn_int* mb, ftn_int* nb, ftn_int* irsrc,
196 ftn_int* icsrc, ftn_int* ictxt, ftn_int* lld, ftn_int* info);
197
198void FORTRAN(pctranc)(ftn_int* m, ftn_int* n, ftn_complex* alpha, ftn_complex* a, ftn_int* ia,
199 ftn_int* ja, ftn_int const* desca, ftn_complex* beta, ftn_complex* c, ftn_int* ic,
200 ftn_int* jc, ftn_int const* descc);
201
202void FORTRAN(pztranc)(ftn_int* m, ftn_int* n, ftn_double_complex* alpha, ftn_double_complex* a, ftn_int* ia,
203 ftn_int* ja, ftn_int const* desca, ftn_double_complex* beta, ftn_double_complex* c, ftn_int* ic,
204 ftn_int* jc, ftn_int const* descc);
205
206void FORTRAN(pztranu)(ftn_int* m, ftn_int* n, ftn_double_complex* alpha, ftn_double_complex* a, ftn_int* ia,
207 ftn_int* ja, ftn_int const* desca, ftn_double_complex* beta, ftn_double_complex* c, ftn_int* ic,
208 ftn_int* jc, ftn_int const* descc);
209
210void FORTRAN(pstran)(ftn_int* m, ftn_int* n, ftn_single* alpha, ftn_single* a, ftn_int* ia, ftn_int* ja,
211 ftn_int const* desca, ftn_single* beta, ftn_single* c, ftn_int* ic, ftn_int* jc,
212 ftn_int const* descc);
213
214void FORTRAN(pdtran)(ftn_int* m, ftn_int* n, ftn_double* alpha, ftn_double* a, ftn_int* ia, ftn_int* ja,
215 ftn_int const* desca, ftn_double* beta, ftn_double* c, ftn_int* ic, ftn_int* jc,
216 ftn_int const* descc);
217
218ftn_int FORTRAN(numroc)(ftn_int* n, ftn_int* nb, ftn_int* iproc, ftn_int* isrcproc, ftn_int* nprocs);
219
220ftn_int FORTRAN(indxl2g)(ftn_int* indxloc, ftn_int* nb, ftn_int* iproc, ftn_int* isrcproc, ftn_int* nprocs);
221
222ftn_len FORTRAN(iceil)(ftn_int* inum, ftn_int* idenom);
223
224void FORTRAN(pzgemr2d)(ftn_int* m, ftn_int* n, ftn_double_complex* a, ftn_int* ia, ftn_int* ja, ftn_int const* desca,
225 ftn_double_complex* b, ftn_int* ib, ftn_int* jb, ftn_int const* descb, ftn_int* gcontext);
226#endif
227}
228
229/// Base class for linear algebra interface.
231{
232 public:
233 static ftn_int ilaenv(ftn_int ispec, std::string const& name, std::string const& opts, ftn_int n1, ftn_int n2,
234 ftn_int n3, ftn_int n4)
235 {
236 return FORTRAN(ilaenv)(&ispec, name.c_str(), opts.c_str(), &n1, &n2, &n3, &n4, (ftn_len)name.length(),
237 (ftn_len)opts.length());
238 }
239
240 static ftn_double dlamch(char cmach)
241 {
242 return FORTRAN(dlamch)(&cmach, (ftn_len)1);
243 }
244
245#ifdef SIRIUS_SCALAPACK
246 static ftn_int numroc(ftn_int n, ftn_int nb, ftn_int iproc, ftn_int isrcproc, ftn_int nprocs)
247 {
248 return FORTRAN(numroc)(&n, &nb, &iproc, &isrcproc, &nprocs);
249 }
250
251 /// Create BLACS handler.
252 static int create_blacs_handler(MPI_Comm comm)
253 {
254 return Csys2blacs_handle(comm);
255 }
256
257 /// Free BLACS handler.
258 static void free_blacs_handler(int blacs_handler)
259 {
260 Cfree_blacs_system_handle(blacs_handler);
261 }
262
263 /// Create BLACS context for the grid of MPI ranks
264 static void gridmap(int* blacs_context, int* map, int ld, int nrow, int ncol)
265 {
266 Cblacs_gridmap(blacs_context, map, ld, nrow, ncol);
267 }
268
269 /// Destroy BLACS context.
270 static void gridexit(int blacs_context)
271 {
272 Cblacs_gridexit(blacs_context);
273 }
274
275 static void gridinfo(int blacs_context, int* nrow, int* ncol, int* irow, int* icol)
276 {
277 Cblacs_gridinfo(blacs_context, nrow, ncol, irow, icol);
278 }
279
280 static void descinit(ftn_int* desc, ftn_int m, ftn_int n, ftn_int mb, ftn_int nb, ftn_int irsrc, ftn_int icsrc,
281 ftn_int ictxt, ftn_int lld)
282 {
283 ftn_int info;
284 ftn_int lld1 = std::max(1, lld);
285
286 FORTRAN(descinit)(desc, &m, &n, &mb, &nb, &irsrc, &icsrc, &ictxt, &lld1, &info);
287
288 if (info) {
289 std::printf("error in descinit()\n");
290 std::printf("m=%i n=%i mb=%i nb=%i irsrc=%i icsrc=%i lld=%i\n", m, n, mb, nb, irsrc, icsrc, lld);
291 exit(-1);
292 }
293 }
294
295 static int pjlaenv(int32_t ictxt, int32_t ispec, const std::string& name, const std::string& opts, int32_t n1,
296 int32_t n2, int32_t n3, int32_t n4)
297 {
298 return FORTRAN(pjlaenv)(&ictxt, &ispec, name.c_str(), opts.c_str(), &n1, &n2, &n3, &n4, (int32_t)name.length(),
299 (int32_t)opts.length());
300 }
301
302 static int32_t indxl2g(int32_t indxloc, int32_t nb, int32_t iproc, int32_t isrcproc, int32_t nprocs)
303 {
304 return FORTRAN(indxl2g)(&indxloc, &nb, &iproc, &isrcproc, &nprocs);
305 }
306
307 static int32_t iceil(int32_t inum, int32_t idenom)
308 {
309 return FORTRAN(iceil)(&inum, &idenom);
310 }
311#endif
312};
313
314} // namespace
315
316}
317
318#endif // __LINALG_BASE_HPP__
Interface to some BLAS/LAPACK functions.
Base class for linear algebra interface.
static void gridmap(int *blacs_context, int *map, int ld, int nrow, int ncol)
Create BLACS context for the grid of MPI ranks.
static void free_blacs_handler(int blacs_handler)
Free BLACS handler.
static int create_blacs_handler(MPI_Comm comm)
Create BLACS handler.
static void gridexit(int blacs_context)
Destroy BLACS context.
lib_t
Type of linear algebra backend library.
Definition: linalg_base.hpp:70
@ cublasxt
cuBlasXt (cuBlas with CPU pointers and large matrices support)
@ scalapack
CPU ScaLAPACK.
@ spla
SPLA library. Can take CPU and device pointers.
@ lapack
CPU LAPACK.
@ magma
MAGMA with CPU pointers.
@ gpublas
GPU BLAS (cuBlas or ROCblas)
@ magma
MAGMA with CPU pointers.
std::enable_if_t< std::is_same< T, real_type< F > >::value, void > transform(::spla::Context &spla_ctx__, sddk::memory_t mem__, la::dmatrix< F > const &M__, int irow0__, int jcol0__, real_type< F > alpha__, Wave_functions< T > const &wf_in__, spin_index s_in__, band_range br_in__, real_type< F > beta__, Wave_functions< T > &wf_out__, spin_index s_out__, band_range br_out__)
Apply linear transformation to the wave-functions.
Namespace of the SIRIUS library.
Definition: sirius.f90:5
Interface to some ScaLAPACK functions.