33__global__ std::enable_if_t<!std::is_same<acc_complex_double_t, T>::value>
34diag_mm(
const T* diag,
int n,
const T* X,
int lda_x,
int ncols, T* Y,
int lda_y, T alpha)
36 int row = blockIdx.x * blockDim.x + threadIdx.x;
37 int col = blockIdx.y * blockDim.y + threadIdx.y;
39 if (col < ncols && row < n) {
40 T X_elem = *(X + lda_x * col + row);
42 *(Y + lda_y * col + row) = alpha * D * X_elem;
47__global__ std::enable_if_t<std::is_same<acc_complex_double_t, T>::value>
48diag_mm(
const T* diag,
int n,
const T* X,
int lda_x,
int ncols, T* Y,
int lda_y, T alpha)
50 int row = blockIdx.x * blockDim.x + threadIdx.x;
51 int col = blockIdx.y * blockDim.y + threadIdx.y;
53 if (col < ncols && row < n) {
54 acc_complex_double_t X_elem = *(X + lda_x * col + row);
55 acc_complex_double_t D = diag[row];
56 *(Y + lda_y * col + row) = accCmul(accCmul(alpha, D), X_elem);
62call_diagmm(
const T* diag,
int n,
const T* X,
int lda_x,
int ncols, T* Y,
int lda_y, T alpha)
65 dim3 threadsPerBlock(numthreads, numthreads);
67 int num_block_rows = (n + threadsPerBlock.x - 1) / threadsPerBlock.x;
68 int num_block_cols = (ncols + threadsPerBlock.y - 1) / threadsPerBlock.y;
69 dim3 numBlocks(num_block_rows, num_block_cols);
71 diag_mm<<<numBlocks, threadsPerBlock>>>(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
76ddiagmm(
const double* diag,
int n,
const double* X,
int lda_x,
int ncols,
double* Y,
int lda_y,
double alpha)
78 call_diagmm(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
82sdiagmm(
const float* diag,
int n,
const float* X,
int lda_x,
int ncols,
float* Y,
int lda_y,
float alpha)
84 call_diagmm(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
87zdiagmm(
const std::complex<double>* diag,
int n,
const std::complex<double>* X,
int lda_x,
int ncols,
88 std::complex<double>* Y,
int lda_y, std::complex<double> alpha)
90 call_diagmm(
reinterpret_cast<const acc_complex_double_t*
>(diag), n,
91 reinterpret_cast<const acc_complex_double_t*
>(X), lda_x, ncols,
92 reinterpret_cast<acc_complex_double_t*
>(Y), lda_y, acc_complex_double_t{alpha.real(), alpha.imag()});
Interface to accelerators API.
Uniform interface to the runtime API of CUDA and ROCm.
Namespace for accelerator-related functions.
Namespace of the SIRIUS library.