32#ifndef __SMOOTH_PERIODIC_FUNCTION_HPP__
33#define __SMOOTH_PERIODIC_FUNCTION_HPP__
39check_smooth_periodic_function_ptr(smooth_periodic_function_ptr_t<T>
const& ptr__,
40 fft::spfft_transform_type<T>
const& spfft__)
42 if (spfft__.dim_x() != ptr__.size_x) {
44 s <<
"x-dimensions don't match" << std::endl
45 <<
" spfft__.dim_x() : " << spfft__.dim_x() << std::endl
46 <<
" ptr__.size_x : " << ptr__.size_x;
49 if (spfft__.dim_y() != ptr__.size_y) {
51 s <<
"y-dimensions don't match" << std::endl
52 <<
" spfft__.dim_y() : " << spfft__.dim_y() << std::endl
53 <<
" ptr__.size_y : " << ptr__.size_y;
56 if (ptr__.offset_z < 0) {
57 if (spfft__.dim_z() != ptr__.size_z) {
59 s <<
"global z-dimensions don't match" << std::endl
60 <<
" spfft__.dim_z() : " << spfft__.dim_z() << std::endl
61 <<
" ptr__.size_z : " << ptr__.size_z;
65 if ((spfft__.local_z_length() != ptr__.size_z) || (spfft__.local_z_offset() != ptr__.offset_z)) {
66 RTE_THROW(
"local z-dimensions don't match");
85 fft::spfft_transform_type<T>*
spfft_{
nullptr};
88 std::shared_ptr<fft::Gvec_fft>
gvecp_{
nullptr};
105 template <
typename F>
109 template <
typename F>
113 template <
typename F>
129 :
spfft_{const_cast<fft::spfft_transform_type<T>*>(&spfft__)}
132 auto& mp = sddk::get_memory_pool(sddk::memory_t::host);
135 check_smooth_periodic_function_ptr(*sptr__, spfft__);
138 RTE_THROW(
"Input pointer is null");
141 bool is_local_rg = (sptr__->offset_z >= 0);
143 int offs = (is_local_rg) ? 0 : spfft__.dim_x() * spfft__.dim_y() * spfft__.local_z_offset();
153 "Smooth_periodic_function.f_pw_local_");
155 if (
gvecp_->comm_ortho_fft().size() != 1) {
173 inline T
const& value(
int ir__)
const
178 inline T& value(
int ir__)
180 return const_cast<T&
>(
static_cast<Smooth_periodic_function<T> const&
>(*this).value(ir__));
183 inline auto values() -> sddk::mdarray<T, 1>&
188 inline auto values() const -> const sddk::mdarray<T,1>&
193 inline auto f_pw_local(
int ig__) -> std::complex<T>&
198 inline auto f_pw_local(
int ig__)
const ->
const std::complex<T>&
203 inline auto f_pw_local() -> sddk::mdarray<std::complex<T>, 1>&
208 inline auto f_pw_local() const -> const sddk::mdarray<std::complex<T>, 1>&
213 inline auto& f_pw_fft(
int ig__)
222 if (
gvecp_->gvec().comm().rank() == 0) {
225 gvecp_->gvec().comm().bcast(&z, 1, 0);
231 RTE_ASSERT(
spfft_ !=
nullptr);
235 auto const& spfft()
const
237 RTE_ASSERT(
spfft_ !=
nullptr);
243 RTE_ASSERT(
gvecp_ !=
nullptr);
247 auto gvec_fft()
const
252 void fft_transform(
int direction__)
254 PROFILE(
"sirius::Smooth_periodic_function::fft_transform");
256 RTE_ASSERT(
gvecp_ !=
nullptr);
258 auto frg_ptr = (
spfft_->local_slice_size() == 0) ?
nullptr : &
f_rg_[0];
260 switch (direction__) {
262 if (
gvecp_->comm_ortho_fft().size() != 1) {
265 spfft_->backward(
reinterpret_cast<real_type<T> const*
>(
f_pw_fft_.at(sddk::memory_t::host)), SPFFT_PU_HOST);
271 spfft_->forward(SPFFT_PU_HOST,
reinterpret_cast<real_type<T>*
>(
f_pw_fft_.at(sddk::memory_t::host)),
273 if (
gvecp_->comm_ortho_fft().size() != 1) {
274 int count =
gvecp_->gvec_slab().counts[
gvecp_->comm_ortho_fft().rank()];
275 int offset =
gvecp_->gvec_slab().offsets[
gvecp_->comm_ortho_fft().rank()];
277 count *
sizeof(std::complex<T>));
282 throw std::runtime_error(
"wrong FFT direction");
287 inline auto gather_f_pw()
const
289 PROFILE(
"sirius::Smooth_periodic_function::gather_f_pw");
291 std::vector<std::complex<T>> fpw(
gvecp_->gvec().num_gvec());
292 gvec().comm().allgather(&
f_pw_local_[0], fpw.data(), gvec().count(), gvec().offset());
297 inline void scatter_f_pw(std::vector<std::complex<T>>
const& f_pw__)
299 std::copy(&f_pw__[
gvecp_->gvec().offset()], &f_pw__[
gvecp_->gvec().offset()] +
gvecp_->gvec().count(),
303 Smooth_periodic_function<T>& operator+=(Smooth_periodic_function<T>
const& rhs__)
307 #pragma omp for schedule(static) nowait
308 for (
int irloc = 0; irloc < this->
spfft_->local_slice_size(); irloc++) {
309 this->
f_rg_(irloc) += rhs__.value(irloc);
311 #pragma omp for schedule(static) nowait
312 for (
int igloc = 0; igloc < this->
gvecp_->gvec().count(); igloc++) {
313 this->
f_pw_local_(igloc) += rhs__.f_pw_local(igloc);
319 Smooth_periodic_function<T>& operator*=(T alpha__)
323 #pragma omp for schedule(static) nowait
324 for (
int irloc = 0; irloc < this->
spfft_->local_slice_size(); irloc++) {
325 this->
f_rg_(irloc) *= alpha__;
327 #pragma omp for schedule(static) nowait
328 for (
int igloc = 0; igloc < this->
gvecp_->gvec().count(); igloc++) {
335 inline T checksum_rg()
const
338 mpi::Communicator(this->
spfft_->communicator()).allreduce(&cs, 1);
342 inline auto checksum_pw()
const
344 auto cs = this->f_pw_local_.
checksum();
345 this->
gvecp_->gvec().comm().allreduce(&cs, 1);
349 inline uint64_t hash_f_pw()
const
352 gvecp_->gvec().comm().bcast(&h, 1, 0);
354 for (
int r = 1; r <
gvecp_->gvec().comm().size(); r++) {
356 gvecp_->gvec().comm().bcast(&h, 1, r);
361 inline uint64_t hash_f_rg()
const
363 auto comm = mpi::Communicator(
spfft_->communicator());
366 for (
int r = 0; r < comm.size(); r++) {
372 comm.bcast(&h, 1, r);
384 fft::spfft_transform_type<T>*
spfft_{
nullptr};
387 std::shared_ptr<fft::Gvec_fft>
gvecp_{
nullptr};
402 for (
int x : {0, 1, 2}) {
403 (*this)[x] = Smooth_periodic_function<T>(spfft__, gvecp__);
407 Smooth_periodic_vector_function<T>& operator=(Smooth_periodic_vector_function<T>&& src__) =
default;
409 spfft::Transform& spfft()
const
411 RTE_ASSERT(
spfft_ !=
nullptr);
415 auto gvec_fft()
const
417 RTE_ASSERT(
gvecp_ !=
nullptr);
428 PROFILE(
"sirius::gradient");
432 #pragma omp parallel for schedule(static)
433 for (
int igloc = 0; igloc < f__.gvec().count(); igloc++) {
434 auto G = f__.gvec().template gvec_cart<index_domain_t::local>(igloc);
435 for (
int x : {0, 1, 2}) {
436 g[x].f_pw_local(igloc) = f__.f_pw_local(igloc) * std::complex<T>(0, G[x]);
447 PROFILE(
"sirius::divergence");
452 for (
int x : {0, 1, 2}) {
453 for (
int igloc = 0; igloc < f.gvec().count(); igloc++) {
454 auto G = f.gvec().template gvec_cart<index_domain_t::local>(igloc);
455 f.f_pw_local(igloc) += g__[x].f_pw_local(igloc) * std::complex<T>(0, G[x]);
466 PROFILE(
"sirius::laplacian");
470 #pragma omp parallel for schedule(static)
471 for (
int igloc = 0; igloc < f__.gvec().count(); igloc++) {
472 auto G = f__.gvec().template gvec_cart<index_domain_t::local>(igloc);
473 g.f_pw_local(igloc) = f__.f_pw_local(igloc) * std::complex<T>(-std::pow(G.length(), 2), 0);
480inline Smooth_periodic_function<T>
481dot(Smooth_periodic_vector_function<T>& vf__, Smooth_periodic_vector_function<T>& vg__)
484 PROFILE(
"sirius::dot");
486 Smooth_periodic_function<T> result(vf__.spfft(), vf__.gvec_fft());
488 #pragma omp parallel for schedule(static)
489 for (
int ir = 0; ir < vf__.spfft().local_slice_size(); ir++) {
491 for (
int x : {0, 1, 2}) {
492 d += vf__[x].value(ir) * vg__[x].value(ir);
494 result.value(ir) = d;
501template <
typename T,
typename F>
505 RTE_ASSERT(&f__.spfft() == &g__.spfft());
510 for (
int irloc = 0; irloc < f__.spfft().local_slice_size(); irloc++) {
511 result_rg +=
conj(f__.value(irloc)) * g__.value(irloc) * theta__(irloc);
521inner_local(Smooth_periodic_function<T>
const& f__, Smooth_periodic_function<T>
const& g__)
523 return inner_local(f__, g__, [](
int ir){
return 1;});
526template <
typename T,
typename F>
528inner(Smooth_periodic_function<T>
const& f__, Smooth_periodic_function<T>
const& g__, F&& theta__)
530 PROFILE(
"sirius::inner");
532 T result_rg =
inner_local(f__, g__, std::forward<F>(theta__));
533 mpi::Communicator(f__.spfft().communicator()).allreduce(&result_rg, 1);
543 return inner(f__, g__, [](
int ir){
return 1;});
551 auto& spfft = src__.spfft();
552 check_smooth_periodic_function_ptr(dest__, spfft);
555 RTE_THROW(
"Output pointer is null");
558 bool is_local_rg = (dest__.offset_z >= 0);
560 int offs = (is_local_rg) ? 0 : spfft.dim_x() * spfft.dim_y() * spfft.local_z_offset();
563 std::copy(src__.values().at(sddk::memory_t::host),
564 src__.values().at(sddk::memory_t::host) + spfft.local_slice_size(),
578 auto& spfft = dest__.spfft();
579 check_smooth_periodic_function_ptr(src__, spfft);
582 RTE_THROW(
"Input pointer is null");
585 bool is_local_rg = (src__.offset_z >= 0);
587 int offs = (is_local_rg) ? 0 : spfft.dim_x() * spfft.dim_y() * spfft.local_z_offset();
590 std::copy(src__.ptr + offs, src__.ptr + offs + spfft.local_slice_size(),
591 dest__.values().at(sddk::memory_t::host));
596copy(Smooth_periodic_function<T>
const& src__, Smooth_periodic_function<T>& dest__)
598 copy(src__.f_rg_, dest__.f_rg_);
599 copy(src__.f_pw_local_, dest__.f_pw_local_);
604scale(T alpha__, Smooth_periodic_function<T>& x__)
606 for (
size_t i = 0; i < x__.f_rg_.size(); i++) {
607 x__.f_rg_[i] *= alpha__;
609 for (
size_t i = 0; i < x__.f_pw_local_.size(); i++) {
610 x__.f_pw_local_[i] *= alpha__;
616axpy(T alpha__, Smooth_periodic_function<T>
const& x__, Smooth_periodic_function<T>& y__)
618 for (
size_t i = 0; i < x__.f_rg_.size(); i++) {
619 y__.f_rg_[i] += x__.f_rg_[i] * alpha__;
621 for (
size_t i = 0; i < x__.f_pw_local_.size(); i++) {
622 y__.f_pw_local_[i] += x__.f_pw_local_[i] * alpha__;
Representation of a smooth (Fourier-transformable) periodic function.
sddk::mdarray< std::complex< T >, 1 > f_pw_local_
Local set of plane-wave expansion coefficients.
fft::spfft_transform_type< T > * spfft_
FFT driver.
void gather_f_pw_fft()
Gather plane-wave coefficients for the subsequent FFT call.
sddk::mdarray< T, 1 > f_rg_
Function on the regular real-space grid.
void zero()
Zero the values on the regular real-space grid and plane-wave coefficients.
Smooth_periodic_function(fft::spfft_transform_type< T > const &spfft__, std::shared_ptr< fft::Gvec_fft > gvecp__, smooth_periodic_function_ptr_t< T > const *sptr__=nullptr)
Constructor.
Smooth_periodic_function()
Default constructor.
sddk::mdarray< std::complex< T >, 1 > f_pw_fft_
Storage of the PW coefficients for the FFT transformation.
std::shared_ptr< fft::Gvec_fft > gvecp_
Distribution of G-vectors.
auto f_0() const
Return plane-wave coefficient for G=0 component.
Vector of the smooth periodic functions.
fft::spfft_transform_type< T > * spfft_
FFT driver.
std::shared_ptr< fft::Gvec_fft > gvecp_
Distribution of G-vectors.
Smooth_periodic_vector_function()
Default constructor does nothing.
MPI communicator wrapper.
void allgather(T *buffer__, int const *recvcounts__, int const *displs__) const
In-place MPI_Allgatherv.
uint64_t hash(uint64_t h__=5381) const
Compute hash of the array.
void zero(memory_t mem__, size_t idx0__, size_t n__)
Zero n elements starting from idx0.
T checksum(size_t idx0__, size_t size__) const
Compute checksum.
Contains helper functions for the interface with SpFFT library.
Declaration and implementation of Gvec class.
Memory management functions and classes.
void spfft_output(spfft_transform_type< T > &spfft__, T *data__)
Output CPU data from the CPU buffer of SpFFT.
enable_return< F, T, int > spfft_input(spfft_transform_type< T > &spfft__, F &&fr__)
Load data from real-valued lambda.
size_t spfft_grid_size_local(T const &spfft__)
Local size of the SpFFT transformation grid.
size_t spfft_grid_size(T const &spfft__)
Total size of the SpFFT transformation grid.
std::enable_if_t< std::is_same< T, real_type< F > >::value, void > inner(::spla::Context &spla_ctx__, sddk::memory_t mem__, spin_range spins__, W const &wf_i__, band_range br_i__, Wave_functions< T > const &wf_j__, band_range br_j__, la::dmatrix< F > &result__, int irow0__, int jcol0__)
Compute inner product between the two sets of wave-functions.
Namespace of the SIRIUS library.
Smooth_periodic_vector_function< T > gradient(Smooth_periodic_function< T > &f__)
Gradient of the function in the plane-wave domain.
T inner_local(Smooth_periodic_function< T > const &f__, Smooth_periodic_function< T > const &g__, F &&theta__)
Compute local contribution to inner product <f|g>
Smooth_periodic_function< T > laplacian(Smooth_periodic_function< T > &f__)
Laplacian of the function in the plane-wave domain.
Smooth_periodic_function< T > divergence(Smooth_periodic_vector_function< T > &g__)
Divergence of the vecor function.
auto conj(double x__)
Return complex conjugate of a number. For a real value this is the number itself.
Contains typedefs, enums and simple descriptors.