Loading [MathJax]/extensions/TeX/AMSsymbols.js
SIRIUS 7.5.0
Electronic structure library and applications
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Pages
core/gpu_kernels/diag_mm.cu
Go to the documentation of this file.
1// Copyright (c) 2013-2023 Simon Pintarelli
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted provided that
5// the following conditions are met:
6//
7// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
8// following disclaimer.
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
10// and the following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
13// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
15// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
16// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
17// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
18// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
20/** \file diag_mm.cu
21 *
22 * \brief
23 */
24
25#include "diag_mm.hpp"
26#include "acc_runtime.hpp"
27#include "acc.hpp"
28
29template <class T>
30__global__ std::enable_if_t<!std::is_same<acc_complex_double_t, T>::value>
31diag_mm(const T* diag, int n, const T* X, int lda_x, int ncols, T* Y, int lda_y, T alpha)
32{
33 int row = blockIdx.x * blockDim.x + threadIdx.x;
34 int col = blockIdx.y * blockDim.y + threadIdx.y;
35
36 if (col < ncols && row < n) {
37 T X_elem = *(X + lda_x * col + row);
38 T D = diag[row];
39 *(Y + lda_y * col + row) = alpha * D * X_elem;
40 }
41}
42
43template <class T>
44__global__ std::enable_if_t<std::is_same<acc_complex_double_t, T>::value>
45diag_mm(const T* diag, int n, const T* X, int lda_x, int ncols, T* Y, int lda_y, T alpha)
46{
47 int row = blockIdx.x * blockDim.x + threadIdx.x;
48 int col = blockIdx.y * blockDim.y + threadIdx.y;
49
50 if (col < ncols && row < n) {
51 acc_complex_double_t X_elem = *(X + lda_x * col + row);
52 acc_complex_double_t D = diag[row];
53 *(Y + lda_y * col + row) = accCmul(accCmul(alpha, D), X_elem);
54 }
55}
56
57template <class T>
58void
59call_diagmm(const T* diag, int n, const T* X, int lda_x, int ncols, T* Y, int lda_y, T alpha)
60{
61 int numthreads = 32;
62 dim3 threadsPerBlock(numthreads, numthreads);
63
64 int num_block_rows = (n + threadsPerBlock.x - 1) / threadsPerBlock.x;
65 int num_block_cols = (ncols + threadsPerBlock.y - 1) / threadsPerBlock.y;
66 dim3 numBlocks(num_block_rows, num_block_cols);
67
68 diag_mm<<<numBlocks, threadsPerBlock>>>(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
69}
70
71extern "C" {
72void
73ddiagmm(const double* diag, int n, const double* X, int lda_x, int ncols, double* Y, int lda_y, double alpha)
74{
75 call_diagmm(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
76}
77
78void
79sdiagmm(const float* diag, int n, const float* X, int lda_x, int ncols, float* Y, int lda_y, float alpha)
80{
81 call_diagmm(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
82}
83void
84zdiagmm(const std::complex<double>* diag, int n, const std::complex<double>* X, int lda_x, int ncols,
85 std::complex<double>* Y, int lda_y, std::complex<double> alpha)
86{
87 call_diagmm(reinterpret_cast<const acc_complex_double_t*>(diag), n,
88 reinterpret_cast<const acc_complex_double_t*>(X), lda_x, ncols,
89 reinterpret_cast<acc_complex_double_t*>(Y), lda_y, acc_complex_double_t{alpha.real(), alpha.imag()});
90}
91}
Interface to accelerators API.
Uniform interface to the runtime API of CUDA and ROCm.