25#ifndef __LINALG_BASE_HPP__
26#define __LINALG_BASE_HPP__
43 static T
const& one()
noexcept
49 static T
const& two()
noexcept
55 static T
const& m_one()
noexcept
57 static const T a = -1;
61 static T
const& zero()
noexcept
89inline auto get_lib_t(std::string name__)
91 std::transform(name__.begin(), name__.end(), name__.begin(), ::tolower);
93 static const std::map<std::string, lib_t> map_to_type = {
103 if (map_to_type.count(name__) == 0) {
105 s <<
"wrong label of linear algebra type: " << name__;
106 throw std::runtime_error(s.str());
109 return map_to_type.at(name__);
112inline std::string to_string(
lib_t la__)
153ftn_int FORTRAN(ilaenv)(ftn_int* ispec, ftn_char name, ftn_char opts, ftn_int* n1, ftn_int* n2, ftn_int* n3,
154 ftn_int* n4, ftn_len name_len, ftn_len opts_len);
156ftn_double FORTRAN(dlamch)(ftn_char cmach, ftn_len cmach_len);
158#ifdef SIRIUS_SCALAPACK
159int Csys2blacs_handle(MPI_Comm SysCtxt);
161MPI_Comm Cblacs2sys_handle(
int BlacsCtxt);
163void Cblacs_gridinit(
int* ConTxt,
const char* order,
int nprow,
int npcol);
165void Cblacs_gridmap(
int* ConTxt,
int* usermap,
int ldup,
int nprow0,
int npcol0);
167void Cblacs_gridinfo(
int ConTxt,
int* nprow,
int* npcol,
int* myrow,
int* mycol);
169void Cfree_blacs_system_handle(
int ISysCtxt);
171void Cblacs_barrier(
int ConTxt,
const char* scope);
173void Cblacs_gridexit(
int ConTxt);
175void FORTRAN(psgemm)(ftn_char transa, ftn_char transb, ftn_int* m, ftn_int* n, ftn_int* k, ftn_single
const* aplha,
176 ftn_single
const* A, ftn_int* ia, ftn_int* ja, ftn_int
const* desca, ftn_single
const* B, ftn_int* ib,
177 ftn_int* jb, ftn_int
const* descb, ftn_single
const* beta, ftn_single* C, ftn_int* ic, ftn_int* jc,
178 ftn_int
const* descc, ftn_len transa_len, ftn_len transb_len);
180void FORTRAN(pdgemm)(ftn_char transa, ftn_char transb, ftn_int* m, ftn_int* n, ftn_int* k, ftn_double
const* aplha,
181 ftn_double
const* A, ftn_int* ia, ftn_int* ja, ftn_int
const* desca, ftn_double
const* B, ftn_int* ib,
182 ftn_int* jb, ftn_int
const* descb, ftn_double
const* beta, ftn_double* C, ftn_int* ic, ftn_int* jc,
183 ftn_int
const* descc, ftn_len transa_len, ftn_len transb_len);
185void FORTRAN(pcgemm)(ftn_char transa, ftn_char transb, ftn_int* m, ftn_int* n, ftn_int* k, ftn_complex
const* aplha,
186 ftn_complex
const* A, ftn_int* ia, ftn_int* ja, ftn_int
const* desca, ftn_complex
const* B,
187 ftn_int* ib, ftn_int* jb, ftn_int
const* descb, ftn_complex
const* beta, ftn_complex* C,
188 ftn_int* ic, ftn_int* jc, ftn_int
const* descc, ftn_len transa_len, ftn_len transb_len);
190void FORTRAN(pzgemm)(ftn_char transa, ftn_char transb, ftn_int* m, ftn_int* n, ftn_int* k, ftn_double_complex
const* aplha,
191 ftn_double_complex
const* A, ftn_int* ia, ftn_int* ja, ftn_int
const* desca, ftn_double_complex
const* B,
192 ftn_int* ib, ftn_int* jb, ftn_int
const* descb, ftn_double_complex
const* beta, ftn_double_complex* C,
193 ftn_int* ic, ftn_int* jc, ftn_int
const* descc, ftn_len transa_len, ftn_len transb_len);
195void FORTRAN(descinit)(ftn_int
const* desc, ftn_int* m, ftn_int* n, ftn_int* mb, ftn_int* nb, ftn_int* irsrc,
196 ftn_int* icsrc, ftn_int* ictxt, ftn_int* lld, ftn_int* info);
198void FORTRAN(pctranc)(ftn_int* m, ftn_int* n, ftn_complex* alpha, ftn_complex* a, ftn_int* ia,
199 ftn_int* ja, ftn_int
const* desca, ftn_complex* beta, ftn_complex* c, ftn_int* ic,
200 ftn_int* jc, ftn_int
const* descc);
202void FORTRAN(pztranc)(ftn_int* m, ftn_int* n, ftn_double_complex* alpha, ftn_double_complex* a, ftn_int* ia,
203 ftn_int* ja, ftn_int
const* desca, ftn_double_complex* beta, ftn_double_complex* c, ftn_int* ic,
204 ftn_int* jc, ftn_int
const* descc);
206void FORTRAN(pztranu)(ftn_int* m, ftn_int* n, ftn_double_complex* alpha, ftn_double_complex* a, ftn_int* ia,
207 ftn_int* ja, ftn_int
const* desca, ftn_double_complex* beta, ftn_double_complex* c, ftn_int* ic,
208 ftn_int* jc, ftn_int
const* descc);
210void FORTRAN(pstran)(ftn_int* m, ftn_int* n, ftn_single* alpha, ftn_single* a, ftn_int* ia, ftn_int* ja,
211 ftn_int
const* desca, ftn_single* beta, ftn_single* c, ftn_int* ic, ftn_int* jc,
212 ftn_int
const* descc);
214void FORTRAN(pdtran)(ftn_int* m, ftn_int* n, ftn_double* alpha, ftn_double* a, ftn_int* ia, ftn_int* ja,
215 ftn_int
const* desca, ftn_double* beta, ftn_double* c, ftn_int* ic, ftn_int* jc,
216 ftn_int
const* descc);
218ftn_int FORTRAN(numroc)(ftn_int* n, ftn_int* nb, ftn_int* iproc, ftn_int* isrcproc, ftn_int* nprocs);
220ftn_int FORTRAN(indxl2g)(ftn_int* indxloc, ftn_int* nb, ftn_int* iproc, ftn_int* isrcproc, ftn_int* nprocs);
222ftn_len FORTRAN(iceil)(ftn_int* inum, ftn_int* idenom);
224void FORTRAN(pzgemr2d)(ftn_int* m, ftn_int* n, ftn_double_complex* a, ftn_int* ia, ftn_int* ja, ftn_int
const* desca,
225 ftn_double_complex* b, ftn_int* ib, ftn_int* jb, ftn_int
const* descb, ftn_int* gcontext);
233 static ftn_int ilaenv(ftn_int ispec, std::string
const& name, std::string
const& opts, ftn_int n1, ftn_int n2,
234 ftn_int n3, ftn_int n4)
236 return FORTRAN(ilaenv)(&ispec, name.c_str(), opts.c_str(), &n1, &n2, &n3, &n4, (ftn_len)name.length(),
237 (ftn_len)opts.length());
240 static ftn_double dlamch(
char cmach)
242 return FORTRAN(dlamch)(&cmach, (ftn_len)1);
245#ifdef SIRIUS_SCALAPACK
246 static ftn_int numroc(ftn_int n, ftn_int nb, ftn_int iproc, ftn_int isrcproc, ftn_int nprocs)
248 return FORTRAN(numroc)(&n, &nb, &iproc, &isrcproc, &nprocs);
254 return Csys2blacs_handle(comm);
260 Cfree_blacs_system_handle(blacs_handler);
264 static void gridmap(
int* blacs_context,
int* map,
int ld,
int nrow,
int ncol)
266 Cblacs_gridmap(blacs_context, map, ld, nrow, ncol);
272 Cblacs_gridexit(blacs_context);
275 static void gridinfo(
int blacs_context,
int* nrow,
int* ncol,
int* irow,
int* icol)
277 Cblacs_gridinfo(blacs_context, nrow, ncol, irow, icol);
280 static void descinit(ftn_int* desc, ftn_int m, ftn_int n, ftn_int mb, ftn_int nb, ftn_int irsrc, ftn_int icsrc,
281 ftn_int ictxt, ftn_int lld)
284 ftn_int lld1 = std::max(1, lld);
286 FORTRAN(descinit)(desc, &m, &n, &mb, &nb, &irsrc, &icsrc, &ictxt, &lld1, &info);
289 std::printf(
"error in descinit()\n");
290 std::printf(
"m=%i n=%i mb=%i nb=%i irsrc=%i icsrc=%i lld=%i\n", m, n, mb, nb, irsrc, icsrc, lld);
295 static int pjlaenv(int32_t ictxt, int32_t ispec,
const std::string& name,
const std::string& opts, int32_t n1,
296 int32_t n2, int32_t n3, int32_t n4)
298 return FORTRAN(pjlaenv)(&ictxt, &ispec, name.c_str(), opts.c_str(), &n1, &n2, &n3, &n4, (int32_t)name.length(),
299 (int32_t)opts.length());
302 static int32_t indxl2g(int32_t indxloc, int32_t nb, int32_t iproc, int32_t isrcproc, int32_t nprocs)
304 return FORTRAN(indxl2g)(&indxloc, &nb, &iproc, &isrcproc, &nprocs);
307 static int32_t iceil(int32_t inum, int32_t idenom)
309 return FORTRAN(iceil)(&inum, &idenom);
Interface to some BLAS/LAPACK functions.
Base class for linear algebra interface.
static void gridmap(int *blacs_context, int *map, int ld, int nrow, int ncol)
Create BLACS context for the grid of MPI ranks.
static void free_blacs_handler(int blacs_handler)
Free BLACS handler.
static int create_blacs_handler(MPI_Comm comm)
Create BLACS handler.
static void gridexit(int blacs_context)
Destroy BLACS context.
lib_t
Type of linear algebra backend library.
@ cublasxt
cuBlasXt (cuBlas with CPU pointers and large matrices support)
@ scalapack
CPU ScaLAPACK.
@ spla
SPLA library. Can take CPU and device pointers.
@ magma
MAGMA with CPU pointers.
@ gpublas
GPU BLAS (cuBlas or ROCblas)
@ magma
MAGMA with CPU pointers.
std::enable_if_t< std::is_same< T, real_type< F > >::value, void > transform(::spla::Context &spla_ctx__, sddk::memory_t mem__, la::dmatrix< F > const &M__, int irow0__, int jcol0__, real_type< F > alpha__, Wave_functions< T > const &wf_in__, spin_index s_in__, band_range br_in__, real_type< F > beta__, Wave_functions< T > &wf_out__, spin_index s_out__, band_range br_out__)
Apply linear transformation to the wave-functions.
Namespace of the SIRIUS library.
Interface to some ScaLAPACK functions.