25#ifndef __ACC_BLAS_API_HPP__
26#define __ACC_BLAS_API_HPP__
30#if defined(SIRIUS_CUDA)
33#elif defined(SIRIUS_ROCM)
37#error Either SIRIUS_CUDA or SIRIUS_ROCM must be defined!
46#if defined(SIRIUS_CUDA)
47using handle_t = cublasHandle_t;
48using status_t = cublasStatus_t;
49using operation_t = cublasOperation_t;
50using side_mode_t = cublasSideMode_t;
51using diagonal_t = cublasDiagType_t;
52using fill_mode_t = cublasFillMode_t;
53using complex_float_t = cuComplex;
54using complex_double_t = cuDoubleComplex;
57#if defined(SIRIUS_ROCM)
58using handle_t = rocblas_handle;
59using status_t = rocblas_status;
60using operation_t = rocblas_operation;
61using side_mode_t = rocblas_side;
62using diagonal_t = rocblas_diagonal;
63using fill_mode_t = rocblas_fill;
64using complex_float_t = rocblas_float_complex;
65using complex_double_t = rocblas_double_complex;
69#if defined(SIRIUS_CUDA)
70constexpr auto None = CUBLAS_OP_N;
71constexpr auto Transpose = CUBLAS_OP_T;
72constexpr auto ConjugateTranspose = CUBLAS_OP_C;
75#if defined(SIRIUS_ROCM)
76constexpr auto None = rocblas_operation_none;
77constexpr auto Transpose = rocblas_operation_transpose;
78constexpr auto ConjugateTranspose = rocblas_operation_conjugate_transpose;
83#if defined(SIRIUS_CUDA)
84constexpr auto Left = CUBLAS_SIDE_LEFT;
85constexpr auto Right = CUBLAS_SIDE_RIGHT;
88#if defined(SIRIUS_ROCM)
89constexpr auto Left = rocblas_side_left;
90constexpr auto Right = rocblas_side_right;
95#if defined(SIRIUS_CUDA)
96constexpr auto NonUnit = CUBLAS_DIAG_NON_UNIT;
97constexpr auto Unit = CUBLAS_DIAG_UNIT;
100#if defined(SIRIUS_ROCM)
101constexpr auto NonUnit = rocblas_diagonal_non_unit;
102constexpr auto Unit = rocblas_diagonal_unit;
107#if defined(SIRIUS_CUDA)
108constexpr auto Upper = CUBLAS_FILL_MODE_UPPER;
109constexpr auto Lower = CUBLAS_FILL_MODE_LOWER;
112#if defined(SIRIUS_ROCM)
113constexpr auto Upper = rocblas_fill_upper;
114constexpr auto Lower = rocblas_fill_lower;
119#if defined(SIRIUS_CUDA)
120constexpr auto Success = CUBLAS_STATUS_SUCCESS;
123#if defined(SIRIUS_ROCM)
124constexpr auto Success = rocblas_status_success;
131template <
typename... ARGS>
132inline auto create(ARGS&&... args) -> status_t {
133#if defined(SIRIUS_ROCM)
134 return rocblas_create_handle(std::forward<ARGS>(args)...);
136 return cublasCreate(std::forward<ARGS>(args)...);
140template <
typename... ARGS>
141inline auto destroy(ARGS&&... args) -> status_t {
142#if defined(SIRIUS_ROCM)
143 return rocblas_destroy_handle(std::forward<ARGS>(args)...);
145 return cublasDestroy(std::forward<ARGS>(args)...);
149template <
typename... ARGS>
150inline auto set_stream(ARGS&&... args) -> status_t {
151#if defined(SIRIUS_ROCM)
152 return rocblas_set_stream(std::forward<ARGS>(args)...);
154 return cublasSetStream(std::forward<ARGS>(args)...);
158template <
typename... ARGS>
159inline auto get_stream(ARGS&&... args) -> status_t {
160#if defined(SIRIUS_ROCM)
161 return rocblas_get_stream(std::forward<ARGS>(args)...);
163 return cublasGetStream(std::forward<ARGS>(args)...);
168template <
typename... ARGS>
169inline auto sgemm(ARGS&&... args) -> status_t {
170#if defined(SIRIUS_ROCM)
171 return rocblas_sgemm(std::forward<ARGS>(args)...);
173 return cublasSgemm(std::forward<ARGS>(args)...);
177template <
typename... ARGS>
178inline auto dgemm(ARGS&&... args) -> status_t {
179#if defined(SIRIUS_ROCM)
180 return rocblas_dgemm(std::forward<ARGS>(args)...);
182 return cublasDgemm(std::forward<ARGS>(args)...);
186template <
typename... ARGS>
187inline auto cgemm(ARGS&&... args) -> status_t {
188#if defined(SIRIUS_ROCM)
189 return rocblas_cgemm(std::forward<ARGS>(args)...);
191 return cublasCgemm(std::forward<ARGS>(args)...);
195template <
typename... ARGS>
196inline auto zgemm(ARGS&&... args) -> status_t {
197#if defined(SIRIUS_ROCM)
198 return rocblas_zgemm(std::forward<ARGS>(args)...);
200 return cublasZgemm(std::forward<ARGS>(args)...);
204template <
typename... ARGS>
205inline auto dgemv(ARGS&&... args) -> status_t {
206#if defined(SIRIUS_ROCM)
207 return rocblas_dgemv(std::forward<ARGS>(args)...);
209 return cublasDgemv(std::forward<ARGS>(args)...);
213template <
typename... ARGS>
214inline auto zgemv(ARGS&&... args) -> status_t {
215#if defined(SIRIUS_ROCM)
216 return rocblas_zgemv(std::forward<ARGS>(args)...);
218 return cublasZgemv(std::forward<ARGS>(args)...);
222template <
typename... ARGS>
223inline auto strmm(ARGS&&... args) -> status_t {
224#if defined(SIRIUS_ROCM)
225 return rocblas_strmm(std::forward<ARGS>(args)...);
227 return cublasStrmm(std::forward<ARGS>(args)...);
231template <
typename... ARGS>
232inline auto dtrmm(ARGS&&... args) -> status_t {
233#if defined(SIRIUS_ROCM)
234 return rocblas_dtrmm(std::forward<ARGS>(args)...);
236 return cublasDtrmm(std::forward<ARGS>(args)...);
240template <
typename... ARGS>
241inline auto ctrmm(ARGS&&... args) -> status_t {
242#if defined(SIRIUS_ROCM)
243 return rocblas_ctrmm(std::forward<ARGS>(args)...);
245 return cublasCtrmm(std::forward<ARGS>(args)...);
249template <
typename... ARGS>
250inline auto ztrmm(ARGS&&... args) -> status_t {
251#if defined(SIRIUS_ROCM)
252 return rocblas_ztrmm(std::forward<ARGS>(args)...);
254 return cublasZtrmm(std::forward<ARGS>(args)...);
258template <
typename... ARGS>
259inline auto sger(ARGS&&... args) -> status_t {
260#if defined(SIRIUS_ROCM)
261 return rocblas_sger(std::forward<ARGS>(args)...);
263 return cublasSger(std::forward<ARGS>(args)...);
267template <
typename... ARGS>
268inline auto dger(ARGS&&... args) -> status_t {
269#if defined(SIRIUS_ROCM)
270 return rocblas_dger(std::forward<ARGS>(args)...);
272 return cublasDger(std::forward<ARGS>(args)...);
276template <
typename... ARGS>
277inline auto cgeru(ARGS&&... args) -> status_t {
278#if defined(SIRIUS_ROCM)
279 return rocblas_cgeru(std::forward<ARGS>(args)...);
281 return cublasCgeru(std::forward<ARGS>(args)...);
285template <
typename... ARGS>
286inline auto zgeru(ARGS&&... args) -> status_t {
287#if defined(SIRIUS_ROCM)
288 return rocblas_zgeru(std::forward<ARGS>(args)...);
290 return cublasZgeru(std::forward<ARGS>(args)...);
294template <
typename... ARGS>
295inline auto zaxpy(ARGS&&... args) -> status_t {
296#if defined(SIRIUS_ROCM)
297 return rocblas_zaxpy(std::forward<ARGS>(args)...);
299 return cublasZaxpy(std::forward<ARGS>(args)...);
303template <
typename... ARGS>
304inline auto dscal(ARGS&&... args) -> status_t {
305#if defined(SIRIUS_ROCM)
306 return rocblas_dscal(std::forward<ARGS>(args)...);
308 return cublasDscal(std::forward<ARGS>(args)...);
312template <
typename... ARGS>
313inline auto sscal(ARGS&&... args) -> status_t {
314#if defined(SIRIUS_ROCM)
315 return rocblas_sscal(std::forward<ARGS>(args)...);
317 return cublasSscal(std::forward<ARGS>(args)...);
Namespace of the SIRIUS library.