SIRIUS 7.5.0
Electronic structure library and applications
acc_blas_api.hpp
Go to the documentation of this file.
1// Copyright (c) 2013-2017 Anton Kozhevnikov, Thomas Schulthess
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted provided that
5// the following conditions are met:
6//
7// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
8// following disclaimer.
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
10// and the following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
13// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
15// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
16// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
17// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
18// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
20/** \file acc_blas_api.hpp
21 *
22 * \brief Interface to cuBLAS / rocblas related functions.
23 */
24
25#ifndef __ACC_BLAS_API_HPP__
26#define __ACC_BLAS_API_HPP__
27
28#include <utility>
29
30#if defined(SIRIUS_CUDA)
31#include <cublas_v2.h>
32
33#elif defined(SIRIUS_ROCM)
34#include <rocblas.h>
35
36#else
37#error Either SIRIUS_CUDA or SIRIUS_ROCM must be defined!
38#endif
39
40namespace sirius {
41
42namespace acc {
43/// Internal interface to accelerated BLAS functions (CUDA or ROCM).
44namespace blas_api {
45
46#if defined(SIRIUS_CUDA)
47using handle_t = cublasHandle_t;
48using status_t = cublasStatus_t;
49using operation_t = cublasOperation_t;
50using side_mode_t = cublasSideMode_t;
51using diagonal_t = cublasDiagType_t;
52using fill_mode_t = cublasFillMode_t;
53using complex_float_t = cuComplex;
54using complex_double_t = cuDoubleComplex;
55#endif
56
57#if defined(SIRIUS_ROCM)
58using handle_t = rocblas_handle;
59using status_t = rocblas_status;
60using operation_t = rocblas_operation;
61using side_mode_t = rocblas_side;
62using diagonal_t = rocblas_diagonal;
63using fill_mode_t = rocblas_fill;
64using complex_float_t = rocblas_float_complex;
65using complex_double_t = rocblas_double_complex;
66#endif
67
68namespace operation {
69#if defined(SIRIUS_CUDA)
70constexpr auto None = CUBLAS_OP_N;
71constexpr auto Transpose = CUBLAS_OP_T;
72constexpr auto ConjugateTranspose = CUBLAS_OP_C;
73#endif
74
75#if defined(SIRIUS_ROCM)
76constexpr auto None = rocblas_operation_none;
77constexpr auto Transpose = rocblas_operation_transpose;
78constexpr auto ConjugateTranspose = rocblas_operation_conjugate_transpose;
79#endif
80} // namespace operation
81
82namespace side {
83#if defined(SIRIUS_CUDA)
84constexpr auto Left = CUBLAS_SIDE_LEFT;
85constexpr auto Right = CUBLAS_SIDE_RIGHT;
86#endif
87
88#if defined(SIRIUS_ROCM)
89constexpr auto Left = rocblas_side_left;
90constexpr auto Right = rocblas_side_right;
91#endif
92} // namespace side
93
94namespace diagonal {
95#if defined(SIRIUS_CUDA)
96constexpr auto NonUnit = CUBLAS_DIAG_NON_UNIT;
97constexpr auto Unit = CUBLAS_DIAG_UNIT;
98#endif
99
100#if defined(SIRIUS_ROCM)
101constexpr auto NonUnit = rocblas_diagonal_non_unit;
102constexpr auto Unit = rocblas_diagonal_unit;
103#endif
104} // namespace diagonal
105
106namespace fill {
107#if defined(SIRIUS_CUDA)
108constexpr auto Upper = CUBLAS_FILL_MODE_UPPER;
109constexpr auto Lower = CUBLAS_FILL_MODE_LOWER;
110#endif
111
112#if defined(SIRIUS_ROCM)
113constexpr auto Upper = rocblas_fill_upper;
114constexpr auto Lower = rocblas_fill_lower;
115#endif
116} // namespace fill
117
118namespace status {
119#if defined(SIRIUS_CUDA)
120constexpr auto Success = CUBLAS_STATUS_SUCCESS;
121#endif
122
123#if defined(SIRIUS_ROCM)
124constexpr auto Success = rocblas_status_success;
125#endif
126} // namespace status
127
128// =======================================
129// Forwarding functions of to GPU BLAS API
130// =======================================
131template <typename... ARGS>
132inline auto create(ARGS&&... args) -> status_t {
133#if defined(SIRIUS_ROCM)
134 return rocblas_create_handle(std::forward<ARGS>(args)...);
135#else
136 return cublasCreate(std::forward<ARGS>(args)...);
137#endif
138}
139
140template <typename... ARGS>
141inline auto destroy(ARGS&&... args) -> status_t {
142#if defined(SIRIUS_ROCM)
143 return rocblas_destroy_handle(std::forward<ARGS>(args)...);
144#else
145 return cublasDestroy(std::forward<ARGS>(args)...);
146#endif
147}
148
149template <typename... ARGS>
150inline auto set_stream(ARGS&&... args) -> status_t {
151#if defined(SIRIUS_ROCM)
152 return rocblas_set_stream(std::forward<ARGS>(args)...);
153#else
154 return cublasSetStream(std::forward<ARGS>(args)...);
155#endif
156}
157
158template <typename... ARGS>
159inline auto get_stream(ARGS&&... args) -> status_t {
160#if defined(SIRIUS_ROCM)
161 return rocblas_get_stream(std::forward<ARGS>(args)...);
162#else
163 return cublasGetStream(std::forward<ARGS>(args)...);
164#endif
165}
166
167
168template <typename... ARGS>
169inline auto sgemm(ARGS&&... args) -> status_t {
170#if defined(SIRIUS_ROCM)
171 return rocblas_sgemm(std::forward<ARGS>(args)...);
172#else
173 return cublasSgemm(std::forward<ARGS>(args)...);
174#endif // SIRIUS_ROCM
175}
176
177template <typename... ARGS>
178inline auto dgemm(ARGS&&... args) -> status_t {
179#if defined(SIRIUS_ROCM)
180 return rocblas_dgemm(std::forward<ARGS>(args)...);
181#else
182 return cublasDgemm(std::forward<ARGS>(args)...);
183#endif // SIRIUS_ROCM
184}
185
186template <typename... ARGS>
187inline auto cgemm(ARGS&&... args) -> status_t {
188#if defined(SIRIUS_ROCM)
189 return rocblas_cgemm(std::forward<ARGS>(args)...);
190#else
191 return cublasCgemm(std::forward<ARGS>(args)...);
192#endif // SIRIUS_ROCM
193}
194
195template <typename... ARGS>
196inline auto zgemm(ARGS&&... args) -> status_t {
197#if defined(SIRIUS_ROCM)
198 return rocblas_zgemm(std::forward<ARGS>(args)...);
199#else
200 return cublasZgemm(std::forward<ARGS>(args)...);
201#endif // SIRIUS_ROCM
202}
203
204template <typename... ARGS>
205inline auto dgemv(ARGS&&... args) -> status_t {
206#if defined(SIRIUS_ROCM)
207 return rocblas_dgemv(std::forward<ARGS>(args)...);
208#else
209 return cublasDgemv(std::forward<ARGS>(args)...);
210#endif // SIRIUS_ROCM
211}
212
213template <typename... ARGS>
214inline auto zgemv(ARGS&&... args) -> status_t {
215#if defined(SIRIUS_ROCM)
216 return rocblas_zgemv(std::forward<ARGS>(args)...);
217#else
218 return cublasZgemv(std::forward<ARGS>(args)...);
219#endif // SIRIUS_ROCM
220}
221
222template <typename... ARGS>
223inline auto strmm(ARGS&&... args) -> status_t {
224#if defined(SIRIUS_ROCM)
225 return rocblas_strmm(std::forward<ARGS>(args)...);
226#else
227 return cublasStrmm(std::forward<ARGS>(args)...);
228#endif // SIRIUS_ROCM
229}
230
231template <typename... ARGS>
232inline auto dtrmm(ARGS&&... args) -> status_t {
233#if defined(SIRIUS_ROCM)
234 return rocblas_dtrmm(std::forward<ARGS>(args)...);
235#else
236 return cublasDtrmm(std::forward<ARGS>(args)...);
237#endif // SIRIUS_ROCM
238}
239
240template <typename... ARGS>
241inline auto ctrmm(ARGS&&... args) -> status_t {
242#if defined(SIRIUS_ROCM)
243 return rocblas_ctrmm(std::forward<ARGS>(args)...);
244#else
245 return cublasCtrmm(std::forward<ARGS>(args)...);
246#endif // SIRIUS_ROCM
247}
248
249template <typename... ARGS>
250inline auto ztrmm(ARGS&&... args) -> status_t {
251#if defined(SIRIUS_ROCM)
252 return rocblas_ztrmm(std::forward<ARGS>(args)...);
253#else
254 return cublasZtrmm(std::forward<ARGS>(args)...);
255#endif // SIRIUS_ROCM
256}
257
258template <typename... ARGS>
259inline auto sger(ARGS&&... args) -> status_t {
260#if defined(SIRIUS_ROCM)
261 return rocblas_sger(std::forward<ARGS>(args)...);
262#else
263 return cublasSger(std::forward<ARGS>(args)...);
264#endif // SIRIUS_ROCM
265}
266
267template <typename... ARGS>
268inline auto dger(ARGS&&... args) -> status_t {
269#if defined(SIRIUS_ROCM)
270 return rocblas_dger(std::forward<ARGS>(args)...);
271#else
272 return cublasDger(std::forward<ARGS>(args)...);
273#endif // SIRIUS_ROCM
274}
275
276template <typename... ARGS>
277inline auto cgeru(ARGS&&... args) -> status_t {
278#if defined(SIRIUS_ROCM)
279 return rocblas_cgeru(std::forward<ARGS>(args)...);
280#else
281 return cublasCgeru(std::forward<ARGS>(args)...);
282#endif // SIRIUS_ROCM
283}
284
285template <typename... ARGS>
286inline auto zgeru(ARGS&&... args) -> status_t {
287#if defined(SIRIUS_ROCM)
288 return rocblas_zgeru(std::forward<ARGS>(args)...);
289#else
290 return cublasZgeru(std::forward<ARGS>(args)...);
291#endif // SIRIUS_ROCM
292}
293
294template <typename... ARGS>
295inline auto zaxpy(ARGS&&... args) -> status_t {
296#if defined(SIRIUS_ROCM)
297 return rocblas_zaxpy(std::forward<ARGS>(args)...);
298#else
299 return cublasZaxpy(std::forward<ARGS>(args)...);
300#endif // SIRIUS_ROCM
301}
302
303template <typename... ARGS>
304inline auto dscal(ARGS&&... args) -> status_t {
305#if defined(SIRIUS_ROCM)
306 return rocblas_dscal(std::forward<ARGS>(args)...);
307#else
308 return cublasDscal(std::forward<ARGS>(args)...);
309#endif // SIRIUS_ROCM
310}
311
312template <typename... ARGS>
313inline auto sscal(ARGS&&... args) -> status_t {
314#if defined(SIRIUS_ROCM)
315 return rocblas_sscal(std::forward<ARGS>(args)...);
316#else
317 return cublasSscal(std::forward<ARGS>(args)...);
318#endif // SIRIUS_ROCM
319}
320
321} // namespace blas
322} // namespace acc
323} // namespace sirius
324
325#endif
Namespace of the SIRIUS library.
Definition: sirius.f90:5