30__global__ std::enable_if_t<!std::is_same<acc_complex_double_t, T>::value>
31diag_mm(
const T* diag,
int n,
const T* X,
int lda_x,
int ncols, T* Y,
int lda_y, T alpha)
33 int row = blockIdx.x * blockDim.x + threadIdx.x;
34 int col = blockIdx.y * blockDim.y + threadIdx.y;
36 if (col < ncols && row < n) {
37 T X_elem = *(X + lda_x * col + row);
39 *(Y + lda_y * col + row) = alpha * D * X_elem;
44__global__ std::enable_if_t<std::is_same<acc_complex_double_t, T>::value>
45diag_mm(
const T* diag,
int n,
const T* X,
int lda_x,
int ncols, T* Y,
int lda_y, T alpha)
47 int row = blockIdx.x * blockDim.x + threadIdx.x;
48 int col = blockIdx.y * blockDim.y + threadIdx.y;
50 if (col < ncols && row < n) {
51 acc_complex_double_t X_elem = *(X + lda_x * col + row);
52 acc_complex_double_t D = diag[row];
53 *(Y + lda_y * col + row) = accCmul(accCmul(alpha, D), X_elem);
59call_diagmm(
const T* diag,
int n,
const T* X,
int lda_x,
int ncols, T* Y,
int lda_y, T alpha)
62 dim3 threadsPerBlock(numthreads, numthreads);
64 int num_block_rows = (n + threadsPerBlock.x - 1) / threadsPerBlock.x;
65 int num_block_cols = (ncols + threadsPerBlock.y - 1) / threadsPerBlock.y;
66 dim3 numBlocks(num_block_rows, num_block_cols);
68 diag_mm<<<numBlocks, threadsPerBlock>>>(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
73ddiagmm(
const double* diag,
int n,
const double* X,
int lda_x,
int ncols,
double* Y,
int lda_y,
double alpha)
75 call_diagmm(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
79sdiagmm(
const float* diag,
int n,
const float* X,
int lda_x,
int ncols,
float* Y,
int lda_y,
float alpha)
81 call_diagmm(diag, n, X, lda_x, ncols, Y, lda_y, alpha);
84zdiagmm(
const std::complex<double>* diag,
int n,
const std::complex<double>* X,
int lda_x,
int ncols,
85 std::complex<double>* Y,
int lda_y, std::complex<double> alpha)
87 call_diagmm(
reinterpret_cast<const acc_complex_double_t*
>(diag), n,
88 reinterpret_cast<const acc_complex_double_t*
>(X), lda_x, ncols,
89 reinterpret_cast<acc_complex_double_t*
>(Y), lda_y, acc_complex_double_t{alpha.real(), alpha.imag()});
Interface to accelerators API.
Uniform interface to the runtime API of CUDA and ROCm.