28#include <spfft/spfft.hpp>
46struct SpFFT_Grid<std::complex<double>> {
using type = spfft::Grid;};
50struct SpFFT_Grid<std::complex<float>> {
using type = spfft::GridFloat;};
53struct SpFFT_Grid<float> {
using type = spfft::GridFloat;};
57using spfft_grid_type =
typename SpFFT_Grid<T>::type;
74struct SpFFT_Transform<std::complex<float>> {
using type = spfft::TransformFloat;};
78using spfft_transform_type =
typename SpFFT_Transform<T>::type;
80const std::map<SpfftProcessingUnitType, sddk::memory_t> spfft_memory_t = {
81 {SPFFT_PU_HOST, sddk::memory_t::host},
82 {SPFFT_PU_GPU, sddk::memory_t::device}
85template <
typename F,
typename T,
typename ...Args>
86using enable_return =
typename std::enable_if<std::is_same<
typename std::result_of<F(Args...)>::type, T>::value,
void>::type;
89template <
typename T,
typename F>
90inline enable_return<F, T, int>
93 switch (spfft__.type()) {
94 case SPFFT_TRANS_C2C: {
95 auto ptr =
reinterpret_cast<std::complex<T>*
>(spfft__.space_domain_data(SPFFT_PU_HOST));
96 #pragma omp parallel for schedule(static)
97 for (
int i = 0; i < spfft__.local_slice_size(); i++) {
98 ptr[i] = std::complex<T>(fr__(i), 0.0);
102 case SPFFT_TRANS_R2C: {
103 auto ptr =
reinterpret_cast<T*
>(spfft__.space_domain_data(SPFFT_PU_HOST));
104 #pragma omp parallel for schedule(static)
105 for (
int i = 0; i < spfft__.local_slice_size(); i++) {
111 throw std::runtime_error(
"wrong spfft type");
117template <
typename T,
typename F>
118inline enable_return<F, std::complex<T>,
int>
121 switch (spfft__.type()) {
122 case SPFFT_TRANS_C2C: {
123 auto ptr =
reinterpret_cast<std::complex<T>*
>(spfft__.space_domain_data(SPFFT_PU_HOST));
124 #pragma omp parallel for schedule(static)
125 for (
int i = 0; i < spfft__.local_slice_size(); i++) {
130 case SPFFT_TRANS_R2C: {
133 throw std::runtime_error(
"wrong spfft type");
140inline void spfft_input(spfft_transform_type<T>& spfft__, T
const* data__)
142 spfft_input<T>(spfft__, [&](
int ir){
return data__[ir];});
145template <
typename T,
typename F>
146inline void spfft_multiply(spfft_transform_type<T>& spfft__, F&& fr__)
148 switch (spfft__.type()) {
149 case SPFFT_TRANS_C2C: {
150 auto ptr =
reinterpret_cast<std::complex<T>*
>(spfft__.space_domain_data(SPFFT_PU_HOST));
151 #pragma omp parallel for schedule(static)
152 for (
int i = 0; i < spfft__.local_slice_size(); i++) {
157 case SPFFT_TRANS_R2C: {
158 auto ptr =
reinterpret_cast<T*
>(spfft__.space_domain_data(SPFFT_PU_HOST));
159 #pragma omp parallel for schedule(static)
160 for (
int i = 0; i < spfft__.local_slice_size(); i++) {
166 throw std::runtime_error(
"wrong spfft type");
175 switch (spfft__.type()) {
176 case SPFFT_TRANS_C2C: {
177 auto ptr =
reinterpret_cast<std::complex<T>*
>(spfft__.space_domain_data(SPFFT_PU_HOST));
178 #pragma omp parallel for schedule(static)
179 for (
int i = 0; i < spfft__.local_slice_size(); i++) {
180 data__[i] = std::real(ptr[i]);
184 case SPFFT_TRANS_R2C: {
185 auto ptr =
reinterpret_cast<T*
>(spfft__.space_domain_data(SPFFT_PU_HOST));
186 #pragma omp parallel for schedule(static)
187 for (
int i = 0; i < spfft__.local_slice_size(); i++) {
193 throw std::runtime_error(
"wrong spfft type");
199inline void spfft_output(spfft_transform_type<T>& spfft__, std::complex<T>* data__)
201 switch (spfft__.type()) {
202 case SPFFT_TRANS_C2C: {
203 auto ptr =
reinterpret_cast<std::complex<T>*
>(spfft__.space_domain_data(SPFFT_PU_HOST));
204 #pragma omp parallel for schedule(static)
205 for (
int i = 0; i < spfft__.local_slice_size(); i++) {
210 case SPFFT_TRANS_R2C: {
214 throw std::runtime_error(
"wrong spfft type");
223 return spfft__.dim_x() * spfft__.dim_y() * spfft__.dim_z();
230 return spfft__.local_slice_size();
MPI communicator wrapper.
int size() const
Size of the communicator (number of ranks).
int rank() const
Rank of MPI process inside communicator.
Contains declaration and implementation of mpi::Communicator class.
Memory management functions and classes.
void spfft_output(spfft_transform_type< T > &spfft__, T *data__)
Output CPU data from the CPU buffer of SpFFT.
enable_return< F, T, int > spfft_input(spfft_transform_type< T > &spfft__, F &&fr__)
Load data from real-valued lambda.
size_t spfft_grid_size_local(T const &spfft__)
Local size of the SpFFT transformation grid.
size_t spfft_grid_size(T const &spfft__)
Total size of the SpFFT transformation grid.
auto split_z_dimension(int size_z__, mpi::Communicator const &comm_fft__)
Split z-dimenstion of size_z between MPI ranks of the FFT communicator.
Namespace of the SIRIUS library.
strong_type< int, struct __n_blocks_tag > n_blocks
Number of blocks to which the global index is split.
Contains definition of sddk::splindex_base and specializations of sddk::splindex class.
Type traits to handle Spfft grid for different precision type.