SIRIUS 7.5.0
Electronic structure library and applications
acc_blas.hpp
Go to the documentation of this file.
1// Copyright (c) 2013-2017 Anton Kozhevnikov, Thomas Schulthess
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted provided that
5// the following conditions are met:
6//
7// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
8// following disclaimer.
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
10// and the following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
13// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
15// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
16// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
17// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
18// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
20/** \file acc_blas.hpp
21 *
22 * \brief Blas functions for execution on GPUs.
23 */
24
25#ifndef __ACC_BLAS_HPP__
26#define __ACC_BLAS_HPP__
27
28#include <unistd.h>
29#include <vector>
30#include "acc_blas_api.hpp"
31#include "acc.hpp"
32
33namespace sirius {
34
35namespace acc {
36
37/// User facing interface to GPU blas functions.
38namespace blas {
39
40#ifdef SIRIUS_CUDA
41inline const char*
42error_message(acc::blas_api::status_t status)
43{
44 switch (status) {
45 case CUBLAS_STATUS_NOT_INITIALIZED: {
46 return "the library was not initialized";
47 break;
48 }
49 case CUBLAS_STATUS_INVALID_VALUE: {
50 return "the parameters m,n,k<0";
51 break;
52 }
53 case CUBLAS_STATUS_ARCH_MISMATCH: {
54 return "the device does not support double-precision";
55 break;
56 }
57 case CUBLAS_STATUS_EXECUTION_FAILED: {
58 return "the function failed to launch on the GPU";
59 break;
60 }
61 default: {
62 return "gpublas status unknown";
63 }
64 }
65}
66#else
67inline const char*
68error_message(acc::blas_api::status_t status)
69{
70 return rocblas_status_to_string(status);
71}
72#endif
73
74inline acc::blas_api::operation_t
75get_gpublasOperation_t(char c)
76{
77 switch (c) {
78 case 'n':
79 case 'N': {
80 return acc::blas_api::operation::None;
81 }
82 case 't':
83 case 'T': {
84 return acc::blas_api::operation::Transpose;
85 }
86 case 'c':
87 case 'C': {
88 return acc::blas_api::operation::ConjugateTranspose;
89 }
90 default: {
91 throw std::runtime_error("get_gpublasOperation_t(): wrong operation");
92 }
93 }
94 return acc::blas_api::operation::None; // make compiler happy
95}
96
97inline acc::blas_api::side_mode_t
98get_gpublasSideMode_t(char c)
99{
100 switch (c) {
101 case 'l':
102 case 'L': {
103 return acc::blas_api::side::Left;
104 }
105 case 'r':
106 case 'R': {
107 return acc::blas_api::side::Right;
108 }
109 default: {
110 throw std::runtime_error("get_gpublasSideMode_t(): wrong side");
111 }
112 }
113 return acc::blas_api::side::Left; // make compiler happy
114}
115
116inline acc::blas_api::fill_mode_t
117get_gpublasFillMode_t(char c)
118{
119 switch (c) {
120 case 'u':
121 case 'U': {
122 return acc::blas_api::fill::Upper;
123 }
124 case 'l':
125 case 'L': {
126 return acc::blas_api::fill::Lower;
127 }
128 default: {
129 throw std::runtime_error("get_gpublasFillMode_t(): wrong mode");
130 }
131 }
132 return acc::blas_api::fill::Upper; // make compiler happy
133}
134
135inline acc::blas_api::diagonal_t
136get_gpublasDiagonal_t(char c)
137{
138 switch (c) {
139 case 'n':
140 case 'N': {
141 return acc::blas_api::diagonal::NonUnit;
142 }
143 case 'u':
144 case 'U': {
145 return acc::blas_api::diagonal::Unit;
146 }
147 default: {
148 throw std::runtime_error("get_gpublasDiagonal_t(): wrong diagonal type");
149 }
150 }
151 return acc::blas_api::diagonal::NonUnit; // make compiler happy
152}
153
154#define CALL_GPU_BLAS(func__, args__) \
155 { \
156 acc::blas_api::status_t status; \
157 if ((status = func__ args__) != acc::blas_api::status::Success) { \
158 error_message(status); \
159 char nm[1024]; \
160 gethostname(nm, 1024); \
161 std::printf("hostname: %s\n", nm); \
162 std::printf("Error in %s at line %i of file %s\n", #func__, __LINE__, __FILE__); \
163 acc::stack_backtrace(); \
164 } \
165 }
166
167/// Store the default (null) stream handler.
168acc::blas_api::handle_t& null_stream_handle();
169
170/// Store the gpublas handlers associated with acc streams.
171std::vector<acc::blas_api::handle_t>& stream_handles();
172
173inline void
174create_stream_handles()
175{
176 // acc::set_device();
177 CALL_GPU_BLAS(acc::blas_api::create, (&null_stream_handle()));
178
179 stream_handles() = std::vector<acc::blas_api::handle_t>(acc::num_streams());
180 for (int i = 0; i < acc::num_streams(); i++) {
181 CALL_GPU_BLAS(acc::blas_api::create, (&stream_handles()[i]));
182
183 CALL_GPU_BLAS(acc::blas_api::set_stream, (stream_handles()[i], acc::stream(acc::stream_id(i))));
184 }
185}
186
187inline void
188destroy_stream_handles()
189{
190 // acc::set_device();
191 CALL_GPU_BLAS(acc::blas_api::destroy, (null_stream_handle()));
192 for (int i = 0; i < acc::num_streams(); i++) {
193 CALL_GPU_BLAS(acc::blas_api::destroy, (stream_handles()[i]));
194 }
195}
196
197inline acc::blas_api::handle_t
198stream_handle(int id__)
199{
200 return (id__ == -1) ? null_stream_handle() : stream_handles()[id__];
201}
202
203inline void
204zgemv(char transa, int32_t m, int32_t n, acc_complex_double_t* alpha, acc_complex_double_t* a, int32_t lda,
205 acc_complex_double_t* x, int32_t incx, acc_complex_double_t* beta, acc_complex_double_t* y, int32_t incy,
206 int stream_id)
207{
208 // acc::set_device();
209 CALL_GPU_BLAS(acc::blas_api::zgemv, (stream_handle(stream_id), get_gpublasOperation_t(transa), m, n,
210 reinterpret_cast<const acc::blas_api::complex_double_t*>(alpha),
211 reinterpret_cast<const acc::blas_api::complex_double_t*>(a), lda,
212 reinterpret_cast<const acc::blas_api::complex_double_t*>(x), incx,
213 reinterpret_cast<const acc::blas_api::complex_double_t*>(beta),
214 reinterpret_cast<acc::blas_api::complex_double_t*>(y), incy));
215}
216
217inline void
218cgemm(char transa, char transb, int32_t m, int32_t n, int32_t k, acc_complex_float_t const* alpha,
219 acc_complex_float_t const* a, int32_t lda, acc_complex_float_t const* b, int32_t ldb,
220 acc_complex_float_t const* beta, acc_complex_float_t* c, int32_t ldc, int stream_id)
221{
222 // acc::set_device();
223 CALL_GPU_BLAS(acc::blas_api::cgemm,
224 (stream_handle(stream_id), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m, n, k,
225 reinterpret_cast<const acc::blas_api::complex_float_t*>(alpha),
226 reinterpret_cast<const acc::blas_api::complex_float_t*>(a), lda,
227 reinterpret_cast<const acc::blas_api::complex_float_t*>(b), ldb,
228 reinterpret_cast<const acc::blas_api::complex_float_t*>(beta),
229 reinterpret_cast<acc::blas_api::complex_float_t*>(c), ldc));
230}
231
232inline void
233zgemm(char transa, char transb, int32_t m, int32_t n, int32_t k, acc_complex_double_t const* alpha,
234 acc_complex_double_t const* a, int32_t lda, acc_complex_double_t const* b, int32_t ldb,
235 acc_complex_double_t const* beta, acc_complex_double_t* c, int32_t ldc, int stream_id)
236{
237 // acc::set_device();
238 CALL_GPU_BLAS(acc::blas_api::zgemm,
239 (stream_handle(stream_id), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m, n, k,
240 reinterpret_cast<const acc::blas_api::complex_double_t*>(alpha),
241 reinterpret_cast<const acc::blas_api::complex_double_t*>(a), lda,
242 reinterpret_cast<const acc::blas_api::complex_double_t*>(b), ldb,
243 reinterpret_cast<const acc::blas_api::complex_double_t*>(beta),
244 reinterpret_cast<acc::blas_api::complex_double_t*>(c), ldc));
245}
246
247inline void
248sgemm(char transa, char transb, int32_t m, int32_t n, int32_t k, float const* alpha, float const* a, int32_t lda,
249 float const* b, int32_t ldb, float const* beta, float* c, int32_t ldc, int stream_id)
250{
251 // acc::set_device();
252 CALL_GPU_BLAS(acc::blas_api::sgemm, (stream_handle(stream_id), get_gpublasOperation_t(transa),
253 get_gpublasOperation_t(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
254}
255
256inline void
257dgemm(char transa, char transb, int32_t m, int32_t n, int32_t k, double const* alpha, double const* a, int32_t lda,
258 double const* b, int32_t ldb, double const* beta, double* c, int32_t ldc, int stream_id)
259{
260 // acc::set_device();
261 CALL_GPU_BLAS(acc::blas_api::dgemm, (stream_handle(stream_id), get_gpublasOperation_t(transa),
262 get_gpublasOperation_t(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
263}
264
265inline void
266strmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, float const* alpha__, float const* A__,
267 int lda__, float* B__, int ldb__, int stream_id)
268{
269 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
270 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
271 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
272 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
273 // acc::set_device();
274#ifdef SIRIUS_CUDA
275 CALL_GPU_BLAS(acc::blas_api::strmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__,
276 lda__, B__, ldb__, B__, ldb__));
277#else
278 // rocblas trmm function does not take three matrices
279 CALL_GPU_BLAS(acc::blas_api::strmm,
280 (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__));
281#endif
282}
283
284inline void
285dtrmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, double const* alpha__, double const* A__,
286 int lda__, double* B__, int ldb__, int stream_id)
287{
288 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
289 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
290 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
291 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
292 // acc::set_device();
293#ifdef SIRIUS_CUDA
294 CALL_GPU_BLAS(acc::blas_api::dtrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__,
295 lda__, B__, ldb__, B__, ldb__));
296#else
297 // rocblas trmm function does not take three matrices
298 CALL_GPU_BLAS(acc::blas_api::dtrmm,
299 (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__));
300#endif
301}
302
303inline void
304ctrmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, acc_complex_float_t const* alpha__,
305 acc_complex_float_t const* A__, int lda__, acc_complex_float_t* B__, int ldb__, int stream_id)
306{
307 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
308 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
309 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
310 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
311 // acc::set_device();
312#ifdef SIRIUS_CUDA
313 CALL_GPU_BLAS(acc::blas_api::ctrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
314 reinterpret_cast<const acc::blas_api::complex_float_t*>(alpha__),
315 reinterpret_cast<const acc::blas_api::complex_float_t*>(A__), lda__,
316 reinterpret_cast<acc::blas_api::complex_float_t*>(B__), ldb__,
317 reinterpret_cast<acc::blas_api::complex_float_t*>(B__), ldb__));
318#else
319 // rocblas trmm function does not take three matrices
320 CALL_GPU_BLAS(acc::blas_api::ctrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
321 reinterpret_cast<const acc::blas_api::complex_float_t*>(alpha__),
322 reinterpret_cast<const acc::blas_api::complex_float_t*>(A__), lda__,
323 reinterpret_cast<acc::blas_api::complex_float_t*>(B__), ldb__));
324#endif
325}
326
327inline void
328ztrmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, acc_complex_double_t const* alpha__,
329 acc_complex_double_t const* A__, int lda__, acc_complex_double_t* B__, int ldb__, int stream_id)
330{
331 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
332 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
333 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
334 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
335 // acc::set_device();
336#ifdef SIRIUS_CUDA
337 CALL_GPU_BLAS(acc::blas_api::ztrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
338 reinterpret_cast<const acc::blas_api::complex_double_t*>(alpha__),
339 reinterpret_cast<const acc::blas_api::complex_double_t*>(A__), lda__,
340 reinterpret_cast<acc::blas_api::complex_double_t*>(B__), ldb__,
341 reinterpret_cast<acc::blas_api::complex_double_t*>(B__), ldb__));
342#else
343 // rocblas trmm function does not take three matrices
344 CALL_GPU_BLAS(acc::blas_api::ztrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
345 reinterpret_cast<const acc::blas_api::complex_double_t*>(alpha__),
346 reinterpret_cast<const acc::blas_api::complex_double_t*>(A__), lda__,
347 reinterpret_cast<acc::blas_api::complex_double_t*>(B__), ldb__));
348#endif
349}
350
351inline void
352sger(int m, int n, float const* alpha, float const* x, int incx, float const* y, int incy, float* A, int lda,
353 int stream_id)
354{
355 // acc::set_device();
356 CALL_GPU_BLAS(acc::blas_api::sger, (stream_handle(stream_id), m, n, alpha, x, incx, y, incy, A, lda));
357}
358
359inline void
360dger(int m, int n, double const* alpha, double const* x, int incx, double const* y, int incy, double* A, int lda,
361 int stream_id)
362{
363 // acc::set_device();
364 CALL_GPU_BLAS(acc::blas_api::dger, (stream_handle(stream_id), m, n, alpha, x, incx, y, incy, A, lda));
365}
366
367inline void
368cgeru(int m, int n, acc_complex_float_t const* alpha, acc_complex_float_t const* x, int incx,
369 acc_complex_float_t const* y, int incy, acc_complex_float_t* A, int lda, int stream_id)
370{
371 // acc::set_device();
372 CALL_GPU_BLAS(acc::blas_api::cgeru,
373 (stream_handle(stream_id), m, n, reinterpret_cast<const acc::blas_api::complex_float_t*>(alpha),
374 reinterpret_cast<const acc::blas_api::complex_float_t*>(x), incx,
375 reinterpret_cast<const acc::blas_api::complex_float_t*>(y), incy,
376 reinterpret_cast<acc::blas_api::complex_float_t*>(A), lda));
377}
378
379inline void
380zgeru(int m, int n, acc_complex_double_t const* alpha, acc_complex_double_t const* x, int incx,
381 acc_complex_double_t const* y, int incy, acc_complex_double_t* A, int lda, int stream_id)
382{
383 // acc::set_device();
384 CALL_GPU_BLAS(acc::blas_api::zgeru,
385 (stream_handle(stream_id), m, n, reinterpret_cast<const acc::blas_api::complex_double_t*>(alpha),
386 reinterpret_cast<const acc::blas_api::complex_double_t*>(x), incx,
387 reinterpret_cast<const acc::blas_api::complex_double_t*>(y), incy,
388 reinterpret_cast<acc::blas_api::complex_double_t*>(A), lda));
389}
390
391inline void
392zaxpy(int n__, acc_complex_double_t const* alpha__, acc_complex_double_t const* x__, int incx__,
393 acc_complex_double_t* y__, int incy__)
394{
395 // acc::set_device();
396 CALL_GPU_BLAS(acc::blas_api::zaxpy,
397 (null_stream_handle(), n__, reinterpret_cast<const acc::blas_api::complex_double_t*>(alpha__),
398 reinterpret_cast<const acc::blas_api::complex_double_t*>(x__), incx__,
399 reinterpret_cast<acc::blas_api::complex_double_t*>(y__), incy__));
400}
401
402inline void
403dscal(int n__, double const* alpha__, double * x__, int incx__)
404{
405 // acc::set_device();
406 CALL_GPU_BLAS(acc::blas_api::dscal,
407 (null_stream_handle(), n__, alpha__, x__, incx__));
408}
409
410inline void
411sscal(int n__, float const* alpha__, float * x__, int incx__)
412{
413 // acc::set_device();
414 CALL_GPU_BLAS(acc::blas_api::sscal,
415 (null_stream_handle(), n__, alpha__, x__, incx__));
416}
417
418#if defined(SIRIUS_CUDA)
419/// Interface to cuBlasXt functions
420namespace xt {
421
422cublasXtHandle_t& cublasxt_handle();
423
424inline void
425create_handle()
426{
427 int device_id[1];
428 device_id[0] = acc::get_device_id();
429 CALL_GPU_BLAS(cublasXtCreate, (&cublasxt_handle()));
430 CALL_GPU_BLAS(cublasXtDeviceSelect, (cublasxt_handle(), 1, device_id));
431 CALL_GPU_BLAS(cublasXtSetBlockDim, (cublasxt_handle(), 4096));
432}
433
434inline void
435destroy_handle()
436{
437 CALL_GPU_BLAS(cublasXtDestroy, (cublasxt_handle()));
438}
439
440inline void
441cgemm(char transa, char transb, int32_t m, int32_t n, int32_t k, acc_complex_float_t const* alpha,
442 acc_complex_float_t const* a, int32_t lda, acc_complex_float_t const* b, int32_t ldb,
443 acc_complex_float_t const* beta, acc_complex_float_t* c, int32_t ldc)
444{
445 // acc::set_device();
446 CALL_GPU_BLAS(cublasXtCgemm, (cublasxt_handle(), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m,
447 n, k, alpha, a, lda, b, ldb, beta, c, ldc));
448}
449
450inline void
451zgemm(char transa, char transb, int32_t m, int32_t n, int32_t k, acc_complex_double_t const* alpha,
452 acc_complex_double_t const* a, int32_t lda, acc_complex_double_t const* b, int32_t ldb,
453 acc_complex_double_t const* beta, acc_complex_double_t* c, int32_t ldc)
454{
455 // acc::set_device();
456 CALL_GPU_BLAS(cublasXtZgemm, (cublasxt_handle(), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m,
457 n, k, alpha, a, lda, b, ldb, beta, c, ldc));
458}
459
460inline void
461sgemm(char transa, char transb, int32_t m, int32_t n, int32_t k, float const* alpha, float const* a, int32_t lda,
462 float const* b, int32_t ldb, float const* beta, float* c, int32_t ldc)
463{
464 // acc::set_device();
465 CALL_GPU_BLAS(cublasXtSgemm, (cublasxt_handle(), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m,
466 n, k, alpha, a, lda, b, ldb, beta, c, ldc));
467}
468
469inline void
470dgemm(char transa, char transb, int32_t m, int32_t n, int32_t k, double const* alpha, double const* a, int32_t lda,
471 double const* b, int32_t ldb, double const* beta, double* c, int32_t ldc)
472{
473 // acc::set_device();
474 CALL_GPU_BLAS(cublasXtDgemm, (cublasxt_handle(), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m,
475 n, k, alpha, a, lda, b, ldb, beta, c, ldc));
476}
477
478inline void
479strmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, float const* alpha__, float const* A__,
480 int lda__, float* B__, int ldb__)
481{
482 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
483 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
484 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
485 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
486 // acc::set_device();
487 CALL_GPU_BLAS(cublasXtStrmm,
488 (cublasxt_handle(), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__, B__, ldb__));
489}
490
491inline void
492dtrmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, double const* alpha__, double const* A__,
493 int lda__, double* B__, int ldb__)
494{
495 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
496 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
497 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
498 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
499 // acc::set_device();
500 CALL_GPU_BLAS(cublasXtDtrmm,
501 (cublasxt_handle(), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__, B__, ldb__));
502}
503
504inline void
505ctrmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, acc_complex_float_t const* alpha__,
506 acc_complex_float_t const* A__, int lda__, acc_complex_float_t* B__, int ldb__)
507{
508 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
509 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
510 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
511 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
512 // acc::set_device();
513 CALL_GPU_BLAS(cublasXtCtrmm,
514 (cublasxt_handle(), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__, B__, ldb__));
515}
516
517inline void
518ztrmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, acc_complex_double_t const* alpha__,
519 acc_complex_double_t const* A__, int lda__, acc_complex_double_t* B__, int ldb__)
520{
521 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
522 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
523 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
524 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
525 // acc::set_device();
526 CALL_GPU_BLAS(cublasXtZtrmm,
527 (cublasxt_handle(), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__, B__, ldb__));
528}
529
530} // namespace xt
531#endif
532
533} // namespace blas
534
535} // namespace acc
536
537} // namespace sirius
538
539#endif
Interface to accelerators API.
Interface to cuBLAS / rocblas related functions.
acc::blas_api::handle_t & null_stream_handle()
Store the default (null) stream handler.
Definition: acc_blas.cpp:11
std::vector< acc::blas_api::handle_t > & stream_handles()
Store the gpublas handlers associated with acc streams.
Definition: acc_blas.cpp:18
int get_device_id()
Get current device ID.
Definition: acc.hpp:191
acc_stream_t stream(stream_id sid__)
Return a single device stream.
Definition: acc.hpp:202
int num_streams()
Get number of streams.
Definition: acc.hpp:209
Namespace of the SIRIUS library.
Definition: sirius.f90:5