SIRIUS 7.5.0
Electronic structure library and applications
nlcglib/preconditioner/diag_mm.cu
Go to the documentation of this file.
1// Copyright (c) 2023 Simon Pintarelli, Anton Kozhevnikov, Thomas Schulthess
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted provided that
5// the following conditions are met:
6//
7// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
8// following disclaimer.
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
10// and the following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
13// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
15// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
16// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
17// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
18// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
20/** \file diag_mm.cu
21 *
22 * \brief Matrix-matrix multiplication with a diagonal matrix
23 */
24
25#include "diag_mm.hpp"
27#include "core/acc/acc.hpp"
28
29using namespace sirius;
30using namespace sirius::acc;
31
32template <class T>
33__global__ std::enable_if_t<!std::is_same<acc_complex_double_t, T>::value>
34diag_mm(const T* diag, int n, const T* X, int lda_x, int ncols, T* Y, int lda_y, T alpha)
35{
36 int row = blockIdx.x * blockDim.x + threadIdx.x;
37 int col = blockIdx.y * blockDim.y + threadIdx.y;
38
39 if (col < ncols && row < n) {
40 T X_elem = *(X + lda_x * col + row);
41 T D = diag[row];
42 *(Y + lda_y * col + row) = alpha * D * X_elem;
43 }
44}
45
46template <class T>
47__global__ std::enable_if_t<std::is_same<acc_complex_double_t, T>::value>
48diag_mm(const T* diag, int n, const T* X, int lda_x, int ncols, T* Y, int lda_y, T alpha)
49{
50 int row = blockIdx.x * blockDim.x + threadIdx.x;
51 int col = blockIdx.y * blockDim.y + threadIdx.y;
52
53 if (col < ncols && row < n) {
54 acc_complex_double_t X_elem = *(X + lda_x * col + row);
55 acc_complex_double_t D = diag[row];
56 *(Y + lda_y * col + row) = accCmul(accCmul(alpha, D), X_elem);
57 }
58}
59
60template <class T>
61void
62call_diagmm(const T* diag, int n, const T* X, int lda_x, int ncols, T* Y, int lda_y, T alpha)
63{
64 int numthreads = 32;
65 dim3 threadsPerBlock(numthreads, numthreads);
66
67 int num_block_rows = (n + threadsPerBlock.x - 1) / threadsPerBlock.x;
68 int num_block_cols = (ncols + threadsPerBlock.y - 1) / threadsPerBlock.y;
69 dim3 numBlocks(num_block_rows, num_block_cols);
70
71 diag_mm<<<numBlocks, threadsPerBlock>>>(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
72}
73
74extern "C" {
75void
76ddiagmm(const double* diag, int n, const double* X, int lda_x, int ncols, double* Y, int lda_y, double alpha)
77{
78 call_diagmm(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
79}
80
81void
82sdiagmm(const float* diag, int n, const float* X, int lda_x, int ncols, float* Y, int lda_y, float alpha)
83{
84 call_diagmm(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
85}
86void
87zdiagmm(const std::complex<double>* diag, int n, const std::complex<double>* X, int lda_x, int ncols,
88 std::complex<double>* Y, int lda_y, std::complex<double> alpha)
89{
90 call_diagmm(reinterpret_cast<const acc_complex_double_t*>(diag), n,
91 reinterpret_cast<const acc_complex_double_t*>(X), lda_x, ncols,
92 reinterpret_cast<acc_complex_double_t*>(Y), lda_y, acc_complex_double_t{alpha.real(), alpha.imag()});
93}
94}
Interface to accelerators API.
Uniform interface to the runtime API of CUDA and ROCm.
Namespace for accelerator-related functions.
Definition: acc.cpp:30
Namespace of the SIRIUS library.
Definition: sirius.f90:5