25#ifndef __ACC_BLAS_HPP__
26#define __ACC_BLAS_HPP__
42error_message(acc::blas_api::status_t status)
45 case CUBLAS_STATUS_NOT_INITIALIZED: {
46 return "the library was not initialized";
49 case CUBLAS_STATUS_INVALID_VALUE: {
50 return "the parameters m,n,k<0";
53 case CUBLAS_STATUS_ARCH_MISMATCH: {
54 return "the device does not support double-precision";
57 case CUBLAS_STATUS_EXECUTION_FAILED: {
58 return "the function failed to launch on the GPU";
62 return "gpublas status unknown";
68error_message(acc::blas_api::status_t status)
70 return rocblas_status_to_string(status);
74inline acc::blas_api::operation_t
75get_gpublasOperation_t(
char c)
80 return acc::blas_api::operation::None;
84 return acc::blas_api::operation::Transpose;
88 return acc::blas_api::operation::ConjugateTranspose;
91 throw std::runtime_error(
"get_gpublasOperation_t(): wrong operation");
94 return acc::blas_api::operation::None;
97inline acc::blas_api::side_mode_t
98get_gpublasSideMode_t(
char c)
103 return acc::blas_api::side::Left;
107 return acc::blas_api::side::Right;
110 throw std::runtime_error(
"get_gpublasSideMode_t(): wrong side");
113 return acc::blas_api::side::Left;
116inline acc::blas_api::fill_mode_t
117get_gpublasFillMode_t(
char c)
122 return acc::blas_api::fill::Upper;
126 return acc::blas_api::fill::Lower;
129 throw std::runtime_error(
"get_gpublasFillMode_t(): wrong mode");
132 return acc::blas_api::fill::Upper;
135inline acc::blas_api::diagonal_t
136get_gpublasDiagonal_t(
char c)
141 return acc::blas_api::diagonal::NonUnit;
145 return acc::blas_api::diagonal::Unit;
148 throw std::runtime_error(
"get_gpublasDiagonal_t(): wrong diagonal type");
151 return acc::blas_api::diagonal::NonUnit;
154#define CALL_GPU_BLAS(func__, args__) \
156 acc::blas_api::status_t status; \
157 if ((status = func__ args__) != acc::blas_api::status::Success) { \
158 error_message(status); \
160 gethostname(nm, 1024); \
161 std::printf("hostname: %s\n", nm); \
162 std::printf("Error in %s at line %i of file %s\n", #func__, __LINE__, __FILE__); \
163 acc::stack_backtrace(); \
174create_stream_handles()
188destroy_stream_handles()
197inline acc::blas_api::handle_t
198stream_handle(
int id__)
204zgemv(
char transa, int32_t m, int32_t n, acc_complex_double_t* alpha, acc_complex_double_t* a, int32_t lda,
205 acc_complex_double_t* x, int32_t incx, acc_complex_double_t* beta, acc_complex_double_t* y, int32_t incy,
209 CALL_GPU_BLAS(acc::blas_api::zgemv, (stream_handle(stream_id), get_gpublasOperation_t(transa), m, n,
210 reinterpret_cast<const acc::blas_api::complex_double_t*
>(alpha),
211 reinterpret_cast<const acc::blas_api::complex_double_t*
>(a), lda,
212 reinterpret_cast<const acc::blas_api::complex_double_t*
>(x), incx,
213 reinterpret_cast<const acc::blas_api::complex_double_t*
>(beta),
214 reinterpret_cast<acc::blas_api::complex_double_t*
>(y), incy));
218cgemm(
char transa,
char transb, int32_t m, int32_t n, int32_t k, acc_complex_float_t
const* alpha,
219 acc_complex_float_t
const* a, int32_t lda, acc_complex_float_t
const* b, int32_t ldb,
220 acc_complex_float_t
const* beta, acc_complex_float_t* c, int32_t ldc,
int stream_id)
223 CALL_GPU_BLAS(acc::blas_api::cgemm,
224 (stream_handle(stream_id), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m, n, k,
225 reinterpret_cast<const acc::blas_api::complex_float_t*
>(alpha),
226 reinterpret_cast<const acc::blas_api::complex_float_t*
>(a), lda,
227 reinterpret_cast<const acc::blas_api::complex_float_t*
>(b), ldb,
228 reinterpret_cast<const acc::blas_api::complex_float_t*
>(beta),
229 reinterpret_cast<acc::blas_api::complex_float_t*
>(c), ldc));
233zgemm(
char transa,
char transb, int32_t m, int32_t n, int32_t k, acc_complex_double_t
const* alpha,
234 acc_complex_double_t
const* a, int32_t lda, acc_complex_double_t
const* b, int32_t ldb,
235 acc_complex_double_t
const* beta, acc_complex_double_t* c, int32_t ldc,
int stream_id)
238 CALL_GPU_BLAS(acc::blas_api::zgemm,
239 (stream_handle(stream_id), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m, n, k,
240 reinterpret_cast<const acc::blas_api::complex_double_t*
>(alpha),
241 reinterpret_cast<const acc::blas_api::complex_double_t*
>(a), lda,
242 reinterpret_cast<const acc::blas_api::complex_double_t*
>(b), ldb,
243 reinterpret_cast<const acc::blas_api::complex_double_t*
>(beta),
244 reinterpret_cast<acc::blas_api::complex_double_t*
>(c), ldc));
248sgemm(
char transa,
char transb, int32_t m, int32_t n, int32_t k,
float const* alpha,
float const* a, int32_t lda,
249 float const* b, int32_t ldb,
float const* beta,
float* c, int32_t ldc,
int stream_id)
252 CALL_GPU_BLAS(acc::blas_api::sgemm, (stream_handle(stream_id), get_gpublasOperation_t(transa),
253 get_gpublasOperation_t(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
257dgemm(
char transa,
char transb, int32_t m, int32_t n, int32_t k,
double const* alpha,
double const* a, int32_t lda,
258 double const* b, int32_t ldb,
double const* beta,
double* c, int32_t ldc,
int stream_id)
261 CALL_GPU_BLAS(acc::blas_api::dgemm, (stream_handle(stream_id), get_gpublasOperation_t(transa),
262 get_gpublasOperation_t(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
266strmm(
char side__,
char uplo__,
char transa__,
char diag__,
int m__,
int n__,
float const* alpha__,
float const* A__,
267 int lda__,
float* B__,
int ldb__,
int stream_id)
269 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
270 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
271 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
272 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
275 CALL_GPU_BLAS(acc::blas_api::strmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__,
276 lda__, B__, ldb__, B__, ldb__));
279 CALL_GPU_BLAS(acc::blas_api::strmm,
280 (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__));
285dtrmm(
char side__,
char uplo__,
char transa__,
char diag__,
int m__,
int n__,
double const* alpha__,
double const* A__,
286 int lda__,
double* B__,
int ldb__,
int stream_id)
288 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
289 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
290 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
291 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
294 CALL_GPU_BLAS(acc::blas_api::dtrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__,
295 lda__, B__, ldb__, B__, ldb__));
298 CALL_GPU_BLAS(acc::blas_api::dtrmm,
299 (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__));
304ctrmm(
char side__,
char uplo__,
char transa__,
char diag__,
int m__,
int n__, acc_complex_float_t
const* alpha__,
305 acc_complex_float_t
const* A__,
int lda__, acc_complex_float_t* B__,
int ldb__,
int stream_id)
307 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
308 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
309 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
310 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
313 CALL_GPU_BLAS(acc::blas_api::ctrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
314 reinterpret_cast<const acc::blas_api::complex_float_t*
>(alpha__),
315 reinterpret_cast<const acc::blas_api::complex_float_t*
>(A__), lda__,
316 reinterpret_cast<acc::blas_api::complex_float_t*
>(B__), ldb__,
317 reinterpret_cast<acc::blas_api::complex_float_t*
>(B__), ldb__));
320 CALL_GPU_BLAS(acc::blas_api::ctrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
321 reinterpret_cast<const acc::blas_api::complex_float_t*
>(alpha__),
322 reinterpret_cast<const acc::blas_api::complex_float_t*
>(A__), lda__,
323 reinterpret_cast<acc::blas_api::complex_float_t*
>(B__), ldb__));
328ztrmm(
char side__,
char uplo__,
char transa__,
char diag__,
int m__,
int n__, acc_complex_double_t
const* alpha__,
329 acc_complex_double_t
const* A__,
int lda__, acc_complex_double_t* B__,
int ldb__,
int stream_id)
331 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
332 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
333 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
334 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
337 CALL_GPU_BLAS(acc::blas_api::ztrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
338 reinterpret_cast<const acc::blas_api::complex_double_t*
>(alpha__),
339 reinterpret_cast<const acc::blas_api::complex_double_t*
>(A__), lda__,
340 reinterpret_cast<acc::blas_api::complex_double_t*
>(B__), ldb__,
341 reinterpret_cast<acc::blas_api::complex_double_t*
>(B__), ldb__));
344 CALL_GPU_BLAS(acc::blas_api::ztrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
345 reinterpret_cast<const acc::blas_api::complex_double_t*
>(alpha__),
346 reinterpret_cast<const acc::blas_api::complex_double_t*
>(A__), lda__,
347 reinterpret_cast<acc::blas_api::complex_double_t*
>(B__), ldb__));
352sger(
int m,
int n,
float const* alpha,
float const* x,
int incx,
float const* y,
int incy,
float* A,
int lda,
356 CALL_GPU_BLAS(acc::blas_api::sger, (stream_handle(stream_id), m, n, alpha, x, incx, y, incy, A, lda));
360dger(
int m,
int n,
double const* alpha,
double const* x,
int incx,
double const* y,
int incy,
double* A,
int lda,
364 CALL_GPU_BLAS(acc::blas_api::dger, (stream_handle(stream_id), m, n, alpha, x, incx, y, incy, A, lda));
368cgeru(
int m,
int n, acc_complex_float_t
const* alpha, acc_complex_float_t
const* x,
int incx,
369 acc_complex_float_t
const* y,
int incy, acc_complex_float_t* A,
int lda,
int stream_id)
372 CALL_GPU_BLAS(acc::blas_api::cgeru,
373 (stream_handle(stream_id), m, n,
reinterpret_cast<const acc::blas_api::complex_float_t*
>(alpha),
374 reinterpret_cast<const acc::blas_api::complex_float_t*
>(x), incx,
375 reinterpret_cast<const acc::blas_api::complex_float_t*
>(y), incy,
376 reinterpret_cast<acc::blas_api::complex_float_t*
>(A), lda));
380zgeru(
int m,
int n, acc_complex_double_t
const* alpha, acc_complex_double_t
const* x,
int incx,
381 acc_complex_double_t
const* y,
int incy, acc_complex_double_t* A,
int lda,
int stream_id)
384 CALL_GPU_BLAS(acc::blas_api::zgeru,
385 (stream_handle(stream_id), m, n,
reinterpret_cast<const acc::blas_api::complex_double_t*
>(alpha),
386 reinterpret_cast<const acc::blas_api::complex_double_t*
>(x), incx,
387 reinterpret_cast<const acc::blas_api::complex_double_t*
>(y), incy,
388 reinterpret_cast<acc::blas_api::complex_double_t*
>(A), lda));
392zaxpy(
int n__, acc_complex_double_t
const* alpha__, acc_complex_double_t
const* x__,
int incx__,
393 acc_complex_double_t* y__,
int incy__)
396 CALL_GPU_BLAS(acc::blas_api::zaxpy,
397 (
null_stream_handle(), n__,
reinterpret_cast<const acc::blas_api::complex_double_t*
>(alpha__),
398 reinterpret_cast<const acc::blas_api::complex_double_t*
>(x__), incx__,
399 reinterpret_cast<acc::blas_api::complex_double_t*
>(y__), incy__));
403dscal(
int n__,
double const* alpha__,
double * x__,
int incx__)
406 CALL_GPU_BLAS(acc::blas_api::dscal,
411sscal(
int n__,
float const* alpha__,
float * x__,
int incx__)
414 CALL_GPU_BLAS(acc::blas_api::sscal,
418#if defined(SIRIUS_CUDA)
422cublasXtHandle_t& cublasxt_handle();
429 CALL_GPU_BLAS(cublasXtCreate, (&cublasxt_handle()));
430 CALL_GPU_BLAS(cublasXtDeviceSelect, (cublasxt_handle(), 1, device_id));
431 CALL_GPU_BLAS(cublasXtSetBlockDim, (cublasxt_handle(), 4096));
437 CALL_GPU_BLAS(cublasXtDestroy, (cublasxt_handle()));
441cgemm(
char transa,
char transb, int32_t m, int32_t n, int32_t k, acc_complex_float_t
const* alpha,
442 acc_complex_float_t
const* a, int32_t lda, acc_complex_float_t
const* b, int32_t ldb,
443 acc_complex_float_t
const* beta, acc_complex_float_t* c, int32_t ldc)
446 CALL_GPU_BLAS(cublasXtCgemm, (cublasxt_handle(), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m,
447 n, k, alpha, a, lda, b, ldb, beta, c, ldc));
451zgemm(
char transa,
char transb, int32_t m, int32_t n, int32_t k, acc_complex_double_t
const* alpha,
452 acc_complex_double_t
const* a, int32_t lda, acc_complex_double_t
const* b, int32_t ldb,
453 acc_complex_double_t
const* beta, acc_complex_double_t* c, int32_t ldc)
456 CALL_GPU_BLAS(cublasXtZgemm, (cublasxt_handle(), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m,
457 n, k, alpha, a, lda, b, ldb, beta, c, ldc));
461sgemm(
char transa,
char transb, int32_t m, int32_t n, int32_t k,
float const* alpha,
float const* a, int32_t lda,
462 float const* b, int32_t ldb,
float const* beta,
float* c, int32_t ldc)
465 CALL_GPU_BLAS(cublasXtSgemm, (cublasxt_handle(), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m,
466 n, k, alpha, a, lda, b, ldb, beta, c, ldc));
470dgemm(
char transa,
char transb, int32_t m, int32_t n, int32_t k,
double const* alpha,
double const* a, int32_t lda,
471 double const* b, int32_t ldb,
double const* beta,
double* c, int32_t ldc)
474 CALL_GPU_BLAS(cublasXtDgemm, (cublasxt_handle(), get_gpublasOperation_t(transa), get_gpublasOperation_t(transb), m,
475 n, k, alpha, a, lda, b, ldb, beta, c, ldc));
479strmm(
char side__,
char uplo__,
char transa__,
char diag__,
int m__,
int n__,
float const* alpha__,
float const* A__,
480 int lda__,
float* B__,
int ldb__)
482 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
483 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
484 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
485 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
487 CALL_GPU_BLAS(cublasXtStrmm,
488 (cublasxt_handle(), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__, B__, ldb__));
492dtrmm(
char side__,
char uplo__,
char transa__,
char diag__,
int m__,
int n__,
double const* alpha__,
double const* A__,
493 int lda__,
double* B__,
int ldb__)
495 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
496 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
497 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
498 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
500 CALL_GPU_BLAS(cublasXtDtrmm,
501 (cublasxt_handle(), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__, B__, ldb__));
505ctrmm(
char side__,
char uplo__,
char transa__,
char diag__,
int m__,
int n__, acc_complex_float_t
const* alpha__,
506 acc_complex_float_t
const* A__,
int lda__, acc_complex_float_t* B__,
int ldb__)
508 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
509 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
510 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
511 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
513 CALL_GPU_BLAS(cublasXtCtrmm,
514 (cublasxt_handle(), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__, B__, ldb__));
518ztrmm(
char side__,
char uplo__,
char transa__,
char diag__,
int m__,
int n__, acc_complex_double_t
const* alpha__,
519 acc_complex_double_t
const* A__,
int lda__, acc_complex_double_t* B__,
int ldb__)
521 acc::blas_api::side_mode_t side = get_gpublasSideMode_t(side__);
522 acc::blas_api::fill_mode_t uplo = get_gpublasFillMode_t(uplo__);
523 acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
524 acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
526 CALL_GPU_BLAS(cublasXtZtrmm,
527 (cublasxt_handle(), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__, B__, ldb__));
Interface to accelerators API.
Interface to cuBLAS / rocblas related functions.
acc::blas_api::handle_t & null_stream_handle()
Store the default (null) stream handler.
std::vector< acc::blas_api::handle_t > & stream_handles()
Store the gpublas handlers associated with acc streams.
int get_device_id()
Get current device ID.
acc_stream_t stream(stream_id sid__)
Return a single device stream.
int num_streams()
Get number of streams.
Namespace of the SIRIUS library.