SIRIUS 7.5.0
Electronic structure library and applications
wave_functions.hpp
Go to the documentation of this file.
1// Copyright (c) 2013-2022 Anton Kozhevnikov, Thomas Schulthess
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted provided that
5// the following conditions are met:
6//
7// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
8// following disclaimer.
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
10// and the following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
13// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
15// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
16// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
17// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
18// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
20/** \file wave_functions.hpp
21 *
22 * \brief Contains declaration and implementation of Wave_functions class.
23 */
24
25#ifndef __WAVE_FUNCTIONS_HPP__
26#define __WAVE_FUNCTIONS_HPP__
27
28#include <cstdlib>
29#include <iostream>
30#include <costa/layout.hpp>
31#include <costa/grid2grid/transformer.hpp>
32#include "core/la/linalg.hpp"
33#include "core/strong_type.hpp"
34#include "core/hdf5_tree.hpp"
35#include "core/fft/gvec.hpp"
36#include "core/env/env.hpp"
37#include "core/rte/rte.hpp"
38#include "core/time_tools.hpp"
39#include "SDDK/memory.hpp"
40
41namespace sirius {
42
43#if defined(SIRIUS_GPU)
44extern "C" {
45
46void
47add_square_sum_gpu_double(std::complex<double> const* wf__, int num_rows_loc__, int nwf__, int reduced__,
48 int mpi_rank__, double* result__);
49
50void
51add_square_sum_gpu_float(std::complex<float> const* wf__, int num_rows_loc__, int nwf__, int reduced__,
52 int mpi_rank__, float* result__);
53
54void
55scale_matrix_columns_gpu_double(int nrow__, int ncol__, std::complex<double>* mtrx__, double* a__);
56
57void
58scale_matrix_columns_gpu_float(int nrow__, int ncol__, std::complex<float>* mtrx__, float* a__);
59
60void
61add_checksum_gpu_double(void const* wf__, int ld__, int num_rows_loc__, int nwf__, void* result__);
62
63void
64add_checksum_gpu_float(void const* wf__, int ld__, int num_rows_loc__, int nwf__, void* result__);
65
66
67void
68inner_diag_local_gpu_double_complex_double(void const* wf1__, int ld1__, void const* wf2__, int ld2__, int ngv_loc__,
69 int nwf__, void* result__);
70
71void
72inner_diag_local_gpu_double_double(void const* wf1__, int ld1__, void const* wf2__, int ld2__, int ngv_loc__,
73 int nwf__, int reduced__, void* result__);
74
75void
76axpby_gpu_double_complex_double(int nwf__, void const* alpha__, void const* x__, int ld1__,
77 void const* beta__, void* y__, int ld2__, int ngv_loc__);
78
79void
80axpby_gpu_double_double(int nwf__, void const* alpha__, void const* x__, int ld1__,
81 void const* beta__, void* y__, int ld2__, int ngv_loc__);
82
83void
84axpy_scatter_gpu_double_complex_double(int nwf__, void const* alpha__, void const* x__, int ld1__,
85 void const* idx__, void* y__, int ld2__, int ngv_loc__);
86
87void
88axpy_scatter_gpu_double_double(int nwf__, void const* alpha__, void const* x__, int ld1__,
89 void const* idx__, void* y__, int ld2__, int ngv_loc__);
90}
91#endif
92
93/// Add checksum for the arrays on GPUs.
94template <typename T>
95auto checksum_gpu(std::complex<T> const* wf__, int ld__, int num_rows_loc__, int nwf__)
96{
97 std::complex<T> cs{0};
98#if defined(SIRIUS_GPU)
99 sddk::mdarray<std::complex<T>, 1> cs1(nwf__, sddk::memory_t::host, "checksum");
100 cs1.allocate(sddk::memory_t::device).zero(sddk::memory_t::device);
101
102 if (std::is_same<T, float>::value) {
103 add_checksum_gpu_float(wf__, ld__, num_rows_loc__, nwf__, cs1.at(sddk::memory_t::device));
104 } else if (std::is_same<T, double>::value) {
105 add_checksum_gpu_double(wf__, ld__, num_rows_loc__, nwf__, cs1.at(sddk::memory_t::device));
106 } else {
107 std::stringstream s;
108 s << "Precision type not yet implemented";
109 RTE_THROW(s);
110 }
111 cs1.copy_to(sddk::memory_t::host);
112 cs = cs1.checksum();
113#endif
114 return cs;
115}
116
117/// Namespace for the wave-functions.
118namespace wf {
119
121//using atom_index = strong_type<int, struct __atom_index_tag>;
123
127
128/// Describe a range of bands.
130{
131 private:
132 int begin_;
133 int end_;
134 public:
135 band_range(int begin__, int end__)
136 : begin_{begin__}
137 , end_{end__}
138 {
139 RTE_ASSERT(begin_ >= 0);
140 RTE_ASSERT(end_ >= 0);
141 RTE_ASSERT(begin_ <= end_);
142 }
143 band_range(int size__)
144 : begin_{0}
145 , end_{size__}
146 {
147 RTE_ASSERT(size__ >= 0);
148 }
149 inline auto begin() const
150 {
151 return begin_;
152 }
153 inline auto end() const
154 {
155 return end_;
156 }
157 inline auto size() const
158 {
159 return end_ - begin_;
160 }
161};
162
163/// Describe a range of spins.
164/** Only 3 combinations of spin range are allowed:
165 - [0, 1)
166 - [1, 2)
167 - [0, 2)
168*/
170{
171 private:
172 int begin_;
173 int end_;
174 int spinor_index_;
175 public:
176 spin_range(int begin__, int end__)
177 : begin_{begin__}
178 , end_{end__}
179 {
180 RTE_ASSERT(begin_ >= 0);
181 RTE_ASSERT(end_ >= 0);
182 RTE_ASSERT(begin_ <= end_);
183 RTE_ASSERT(end_ <= 2);
184 /* if size of the spin range is 2, this is a full-spinor case */
185 if (this->size() == 2) {
186 spinor_index_ = 0;
187 } else {
188 spinor_index_ = begin_;
189 }
190 }
191 spin_range(int ispn__)
192 : begin_{ispn__}
193 , end_{ispn__ + 1}
194 , spinor_index_{ispn__}
195 {
196 RTE_ASSERT(begin_ >= 0);
197 RTE_ASSERT(end_ >= 0);
198 RTE_ASSERT(begin_ <= end_);
199 RTE_ASSERT(end_ <= 2);
200 }
201 inline auto begin() const
202 {
203 return spin_index(begin_);
204 }
205 inline auto end() const
206 {
207 return spin_index(end_);
208 }
209 inline int size() const
210 {
211 return end_ - begin_;
212 }
213 inline int spinor_index() const
214 {
215 return spinor_index_;
216 }
217};
218
219enum class copy_to : unsigned int
220{
221 none = 0b0000,
222 device = 0b0001,
223 host = 0b0010
224};
225inline copy_to operator|(copy_to a__, copy_to b__)
226{
227 return static_cast<copy_to>(static_cast<unsigned int>(a__) | static_cast<unsigned int>(b__));
228}
229
230/// Helper class to allocate and copy wave-functions to/from device.
232{
233 private:
234 void* obj_{nullptr};
235 sddk::memory_t mem_{sddk::memory_t::host};
236 copy_to copy_to_{wf::copy_to::none};
237 std::function<void(void*, sddk::memory_t, wf::copy_to)> handler_;
238
240 device_memory_guard& operator=(device_memory_guard const&) = delete;
241 public:
243 {
244 }
245
246 template <typename T>
247 device_memory_guard(T const& obj__, sddk::memory_t mem__, copy_to copy_to__)
248 : obj_{const_cast<T*>(&obj__)}
249 , mem_{mem__}
250 , copy_to_{copy_to__}
251 {
252 if (is_device_memory(mem_)) {
253 auto obj = static_cast<T*>(obj_);
254 obj->allocate(mem_);
255 if (static_cast<unsigned int>(copy_to_) & static_cast<unsigned int>(copy_to::device)) {
256 obj->copy_to(mem_);
257 }
258 }
259 handler_ = [](void* p__, sddk::memory_t mem__, wf::copy_to copy_to__)
260 {
261 if (p__) {
262 auto obj = static_cast<T*>(p__);
263 if (is_device_memory(mem__)) {
264 if (static_cast<unsigned int>(copy_to__) & static_cast<unsigned int>(copy_to::host)) {
265 obj->copy_to(sddk::memory_t::host);
266 }
267 obj->deallocate(mem__);
268 }
269 }
270 };
271 }
273 {
274 this->obj_ = src__.obj_;
275 src__.obj_ = nullptr;
276 this->handler_ = src__.handler_;
277 this->mem_ = src__.mem_;
278 this->copy_to_ = src__.copy_to_;
279 }
280 device_memory_guard& operator=(device_memory_guard&& src__)
281 {
282 if (this != &src__) {
283 this->obj_ = src__.obj_;
284 src__.obj_ = nullptr;
285 this->handler_ = src__.handler_;
286 this->mem_ = src__.mem_;
287 this->copy_to_ = src__.copy_to_;
288 }
289 return *this;
290 }
291
293 {
294 handler_(obj_, mem_, copy_to_);
295 }
296};
297
298/* forward declaration */
299template <typename T>
301
302/// Base class for the wave-functions.
303/** Wave-functions are represented by a set of plane-wave and muffin-tin coefficients stored consecutively in a 2D array.
304 * The leading dimensions of this array is a sum of the number of plane-waves and the number of muffin-tin coefficients.
305 \verbatim
306
307 band index
308 +-----------+
309 | |
310 | |
311 ig | PW part |
312 | |
313 | |
314 +-----------+
315 | |
316 xi | MT part |
317 | |
318 +-----------+
319
320
321 \endverbatim
322 */
323template <typename T>
325{
326 protected:
327 /// Local number of plane-wave coefficients.
328 int num_pw_{0};
329 /// Local number of muffin-tin coefficients.
330 int num_mt_{0};
331 /// Number of magnetic dimensions (0, 1, or 3).
332 /** This helps to distinguish between non-magnetic, collinear and full spinor wave-functions. */
334 /// Total number of wave-functions.
336 /// Number of spin components (1 or 2).
338 /// Friend class declaration.
339 /** Wave_functions_fft needs access to data to alias the pointers and avoid copy in trivial cases. */
340 friend class Wave_functions_fft<T>;
341 /// Data storage for the wave-functions.
342 /** Wave-functions are stored as two independent arrays for spin-up and spin-dn. The planewave and muffin-tin
343 coefficients are stored consecutively. */
344 std::array<sddk::mdarray<std::complex<T>, 2>, 2> data_;
345
346 public:
347 /// Constructor.
349 {
350 }
351 /// Constructor.
352 Wave_functions_base(int num_pw__, int num_mt__, num_mag_dims num_md__, num_bands num_wf__,
353 sddk::memory_t default_mem__)
354 : num_pw_{num_pw__}
355 , num_mt_{num_mt__}
356 , num_md_{num_md__}
357 , num_wf_{num_wf__}
358 {
359 if (!(num_md_.get() == 0 || num_md_.get() == 1 || num_md_.get() == 3)) {
360 RTE_THROW("wrong number of magnetic dimensions");
361 }
362
363 if (num_md_.get() == 0) {
364 num_sc_ = num_spins(1);
365 } else {
366 num_sc_ = num_spins(2);
367 }
368 for (int is = 0; is < num_sc_.get(); is++) {
370 sddk::get_memory_pool(default_mem__), "Wave_functions_base::data_");
371 }
372 }
373
374 /// Return an instance of the memory guard.
375 /** When the instance is created, it allocates the GPU memory and optionally copies data to the GPU. When the
376 instance is destroyed, the data is optionally copied to host and GPU memory is deallocated. */
377 auto memory_guard(sddk::memory_t mem__, wf::copy_to copy_to__ = copy_to::none) const
378 {
379 return device_memory_guard(*this, mem__, copy_to__);
380 }
381
382 /// Return number of spin components.
383 inline auto num_sc() const
384 {
385 return num_sc_;
386 }
387
388 /// Return number of magnetic dimensions.
389 inline auto num_md() const
390 {
391 return num_md_;
392 }
393
394 /// Return number of wave-functions.
395 inline auto num_wf() const
396 {
397 return num_wf_;
398 }
399
400 /// Return leading dimensions of the wave-functions coefficients array.
401 inline auto ld() const
402 {
403 return num_pw_ + num_mt_;
404 }
405
406 /// Return the actual spin index of the wave-functions.
407 /** Return 0 if the wave-functions are non-magnetic, otherwise return the input spin index. */
408 inline auto actual_spin_index(spin_index s__) const
409 {
410 if (num_sc_.get() == 2) {
411 return s__;
412 } else {
413 return spin_index(0);
414 }
415 }
416
417 /// Zero a spin component of the wave-functions in a band range.
418 inline void
420 {
421 if (this->ld()) {
422 if (is_host_memory(mem__)) {
423 for (int ib = br__.begin(); ib < br__.end(); ib++) {
424 auto ptr = data_[s__.get()].at(mem__, 0, ib);
425 std::fill(ptr, ptr + this->ld(), 0);
426 }
427 }
428 if (is_device_memory(mem__)) {
429 acc::zero(data_[s__.get()].at(mem__, 0, br__.begin()), this->ld(), this->ld(), br__.size());
430 }
431 }
432 }
433
434 /// Zero all wave-functions.
435 inline void
437 {
438 if (this->ld()) {
439 for (int is = 0; is < num_sc_.get(); is++) {
440 data_[is].zero(mem__);
441 }
442 }
443 }
444
445 /// Return const pointer to the wave-function coefficient at a given index, spin and band
446 inline std::complex<T> const*
447 at(sddk::memory_t mem__, int i__, spin_index s__, band_index b__) const
448 {
449 return data_[s__.get()].at(mem__, i__, b__.get());
450 }
451
452 /// Return pointer to the wave-function coefficient at a given index, spin and band
453 inline auto
454 at(sddk::memory_t mem__, int i__, spin_index s__, band_index b__)
455 {
456 return data_[s__.get()].at(mem__, i__, b__.get());
457 }
458
459 /// Allocate wave-functions.
460 /** This function is primarily called by a memory_guard to allocate GPU memory. */
461 inline void
463 {
464 for (int s = 0; s < num_sc_.get(); s++) {
465 data_[s].allocate(sddk::get_memory_pool(mem__));
466 }
467 }
468
469 /// Deallocate wave-functions.
470 /** This function is primarily called by a memory_guard to deallocate GPU memory. */
471 inline void
473 {
474 for (int s = 0; s < num_sc_.get(); s++) {
475 data_[s].deallocate(mem__);
476 }
477 }
478
479 /// Copy date to host or device.
480 inline void
482 {
483 for (int s = 0; s < num_sc_.get(); s++) {
484 data_[s].copy_to(mem__);
485 }
486 }
487};
488
489/// Wave-functions for the muffin-tin part of LAPW.
490template <typename T>
492{
493 protected:
494 /// Communicator that is used to split atoms between MPI ranks.
496 /// Total number of atoms.
498 /// Distribution of atoms between MPI ranks.
500 /// Local size of muffin-tin coefficients for each rank.
501 /** Each rank stores local fraction of atoms. Each atom has a set of MT coefficients. */
503 /// Local offset in the block of MT coefficients for current rank.
504 /** The size of the vector is equal to the local number of atoms for the current rank. */
506 /// Numbef of muffin-tin coefficients for each atom.
507 std::vector<int> num_mt_coeffs_;
508
509 /// Calculate the local number of muffin-tin coefficients.
510 /** Compute the local fraction of atoms and then sum the muffin-tin coefficients for this fraction. */
511 static int get_local_num_mt_coeffs(std::vector<int> num_mt_coeffs__, mpi::Communicator const& comm__)
512 {
513 int num_atoms = static_cast<int>(num_mt_coeffs__.size());
514 splindex_block<atom_index_t> spl_atoms(num_atoms, n_blocks(comm__.size()), block_id(comm__.rank()));
515 int result{0};
516 for (auto it : spl_atoms) {
517 result += num_mt_coeffs__[it.i];
518 }
519 return result;
520 }
521
522 /// Construct without muffin-tin part.
524 sddk::memory_t default_mem__, int num_pw__)
525 : Wave_functions_base<T>(num_pw__, 0, num_md__, num_wf__, default_mem__)
526 , comm_{comm__}
528 {
529 }
530
531 public:
532 /// Constructor.
534 {
535 }
536
537 /// Constructor.
538 Wave_functions_mt(mpi::Communicator const& comm__, std::vector<int> num_mt_coeffs__, num_mag_dims num_md__,
539 num_bands num_wf__, sddk::memory_t default_mem__, int num_pw__ = 0)
540 : Wave_functions_base<T>(num_pw__, get_local_num_mt_coeffs(num_mt_coeffs__, comm__), num_md__, num_wf__,
541 default_mem__)
542 , comm_{comm__}
543 , num_atoms_{static_cast<int>(num_mt_coeffs__.size())}
545 , num_mt_coeffs_{num_mt_coeffs__}
546 {
548
549 for (int ia = 0; ia < num_atoms_; ia++) {
550 auto rank = spl_num_atoms_.location(atom_index_t::global(ia)).ib;
551 if (rank == comm_.rank()) {
552 offset_in_local_mt_coeffs_.push_back(mt_coeffs_distr_.counts[rank]);
553 }
554 /* increment local number of MT coeffs. for a given rank */
555 mt_coeffs_distr_.counts[rank] += num_mt_coeffs__[ia];
556 }
557 mt_coeffs_distr_.calc_offsets();
558 }
559
560 /// Return reference to the coefficient by atomic orbital index, atom, spin and band indices.
561 inline auto&
563 {
564 return this->data_[ispn__.get()](this->num_pw_ + xi__ + offset_in_local_mt_coeffs_[ia__.get()], i__.get());
565 }
566
567 /// Return const reference to the coefficient by atomic orbital index, atom, spin and band indices.
568 inline auto const&
569 mt_coeffs(int xi__, atom_index_t::local ia__, spin_index ispn__, band_index i__) const
570 {
571 return this->data_[ispn__.get()](this->num_pw_ + xi__ + offset_in_local_mt_coeffs_[ia__.get()], i__.get());
572 }
573
575
576 /// Return const pointer to the coefficient by atomic orbital index, atom, spin and band indices.
577 inline std::complex<T> const*
578 at(sddk::memory_t mem__, int xi__, atom_index_t::local ia__, spin_index s__, band_index b__) const
579 {
580 return this->data_[s__.get()].at(mem__, this->num_pw_ + xi__ + offset_in_local_mt_coeffs_[ia__.get()], b__.get());
581 }
582
583 /// Return pointer to the coefficient by atomic orbital index, atom, spin and band indices.
584 inline auto
586 {
587 return this->data_[s__.get()].at(mem__, this->num_pw_ + xi__ + offset_in_local_mt_coeffs_[ia__.get()], b__.get());
588 }
589
590 /// Return a split index for the number of atoms.
591 inline auto const&
593 {
594 return spl_num_atoms_;
595 }
596
597 /// Copy muffin-tin coefficients to host or GPU memory.
598 /** This functionality is required for the application of LAPW overlap operator to the wave-functions, which
599 * is always done on the CPU. */
600 inline void
602 {
603 if (this->ld() && this->num_mt_) {
604 auto ptr = this->data_[s__.get()].at(sddk::memory_t::host, this->num_pw_, br__.begin());
605 auto ptr_gpu = this->data_[s__.get()].at(sddk::memory_t::device, this->num_pw_, br__.begin());
606 if (is_device_memory(mem__)) {
607 acc::copyin(ptr_gpu, this->ld(), ptr, this->ld(), this->num_mt_, br__.size());
608 }
609 if (is_host_memory(mem__)) {
610 acc::copyout(ptr, this->ld(), ptr_gpu, this->ld(), this->num_mt_, br__.size());
611 }
612 }
613 }
614
615 /// Return COSTA layout for the muffin-tin part for a given spin index and band range.
616 auto
618 {
619 std::vector<int> rowsplit(comm_.size() + 1);
620 rowsplit[0] = 0;
621 for (int i = 0; i < comm_.size(); i++) {
622 rowsplit[i + 1] = rowsplit[i] + mt_coeffs_distr_.counts[i];
623 }
624 std::vector<int> colsplit({0, b__.size()});
625 std::vector<int> owners(comm_.size());
626 for (int i = 0; i < comm_.size(); i++) {
627 owners[i] = i;
628 }
629 costa::block_t localblock;
630 localblock.data = this->num_mt_ ?
631 this->data_[ispn__.get()].at(sddk::memory_t::host, this->num_pw_, b__.begin()) : nullptr;
632 localblock.ld = this->ld();
633 localblock.row = comm_.rank();
634 localblock.col = 0;
635
636 return costa::custom_layout<std::complex<T>>(comm_.size(), 1, rowsplit.data(), colsplit.data(),
637 owners.data(), 1, &localblock, 'C');
638 }
639
640 /// Compute checksum of the muffin-tin coefficients.
641 inline auto
643 {
644 std::complex<T> cs{0};
645 if (this->num_mt_ && br__.size()) {
646 if (is_host_memory(mem__)) {
647 for (int ib = br__.begin(); ib < br__.end(); ib++) {
648 auto ptr = this->data_[s__.get()].at(mem__, this->num_pw_, ib);
649 cs = std::accumulate(ptr, ptr + this->num_mt_, cs);
650 }
651 }
652 if (is_device_memory(mem__)) {
653 auto ptr = this->data_[s__.get()].at(mem__, this->num_pw_, br__.begin());
654 cs = checksum_gpu<T>(ptr, this->ld(), this->num_mt_, br__.size());
655 }
656 }
657 comm_.allreduce(&cs, 1);
658 return cs;
659 }
660
661 /// Return vector of muffin-tin coefficients for all atoms.
662 auto
664 {
665 return num_mt_coeffs_;
666 }
667
668 /// Return const reference to the communicator.
669 auto const&
670 comm() const
671 {
672 return comm_;
673 }
674
675 auto const&
676 mt_coeffs_distr() const
677 {
678 return mt_coeffs_distr_;
679 }
680};
681
682/// Wave-functions representation.
683/** Wave-functions consist of two parts: plane-wave part and mufin-tin part. Wave-functions have one or two spin
684 * components. In case of collinear magnetism each component represents a pure (up- or dn-) spinor state and they
685 * are independent. In non-collinear case the two components represent a full spinor state.
686 *
687 * \tparam T Precision type of the wave-functions (double or float).
688 */
689template <typename T>
691{
692 private:
693 /// Pointer to G+k- vectors object.
694 std::shared_ptr<fft::Gvec> gkvec_;
695 public:
696 /// Constructor for pure plane-wave functions.
697 Wave_functions(std::shared_ptr<fft::Gvec> gkvec__, num_mag_dims num_md__, num_bands num_wf__, sddk::memory_t default_mem__)
698 : Wave_functions_mt<T>(gkvec__->comm(), num_md__, num_wf__, default_mem__, gkvec__->count())
699 , gkvec_{gkvec__}
700 {
701 }
702
703 /// Constructor for wave-functions with plane-wave and muffin-tin parts (LAPW case).
704 Wave_functions(std::shared_ptr<fft::Gvec> gkvec__, std::vector<int> num_mt_coeffs__, num_mag_dims num_md__,
705 num_bands num_wf__, sddk::memory_t default_mem__)
706 : Wave_functions_mt<T>(gkvec__->comm(), num_mt_coeffs__, num_md__, num_wf__, default_mem__, gkvec__->count())
707 , gkvec_{gkvec__}
708 {
709 }
710
711 /// Return reference to the plane-wave coefficient for a given plane-wave, spin and band indices.
712 inline auto&
713 pw_coeffs(int ig__, spin_index ispn__, band_index i__)
714 {
715 return this->data_[ispn__.get()](ig__, i__.get());
716 }
717
718 inline auto& pw_coeffs(spin_index ispn__)
719 {
720 return this->data_[ispn__.get()];
721 }
722
723 inline const auto& pw_coeffs(spin_index ispn__) const
724 {
725 return this->data_[ispn__.get()];
726 }
727
728 /// Return COSTA layout for the plane-wave part for a given spin index and band range.
729 auto
731 {
732 PROFILE("sirius::wf::Wave_functions_fft::grid_layout_pw");
733
734 std::vector<int> rowsplit(this->comm_.size() + 1);
735 rowsplit[0] = 0;
736 for (int i = 0; i < this->comm_.size(); i++) {
737 rowsplit[i + 1] = rowsplit[i] + gkvec_->gvec_count(i);
738 }
739 std::vector<int> colsplit({0, b__.size()});
740 std::vector<int> owners(this->comm_.size());
741 for (int i = 0; i < this->comm_.size(); i++) {
742 owners[i] = i;
743 }
744 costa::block_t localblock;
745 localblock.data = const_cast<std::complex<T>*>(this->data_[ispn__.get()].at(sddk::memory_t::host, 0, b__.begin()));
746 localblock.ld = this->ld();
747 localblock.row = this->comm_.rank();
748 localblock.col = 0;
749
750 return costa::custom_layout<std::complex<T>>(this->comm_.size(), 1, rowsplit.data(), colsplit.data(),
751 owners.data(), 1, &localblock, 'C');
752 }
753
754 auto const& gkvec() const
755 {
756 RTE_ASSERT(gkvec_ != nullptr);
757 return *gkvec_;
758 }
759
760 auto gkvec_sptr() const
761 {
762 return gkvec_;
763 }
764
765 inline auto checksum_pw(sddk::memory_t mem__, spin_index s__, band_range b__) const
766 {
767 std::complex<T> cs{0};
768 if (b__.size()) {
769 if (is_host_memory(mem__)) {
770 for (int ib = b__.begin(); ib < b__.end(); ib++) {
771 auto ptr = this->data_[s__.get()].at(mem__, 0, ib);
772 cs = std::accumulate(ptr, ptr + this->num_pw_, cs);
773 }
774 }
775 if (is_device_memory(mem__)) {
776 auto ptr = this->data_[s__.get()].at(mem__, 0, b__.begin());
777 cs = checksum_gpu<T>(ptr, this->ld(), this->num_pw_, b__.size());
778 }
779 this->comm_.allreduce(&cs, 1);
780 }
781 return cs;
782 }
783
784 inline auto checksum(sddk::memory_t mem__, spin_index s__, band_range b__) const
785 {
786 return this->checksum_pw(mem__, s__, b__) + this->checksum_mt(mem__, s__, b__);
787 }
788
789 inline auto checksum(sddk::memory_t mem__, band_range b__) const
790 {
791 std::complex<T> cs{0};
792 for (int is = 0; is < this->num_sc().get(); is++) {
793 cs += this->checksum(mem__, wf::spin_index(is), b__);
794 }
795 return cs;
796 }
797};
798
800{
801 /// Do nothing.
802 static const unsigned int none = 0b0000;
803 /// Shuffle to FFT distribution.
804 static const unsigned int fft_layout = 0b0001;
805 /// Shuffle to back to default slab distribution.
806 static const unsigned int wf_layout = 0b0010;
807};
808
809/// Wave-fucntions in the FFT-friendly distribution.
810/** To reduce the FFT MPI communication volume, it is often beneficial to redistribute wave-functions from
811 * a default slab layout to a FFT-friendly layout. Often this is a full swap from G-vector to band distribution.
812 * In general this is a redistribution of data from [N x 1] to [M x K] MPI grids.
813 \verbatim
814 band index band index band index
815 ┌──────────────┐ ┌───────┬──────┐ ┌───┬───┬───┬──┐
816 │ │ │ │ │ │ │ │ │ │
817 │ │ │ │ │ │ │ │ │ │
818 ├──────────────┤ │ │ │ │ │ │ │ │
819 │ │ │ │ │ │ │ │ │ │
820 │ │ partial │ │ │ full │ │ │ │ │
821 │ │ swap │ │ │ swap │ │ │ │ │
822 G+k index ├──────────────┤ -> ├───────┼──────┤ -> ├───┼───┼───┼──┤
823 (distributed) │ │ │ │ │ │ │ │ │ │
824 │ │ │ │ │ │ │ │ │ │
825 │ │ │ │ │ │ │ │ │ │
826 ├──────────────┤ │ │ │ │ │ │ │ │
827 │ │ │ │ │ │ │ │ │ │
828 │ │ │ │ │ │ │ │ │ │
829 └──────────────┘ └───────┴──────┘ └───┴───┴───┴──┘
830
831 \endverbatim
832
833 Wave-functions in FFT distribution are scalar with only one spin component.
834
835 * \tparam T Precision type of the wave-functions (double or float).
836 */
837template <typename T>
839{
840 private:
841 /// Pointer to FFT-friendly G+k vector deistribution.
842 std::shared_ptr<fft::Gvec_fft> gkvec_fft_;
843 /// Split number of wave-functions between column communicator.
845 /// Pointer to the original wave-functions.
847 /// Spin-index of the wave-function component
849 /// Range of bands in the input wave-functions to be swapped.
851 /// Direction of the reshuffling: to FFT layout or back to WF layout or both.
852 unsigned int shuffle_flag_{0};
853 /// True if the FFT wave-functions are also available on the device.
854 bool on_device_{false};
855
856 /// Return COSTA grd layout description.
857 auto grid_layout(int n__)
858 {
859 PROFILE("sirius::wf::Wave_functions_fft::grid_layout");
860
861 auto& comm_row = gkvec_fft_->comm_fft();
862 auto& comm_col = gkvec_fft_->comm_ortho_fft();
863
864 std::vector<int> rowsplit(comm_row.size() + 1);
865 rowsplit[0] = 0;
866 for (int i = 0; i < comm_row.size(); i++) {
867 rowsplit[i + 1] = rowsplit[i] + gkvec_fft_->count(i);
868 }
869
870 std::vector<int> colsplit(comm_col.size() + 1);
871 colsplit[0] = 0;
872 for (int i = 0; i < comm_col.size(); i++) {
873 colsplit[i + 1] = colsplit[i] + spl_num_wf_.local_size(block_id(i));
874 }
875
876 std::vector<int> owners(gkvec_fft_->gvec().comm().size());
877 for (int i = 0; i < gkvec_fft_->gvec().comm().size(); i++) {
878 owners[i] = i;
879 }
880 costa::block_t localblock;
881 localblock.data = this->data_[0].at(sddk::memory_t::host);
882 localblock.ld = this->ld();
883 localblock.row = gkvec_fft_->comm_fft().rank();
884 localblock.col = comm_col.rank();
885
886 return costa::custom_layout<std::complex<T>>(comm_row.size(), comm_col.size(), rowsplit.data(),
887 colsplit.data(), owners.data(), 1, &localblock, 'C');
888 }
889
890 /// Shuffle wave-function to the FFT distribution.
892 {
893 PROFILE("shuffle_to_fft_layout");
894
895 auto sp = wf_->actual_spin_index(ispn__);
896 auto t0 = ::sirius::time_now();
897 if (false) {
898 auto layout_in = wf_->grid_layout_pw(sp, b__);
899 auto layout_out = this->grid_layout(b__.size());
900
901 costa::transform(layout_in, layout_out, 'N', la::constant<std::complex<T>>::one(),
902 la::constant<std::complex<T>>::zero(), wf_->gkvec().comm().native());
903 } else {
904 /*
905 * old implementation (to be removed when performance of COSTA is understood)
906 */
907 auto& comm_col = gkvec_fft_->comm_ortho_fft();
908
909 /* in full-potential case leading dimenstion is larger than the number of plane-wave
910 * coefficients, so we have to copy data into temporary storage with necessary leading
911 * dimension */
913 if (wf_->ld() == wf_->num_pw_) { /* pure plane-wave coeffs */
914 auto ptr = (wf_->num_pw_ == 0) ? nullptr : wf_->data_[sp.get()].at(sddk::memory_t::host, 0, b__.begin());
915 wf_tmp = sddk::mdarray<std::complex<T>, 2>(ptr, wf_->num_pw_, b__.size());
916 } else {
917 wf_tmp = sddk::mdarray<std::complex<T>, 2>(wf_->num_pw_, b__.size(), sddk::get_memory_pool(sddk::memory_t::host));
918 for (int i = 0; i < b__.size(); i++) {
919 auto in_ptr = wf_->data_[sp.get()].at(sddk::memory_t::host, 0, b__.begin() + i);
920 std::copy(in_ptr, in_ptr + wf_->num_pw_, wf_tmp.at(sddk::memory_t::host, 0, i));
921 }
922 }
923
924 auto* send_buf = (wf_tmp.ld() == 0) ? nullptr : wf_tmp.at(sddk::memory_t::host);
925
926 /* local number of columns */
927 int n_loc = spl_num_wf_.local_size();
928
929 sddk::mdarray<std::complex<T>, 1> recv_buf(gkvec_fft_->count() * n_loc,
930 sddk::get_memory_pool(sddk::memory_t::host), "recv_buf");
931
932 auto& row_distr = gkvec_fft_->gvec_slab();
933
934 /* send and receive dimensions */
935 mpi::block_data_descriptor sd(comm_col.size()), rd(comm_col.size());
936 for (int j = 0; j < comm_col.size(); j++) {
937 sd.counts[j] = spl_num_wf_.local_size(block_id(j)) * row_distr.counts[comm_col.rank()];
938 rd.counts[j] = spl_num_wf_.local_size(block_id(comm_col.rank())) * row_distr.counts[j];
939 }
940 sd.calc_offsets();
941 rd.calc_offsets();
942
943 comm_col.alltoall(send_buf, sd.counts.data(), sd.offsets.data(), recv_buf.at(sddk::memory_t::host),
944 rd.counts.data(), rd.offsets.data());
945
946 /* reorder received blocks */
947 #pragma omp parallel for
948 for (int i = 0; i < n_loc; i++) {
949 for (int j = 0; j < comm_col.size(); j++) {
950 int offset = row_distr.offsets[j];
951 int count = row_distr.counts[j];
952 if (count) {
953 auto from = &recv_buf[offset * n_loc + count * i];
954 std::copy(from, from + count, this->data_[0].at(sddk::memory_t::host, offset, i));
955 }
956 }
957 }
958 }
959
960 if (env::print_performance() && wf_->gkvec().comm().rank() == 0) {
961 auto t = ::sirius::time_interval(t0);
962 std::cout << "[transform_to_fft_layout] throughput: "
963 << 2 * sizeof(T) * wf_->gkvec().num_gvec() * b__.size() / std::pow(2.0, 30) / t << " Gb/sec" << std::endl;
964 }
965 }
966
967 /// Shuffle wave-function to the original slab layout.
969 {
970 PROFILE("shuffle_to_wf_layout");
971
972 auto sp = wf_->actual_spin_index(ispn__);
973 auto pp = env::print_performance();
974
975 auto t0 = ::sirius::time_now();
976 if (false) {
977 auto layout_in = this->grid_layout(b__.size());
978 auto layout_out = wf_->grid_layout_pw(sp, b__);
979
980 costa::transform(layout_in, layout_out, 'N', la::constant<std::complex<T>>::one(),
981 la::constant<std::complex<T>>::zero(), wf_->gkvec().comm().native());
982 } else {
983
984 auto& comm_col = gkvec_fft_->comm_ortho_fft();
985
986 /* local number of columns */
987 int n_loc = spl_num_wf_.local_size();
988
989 /* send buffer */
990 sddk::mdarray<std::complex<T>, 1> send_buf(gkvec_fft_->count() * n_loc,
991 sddk::get_memory_pool(sddk::memory_t::host), "send_buf");
992
993 auto& row_distr = gkvec_fft_->gvec_slab();
994
995 /* reorder sending blocks */
996 #pragma omp parallel for
997 for (int i = 0; i < n_loc; i++) {
998 for (int j = 0; j < comm_col.size(); j++) {
999 int offset = row_distr.offsets[j];
1000 int count = row_distr.counts[j];
1001 if (count) {
1002 auto from = this->data_[0].at(sddk::memory_t::host, offset, i);
1003 std::copy(from, from + count, &send_buf[offset * n_loc + count * i]);
1004 }
1005 }
1006 }
1007 /* send and receive dimensions */
1008 mpi::block_data_descriptor sd(comm_col.size()), rd(comm_col.size());
1009 for (int j = 0; j < comm_col.size(); j++) {
1010 sd.counts[j] = spl_num_wf_.local_size(block_id(comm_col.rank())) * row_distr.counts[j];
1011 rd.counts[j] = spl_num_wf_.local_size(block_id(j)) * row_distr.counts[comm_col.rank()];
1012 }
1013 sd.calc_offsets();
1014 rd.calc_offsets();
1015
1016#if !defined(NDEBUG)
1017 for (int i = 0; i < n_loc; i++) {
1018 for (int j = 0; j < comm_col.size(); j++) {
1019 int offset = row_distr.offsets[j];
1020 int count = row_distr.counts[j];
1021 for (int igg = 0; igg < count; igg++) {
1022 if (send_buf[offset * n_loc + count * i + igg] != this->data_[0](offset + igg, i)) {
1023 RTE_THROW("wrong packing of send buffer");
1024 }
1025 }
1026 }
1027 }
1028#endif
1029 /* full potential wave-functions have extra muffin-tin part;
1030 * that makes the PW block of data not consecutive and thus we need to copy to a consecutive buffer
1031 * for alltoall */
1033 if (wf_->ld() == wf_->num_pw_) { /* pure plane-wave coeffs */
1034 auto ptr = (wf_->num_pw_ == 0) ? nullptr : wf_->data_[sp.get()].at(sddk::memory_t::host, 0, b__.begin());
1035 wf_tmp = sddk::mdarray<std::complex<T>, 2>(ptr, wf_->num_pw_, b__.size());
1036 } else {
1037 wf_tmp = sddk::mdarray<std::complex<T>, 2>(wf_->num_pw_, b__.size(), sddk::get_memory_pool(sddk::memory_t::host));
1038 }
1039
1040 auto* recv_buf = (wf_tmp.ld() == 0) ? nullptr : wf_tmp.at(sddk::memory_t::host);
1041
1042 comm_col.alltoall(send_buf.at(sddk::memory_t::host), sd.counts.data(), sd.offsets.data(), recv_buf,
1043 rd.counts.data(), rd.offsets.data());
1044
1045 if (wf_->ld() != wf_->num_pw_) {
1046 for (int i = 0; i < b__.size(); i++) {
1047 auto out_ptr = wf_->data_[sp.get()].at(sddk::memory_t::host, 0, b__.begin() + i);
1048 std::copy(wf_tmp.at(sddk::memory_t::host, 0, i),
1049 wf_tmp.at(sddk::memory_t::host, 0, i) + wf_->num_pw_, out_ptr);
1050 }
1051 }
1052 }
1053 if (pp && wf_->gkvec().comm().rank() == 0) {
1054 auto t = ::sirius::time_interval(t0);
1055 std::cout << "[transform_from_fft_layout] throughput: "
1056 << 2 * sizeof(T) * wf_->gkvec().num_gvec() * b__.size() / std::pow(2.0, 30) / t << " Gb/sec" << std::endl;
1057 }
1058 }
1059
1060 public:
1061 /// Constructor.
1063 {
1064 }
1065
1066 /// Constructor.
1067 Wave_functions_fft(std::shared_ptr<fft::Gvec_fft> gkvec_fft__, Wave_functions<T>& wf__, spin_index s__,
1068 band_range br__, unsigned int shuffle_flag___)
1069 : gkvec_fft_{gkvec_fft__}
1070 , wf_{&wf__}
1071 , s_{s__}
1072 , br_{br__}
1073 , shuffle_flag_{shuffle_flag___}
1074 {
1075 auto& comm_col = gkvec_fft_->comm_ortho_fft();
1076 spl_num_wf_ = splindex_block<>(br__.size(), n_blocks(comm_col.size()), block_id(comm_col.rank()));
1077 this->num_mt_ = 0;
1078 this->num_md_ = wf::num_mag_dims(0);
1079 this->num_sc_ = wf::num_spins(1);
1081
1082 auto sp = wf_->actual_spin_index(s__);
1083
1084 /* special case when wave-functions are not redistributed */
1085 if (comm_col.size() == 1) {
1086 auto i = wf::band_index(br__.begin());
1087 auto ptr = wf__.at(sddk::memory_t::host, 0, sp, i);
1088 auto ptr_gpu = wf__.data_[sp.get()].on_device() ? wf__.at(sddk::memory_t::device, 0, sp, i) : nullptr;
1089 if (ptr_gpu) {
1090 on_device_ = true;
1091 }
1092 /* make alias to the fraction of the wave-functions */
1093 this->data_[0] = sddk::mdarray<std::complex<T>, 2>(ptr, ptr_gpu, wf__.ld(), this->num_wf_.get());
1094 this->num_pw_ = wf_->num_pw_;
1095 } else {
1096 /* do wave-functions swap */
1097 this->data_[0] = sddk::mdarray<std::complex<T>, 2>(gkvec_fft__->count(), this->num_wf_.get(),
1098 sddk::get_memory_pool(sddk::memory_t::host), "Wave_functions_fft.data");
1099 this->num_pw_ = gkvec_fft__->count();
1100
1102 if (wf__.data_[sp.get()].on_device()) {
1103 /* copy block of wave-functions to host memory before calling COSTA */
1104 auto ptr = wf__.at(sddk::memory_t::host, 0, sp, wf::band_index(br__.begin()));
1105 auto ptr_gpu = wf__.at(sddk::memory_t::device, 0, sp, wf::band_index(br__.begin()));
1106 acc::copyout(ptr, wf__.ld(), ptr_gpu, wf__.ld(), wf__.num_pw_, br__.size());
1107 }
1108 shuffle_to_fft_layout(s__, br__);
1109 }
1110 }
1111 }
1112
1113 /// Move assignment operator.
1115 {
1116 if (this != &src__) {
1117 gkvec_fft_ = src__.gkvec_fft_;
1118 spl_num_wf_ = src__.spl_num_wf_;
1119 wf_ = src__.wf_;
1120 src__.wf_ = nullptr;
1121 s_ = src__.s_;
1122 br_ = src__.br_;
1123 shuffle_flag_ = src__.shuffle_flag_;
1124 on_device_ = src__.on_device_;
1125 this->num_pw_ = src__.num_pw_;
1126 this->num_mt_ = src__.num_mt_;
1127 this->num_md_ = src__.num_md_;
1128 this->num_wf_ = src__.num_wf_;
1129 this->num_sc_ = src__.num_sc_;
1130 for (int is = 0; is < this->num_sc_.get(); is++) {
1131 this->data_[is] = std::move(src__.data_[is]);
1132 }
1133 }
1134 return *this;
1135 }
1136
1137 /// Destructor.
1139 {
1140 if (wf_) {
1141 auto& comm_col = gkvec_fft_->comm_ortho_fft();
1142 if ((comm_col.size() != 1) && (shuffle_flag_ & shuffle_to::wf_layout)) {
1144 auto sp = wf_->actual_spin_index(s_);
1145 if (wf_->data_[sp.get()].on_device()) {
1146 /* copy block of wave-functions to device memory after calling COSTA */
1147 auto ptr = wf_->at(sddk::memory_t::host, 0, sp, wf::band_index(br_.begin()));
1148 auto ptr_gpu = wf_->at(sddk::memory_t::device, 0, sp, wf::band_index(br_.begin()));
1149 acc::copyin(ptr_gpu, wf_->ld(), ptr, wf_->ld(), wf_->num_pw_, br_.size());
1150 }
1151 }
1152 }
1153 }
1154
1155 /// Return local number of wave-functions.
1156 /** Wave-function band index is distributed over the columns of MPI grid. Each group of FFT communiators
1157 * is working on its local set of wave-functions. */
1158 int num_wf_local() const
1159 {
1160 return spl_num_wf_.local_size();
1161 }
1162
1163 /// Return the split index for the number of wave-functions.
1164 auto spl_num_wf() const
1165 {
1166 return spl_num_wf_;
1167 }
1168
1169 /// Return reference to the plane-wave coefficient.
1170 inline std::complex<T>& pw_coeffs(int ig__, band_index b__)
1171 {
1172 return this->data_[0](ig__, b__.get());
1173 }
1174
1175 /// Return pointer to the beginning of wave-functions casted to real type as required by the SpFFT library.
1177 {
1178 return reinterpret_cast<T*>(this->data_[0].at(mem__, 0, b__.get()));
1179 }
1180
1181 /// Return true if data is avaliable on the device memory.
1182 inline auto on_device() const
1183 {
1184 return on_device_;
1185 }
1186
1187 /// Return const pointer to the data for a given plane-wave and band indices.
1188 inline std::complex<T> const*
1189 at(sddk::memory_t mem__, int i__, band_index b__) const
1190 {
1191 return this->data_[0].at(mem__, i__, b__.get());
1192 }
1193
1194 /// Return pointer to the data for a given plane-wave and band indices.
1195 inline auto
1196 at(sddk::memory_t mem__, int i__, band_index b__)
1197 {
1198 return this->data_[0].at(mem__, i__, b__.get());
1199 }
1200};
1201
1202/// For real-type F (double or float).
1203template <typename T, typename F>
1204static inline std::enable_if_t<std::is_scalar<F>::value, F>
1205inner_diag_local_aux(std::complex<T> z1__, std::complex<T> z2__)
1206{
1207 return z1__.real() * z2__.real() + z1__.imag() * z2__.imag();
1208}
1209
1210/// For complex-type F (complex<double> or complex<float>).
1211template <typename T, typename F>
1212static inline std::enable_if_t<!std::is_scalar<F>::value, F>
1213inner_diag_local_aux(std::complex<T> z1__, std::complex<T> z2__)
1214{
1215 return std::conj(z1__) * z2__;
1216}
1217
1218template <typename T, typename F>
1219static auto
1220inner_diag_local(sddk::memory_t mem__, wf::Wave_functions<T> const& lhs__, wf::Wave_functions_base<T> const& rhs__,
1221 wf::spin_range spins__, wf::num_bands num_wf__)
1222{
1223 RTE_ASSERT(lhs__.ld() == rhs__.ld());
1224 if (spins__.size() == 2) {
1225 if (lhs__.num_md() != wf::num_mag_dims(3)) {
1226 RTE_THROW("Wave-functions are not spinors");
1227 }
1228 if (rhs__.num_md() != wf::num_mag_dims(3)) {
1229 RTE_THROW("Wave-functions are not spinors");
1230 }
1231 }
1232
1233 std::vector<F> result(num_wf__.get(), 0);
1234
1235 if (is_host_memory(mem__)) {
1236 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1237 auto s1 = lhs__.actual_spin_index(s);
1238 auto s2 = rhs__.actual_spin_index(s);
1239 for (int i = 0; i < num_wf__.get(); i++) {
1240 auto ptr1 = lhs__.at(mem__, 0, s1, wf::band_index(i));
1241 auto ptr2 = rhs__.at(mem__, 0, s2, wf::band_index(i));
1242 for (int j = 0; j < lhs__.ld(); j++) {
1243 result[i] += inner_diag_local_aux<T, F>(ptr1[j], ptr2[j]);
1244 }
1245 /* gamma-point case */
1246 if (std::is_same<F, real_type<F>>::value) {
1247 if (lhs__.comm().rank() == 0) {
1248 result[i] = F(2.0) * result[i] - F(std::real(std::conj(ptr1[0]) * ptr2[0]));
1249 } else {
1250 result[i] *= F(2.0);
1251 }
1252 }
1253 }
1254 }
1255 } else {
1256#if defined(SIRIUS_GPU)
1257 int reduced{0};
1258 /* gamma-point case */
1259 if (std::is_same<F, real_type<F>>::value) {
1260 reduced = lhs__.comm().rank() + 1;
1261 }
1262 sddk::mdarray<F, 1> result_gpu(num_wf__.get());
1263 result_gpu.allocate(mem__).zero(mem__);
1264
1265 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1266 auto s1 = lhs__.actual_spin_index(s);
1267 auto s2 = rhs__.actual_spin_index(s);
1268 auto ptr1 = lhs__.at(mem__, 0, s1, wf::band_index(0));
1269 auto ptr2 = rhs__.at(mem__, 0, s2, wf::band_index(0));
1270 if (std::is_same<T, double>::value) {
1271
1272 if (std::is_same<F, double>::value) {
1273 inner_diag_local_gpu_double_double(ptr1, lhs__.ld(), ptr2, rhs__.ld(), lhs__.ld(), num_wf__.get(),
1274 reduced, result_gpu.at(mem__));
1275 }
1276 if (std::is_same<F, std::complex<double>>::value) {
1277 inner_diag_local_gpu_double_complex_double(ptr1, lhs__.ld(), ptr2, rhs__.ld(), lhs__.ld(), num_wf__.get(),
1278 result_gpu.at(mem__));
1279 }
1280 }
1281 }
1282 result_gpu.copy_to(sddk::memory_t::host);
1283 for (int i = 0; i < num_wf__.get(); i++) {
1284 result[i] = result_gpu[i];
1285 }
1286#endif
1287 }
1288 return result;
1289}
1290
1291template <typename T, typename F>
1292auto
1293inner_diag(sddk::memory_t mem__, wf::Wave_functions<T> const& lhs__, wf::Wave_functions_base<T> const& rhs__,
1294 wf::spin_range spins__, wf::num_bands num_wf__)
1295{
1296 PROFILE("wf::inner_diag");
1297 auto result = inner_diag_local<T, F>(mem__, lhs__, rhs__, spins__, num_wf__);
1298 lhs__.comm().allreduce(result);
1299 return result;
1300}
1301
1302/// For real-type F (double or float).
1303template <typename T, typename F>
1304static inline std::enable_if_t<std::is_scalar<F>::value, std::complex<T>>
1305axpby_aux(F a__, std::complex<T> x__, F b__, std::complex<T> y__)
1306{
1307 return std::complex<T>(a__ * x__.real() + b__ * y__.real(), a__ * x__.imag() + b__ * y__.imag());
1308}
1309
1310/// For complex-type F (double or float).
1311template <typename T, typename F>
1312static inline std::enable_if_t<!std::is_scalar<F>::value, std::complex<T>>
1313axpby_aux(F a__, std::complex<T> x__, F b__, std::complex<T> y__)
1314{
1315 auto z1 = F(x__.real(), x__.imag());
1316 auto z2 = F(y__.real(), y__.imag());
1317 auto z3 = a__ * z1 + b__ * z2;
1318 return std::complex<T>(z3.real(), z3.imag());
1319}
1320
1321/// Perform y <- a * x + b * y type of operation.
1322template <typename T, typename F>
1323void axpby(sddk::memory_t mem__, wf::spin_range spins__, wf::band_range br__, F const* alpha__,
1324 wf::Wave_functions<T> const* x__, F const* beta__, wf::Wave_functions<T>* y__)
1325{
1326 PROFILE("wf::axpby");
1327 if (x__) {
1328 RTE_ASSERT(x__->ld() == y__->ld());
1329 }
1330 if (is_host_memory(mem__)) {
1331 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1332 auto spy = y__->actual_spin_index(s);
1333 auto spx = x__ ? x__->actual_spin_index(s) : spy;
1334 #pragma omp parallel for
1335 for (int i = 0; i < br__.size(); i++) {
1336 auto ptr_y = y__->at(sddk::memory_t::host, 0, spy, wf::band_index(br__.begin() + i));
1337 if (x__) {
1338 auto ptr_x = x__->at(sddk::memory_t::host, 0, spx, wf::band_index(br__.begin() + i));
1339 if (beta__[i] == F(0)) {
1340 for (int j = 0; j < y__->ld(); j++) {
1341 ptr_y[j] = axpby_aux<T, F>(alpha__[i], ptr_x[j], 0.0, 0.0);
1342 }
1343 } else if (alpha__[i] == F(0)) {
1344 for (int j = 0; j < y__->ld(); j++) {
1345 ptr_y[j] = axpby_aux<T, F>(0.0, 0.0, beta__[i], ptr_y[j]);
1346 }
1347 } else {
1348 for (int j = 0; j < y__->ld(); j++) {
1349 ptr_y[j] = axpby_aux<T, F>(alpha__[i], ptr_x[j], beta__[i], ptr_y[j]);
1350 }
1351 }
1352 } else {
1353 for (int j = 0; j < y__->ld(); j++) {
1354 ptr_y[j] = axpby_aux<T, F>(0.0, 0.0, beta__[i], ptr_y[j]);
1355 }
1356 }
1357 }
1358 }
1359 } else {
1360#if defined(SIRIUS_GPU)
1361 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1362 auto spy = y__->actual_spin_index(s);
1363 auto spx = x__ ? x__->actual_spin_index(s) : spy;
1364 auto ptr_y = y__->at(mem__, 0, spy, wf::band_index(br__.begin()));
1365 auto ptr_x = x__ ? x__->at(mem__, 0, spx, wf::band_index(br__.begin())) : nullptr;
1366
1367 sddk::mdarray<F, 1> alpha;
1368 if (x__) {
1369 alpha = sddk::mdarray<F, 1>(const_cast<F*>(alpha__), br__.size());
1370 alpha.allocate(mem__).copy_to(mem__);
1371 }
1372 sddk::mdarray<F, 1> beta(const_cast<F*>(beta__), br__.size());
1373 beta.allocate(mem__).copy_to(mem__);
1374
1375 auto ldx = x__ ? x__->ld() : 0;
1376 auto ptr_alpha = x__ ? alpha.at(mem__) : nullptr;
1377
1378 if (std::is_same<T, double>::value) {
1379
1380 if (std::is_same<F, double>::value) {
1381 axpby_gpu_double_double(br__.size(), ptr_alpha, ptr_x, ldx,
1382 beta.at(mem__), ptr_y, y__->ld(), y__->ld());
1383 }
1384 if (std::is_same<F, std::complex<double>>::value) {
1385 axpby_gpu_double_complex_double(br__.size(), ptr_alpha, ptr_x, ldx,
1386 beta.at(mem__), ptr_y, y__->ld(), y__->ld());
1387 }
1388 }
1389 if (std::is_same<T, float>::value) {
1390 RTE_THROW("[wf::axpby] implement GPU kernel for float");
1391 }
1392 }
1393#endif
1394 }
1395}
1396
1397template <typename T, typename F, typename G>
1398void axpy_scatter(sddk::memory_t mem__, wf::spin_range spins__, F const* alphas__,
1399 Wave_functions<T> const* x__, G const* idx__, Wave_functions<T>* y__, int n__)
1400{
1401 PROFILE("wf::axpy_scatter");
1402 if (is_host_memory(mem__)) {
1403 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1404 auto spy = y__->actual_spin_index(s);
1405 auto spx = x__ ? x__->actual_spin_index(s) : spy;
1406 #pragma omp parallel for
1407 for (int i = 0; i < n__; i++) {
1408 auto ii = idx__[i];
1409 auto alpha = alphas__[i];
1410
1411 auto ptr_y = y__->at(sddk::memory_t::host, 0, spy, wf::band_index(ii));
1412 auto ptr_x = x__->at(sddk::memory_t::host, 0, spx, wf::band_index(i));
1413 for (int j = 0; j < y__->ld(); j++) {
1414 ptr_y[j] += alpha * ptr_x[j];
1415 }
1416 }
1417 }
1418 } else {
1419#if defined(SIRIUS_GPU)
1420 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1421 auto spy = y__->actual_spin_index(s);
1422 auto spx = x__ ? x__->actual_spin_index(s) : spy;
1423
1424 auto ptr_y = y__->at(mem__, 0, spy, wf::band_index(0));
1425 auto ptr_x = x__->at(mem__, 0, spx, wf::band_index(0));
1426
1427 sddk::mdarray<F, 1> alpha(const_cast<F*>(alphas__), n__);
1428 alpha.allocate(mem__).copy_to(mem__);
1429
1430 sddk::mdarray<G, 1> idx(const_cast<G*>(idx__), n__);
1431 idx.allocate(mem__).copy_to(mem__);
1432
1433 if (std::is_same<T, double>::value) {
1434 if (std::is_same<F, double>::value) {
1435 axpy_scatter_gpu_double_double(
1436 n__, alpha.at(mem__), ptr_x, x__->ld(), idx.at(mem__), ptr_y, y__->ld(), y__->ld());
1437 }
1438 if (std::is_same<F, std::complex<double>>::value) {
1439 axpy_scatter_gpu_double_complex_double(
1440 n__, alpha.at(mem__), ptr_x, x__->ld(), idx.at(mem__), ptr_y, y__->ld(), y__->ld());
1441 }
1442 }
1443 if (std::is_same<T, float>::value) {RTE_THROW("[wf::axpy_scatter] implement GPU kernel for float");}
1444 }
1445#endif
1446 }
1447}
1448
1449/// Copy wave-functions.
1450template <typename T, typename F = T>
1451void copy(sddk::memory_t mem__, Wave_functions<T> const& in__, wf::spin_index s_in__, wf::band_range br_in__,
1452 Wave_functions<F>& out__, wf::spin_index s_out__, wf::band_range br_out__)
1453{
1454 PROFILE("wf::copy");
1455 RTE_ASSERT(br_in__.size() == br_out__.size());
1456 if (in__.ld() != out__.ld()) {
1457 std::stringstream s;
1458 s << "Leading dimensions of wave-functions do not match" << std::endl
1459 << " in__.ld() = " << in__.ld() << std::endl
1460 << " out__.ld() = " << out__.ld() << std::endl;
1461 RTE_THROW(s);
1462 }
1463
1464 auto in_ptr = in__.at(mem__, 0, s_in__, wf::band_index(br_in__.begin()));
1465 auto out_ptr = out__.at(mem__, 0, s_out__, wf::band_index(br_out__.begin()));
1466
1467 if (is_host_memory(mem__)) {
1468 std::copy(in_ptr, in_ptr + in__.ld() * br_in__.size(), out_ptr);
1469 } else {
1470 if (!std::is_same<T, F>::value) {
1471 RTE_THROW("copy of different types on GPU is not implemented");
1472 }
1473 acc::copy(reinterpret_cast<std::complex<T>*>(out_ptr), in_ptr, in__.ld() * br_in__.size());
1474 }
1475}
1476
1477/// Apply linear transformation to the wave-functions.
1478/**
1479 * \tparam T Precision type of the wave-functions (float or double).
1480 * \tparam F Type of the subspace and transformation matrix (float or double for Gamma-point calculation,
1481 * complex<float> or complex<double> otherwise).
1482 * \param [in] spla_ctx Context of the SPLA library.
1483 * \param [in] mem Location of the input wave-functions (host or device).
1484 * \param [in] M The whole transformation matrix.
1485 * \param [in] irow0 Location of the 1st row of the transfoormation sub-matrix.
1486 * \param [in] jcol0 Location of the 1st column of the transfoormation sub-matrix.
1487 */
1488template <typename T, typename F>
1489inline std::enable_if_t<std::is_same<T, real_type<F>>::value, void>
1490transform(::spla::Context& spla_ctx__, sddk::memory_t mem__, la::dmatrix<F> const& M__, int irow0__, int jcol0__,
1491 real_type<F> alpha__, Wave_functions<T> const& wf_in__, spin_index s_in__, band_range br_in__,
1492 real_type<F> beta__, Wave_functions<T>& wf_out__, spin_index s_out__, band_range br_out__)
1493{
1494 PROFILE("wf::transform");
1495
1496 RTE_ASSERT(wf_in__.ld() == wf_out__.ld());
1497
1498 /* spla manages the resources through the context which can be updated during the call;
1499 * that's why the const must be removed here */
1500 auto& spla_mat_dist = const_cast<la::dmatrix<F>&>(M__).spla_distribution();
1501
1502 /* for Gamma point case (transformation matrix is real) we treat complex wave-function coefficients as
1503 * a doubled list of real values */
1504 int ld = wf_in__.ld();
1505 if (std::is_same<F, real_type<F>>::value) {
1506 ld *= 2;
1507 }
1508
1509 F const* mtrx_ptr = M__.size_local() ? M__.at(sddk::memory_t::host, 0, 0) : nullptr;
1510
1511 F const* in_ptr = reinterpret_cast<F const*>(wf_in__.at(mem__, 0, s_in__, band_index(br_in__.begin())));
1512
1513 F* out_ptr = reinterpret_cast<F*>(wf_out__.at(mem__, 0, s_out__, band_index(br_out__.begin())));
1514
1515 spla::pgemm_sbs(ld, br_out__.size(), br_in__.size(), alpha__, in_ptr, ld, mtrx_ptr, M__.ld(), irow0__, jcol0__,
1516 spla_mat_dist, beta__, out_ptr, ld, spla_ctx__);
1517}
1518
1519template <typename T, typename F>
1520inline std::enable_if_t<!std::is_same<T, real_type<F>>::value, void>
1521transform(::spla::Context& spla_ctx__, sddk::memory_t mem__, la::dmatrix<F> const& M__, int irow0__, int jcol0__,
1522 real_type<F> alpha__, Wave_functions<T> const& wf_in__, spin_index s_in__, band_range br_in__,
1523 real_type<F> beta__, Wave_functions<T>& wf_out__, spin_index s_out__, band_range br_out__)
1524{
1525 if (is_device_memory(mem__)) {
1526 RTE_THROW("wf::transform(): mixed FP32/FP64 precision is implemented only for CPU");
1527 }
1528 RTE_ASSERT(wf_in__.ld() == wf_out__.ld());
1529 for (int j = 0; j < br_out__.size(); j++) {
1530 for (int k = 0; k < wf_in__.ld(); k++) {
1531 auto wf_out_ptr = wf_out__.at(sddk::memory_t::host, k, s_out__, wf::band_index(j + br_out__.begin()));
1532 std::complex<real_type<F>> z(0, 0);;
1533 for (int i = 0; i < br_in__.size(); i++) {
1534 auto wf_in_ptr = wf_in__.at(sddk::memory_t::host, k, s_in__, wf::band_index(i + br_in__.begin()));
1535
1536 z += static_cast<std::complex<real_type<F>>>(*wf_in_ptr) * M__(irow0__ + i, jcol0__ + j);
1537 }
1538 if (beta__ == 0) {
1539 *wf_out_ptr = alpha__ * z;
1540 } else {
1541 *wf_out_ptr = alpha__ * z + static_cast<std::complex<real_type<F>>>(*wf_out_ptr) * beta__;
1542 }
1543 }
1544 }
1545}
1546
1547/// Scale G=0 component of the wave-functions.
1548/** This is needed for the Gamma-point calculation to exclude the double-counting of G=0 term.
1549 */
1550template <typename T>
1551inline void
1553 wf::band_range br__, T* scale__)
1554{
1555 RTE_ASSERT(spins__.size() == 1);
1556
1557 auto& wf = const_cast<Wave_functions<T>&>(wf__);
1558 RTE_ASSERT(wf.num_sc() == wf::num_spins(1)); // TODO: might be too strong check
1559
1560 /* rank 0 stores the G=0 component */
1561 if (wf.comm().rank() != 0) {
1562 return;
1563 }
1564
1565 auto ld = wf.ld() * 2;
1566
1567 auto sp = wf.actual_spin_index(spins__.begin());
1568
1569 auto ptr = wf.at(mem__, 0, sp, wf::band_index(br__.begin()));
1570 auto m = br__.size();
1571
1572 if (is_device_memory(mem__)) {
1573#if defined(SIRIUS_GPU)
1574 if (std::is_same<T, double>::value) {
1575 acc::blas::dscal(m, reinterpret_cast<double*>(scale__), reinterpret_cast<double*>(ptr), ld);
1576 } else if (std::is_same<T, float>::value) {
1577 acc::blas::sscal(m, reinterpret_cast<float*>(scale__), reinterpret_cast<float*>(ptr), ld);
1578 }
1579#else
1580 RTE_THROW("not compiled with GPU support!");
1581#endif
1582 } else {
1583 if (std::is_same<T, double>::value) {
1584 FORTRAN(dscal)(&m, reinterpret_cast<double*>(scale__), reinterpret_cast<double*>(ptr), &ld);
1585 } else if (std::is_same<T, float>::value) {
1586 FORTRAN(sscal)(&m, reinterpret_cast<float*>(scale__), reinterpret_cast<float*>(ptr), &ld);
1587 }
1588 }
1589}
1590
1591/// Compute inner product between the two sets of wave-functions.
1592/**
1593 * \tparam T Precision type of the wave-functions (float or double).
1594 * \tparam F Type of the subspace (float or double for Gamma-point calculation,
1595 * complex<float> or complex<double> otherwise).
1596 *
1597 * \param [in] spla_ctx Context of the SPLA library.
1598 * \param [in] mem Location of the input wave-functions (host or device).
1599 * \param [in] spins Spin range of the wave-functions.
1600 * \param [in] wf_i Left hand side of <wf_i | wf_j> product.
1601 * \param [in] br_i Band range of the <wf_i| wave-functions.
1602 * \param [in] wf_j Right hand side of <wf_i | wf_j> product.
1603 * \param [in] br_j Band range of the |wf_j> wave-functions.
1604 * \param [out] result Resulting inner product matrix.
1605 * \param [in] irow0 Starting row of the output sub-block.
1606 * \param [in] jcol0 Starting column of the output sub-block.
1607 * \return None
1608 *
1609 * Depending on the spin range this functions computes the inner product between individaul spin components
1610 * or between full spinor wave functions:
1611 * \f[
1612 * M_{irow0+i,jcol0+j} = \sum_{\sigma=s0}^{s1} \langle \phi_{i0 + i}^{\sigma} | \phi_{j0 + j}^{\sigma} \rangle
1613 * \f]
1614 * where i0 and j0 and the dimensions of the resulting inner product matrix are determined by the band ranges for
1615 * bra- and ket- states.
1616 *
1617 * The location of the wave-functions data is determined by the mem parameter. The result is always returned in the
1618 * CPU memory. If resulting matrix is allocated on the GPU memory, the result is copied to GPU as well.
1619 */
1620template <typename F, typename W, typename T>
1621inline std::enable_if_t<std::is_same<T, real_type<F>>::value, void>
1622inner(::spla::Context& spla_ctx__, sddk::memory_t mem__, spin_range spins__, W const& wf_i__, band_range br_i__,
1623 Wave_functions<T> const& wf_j__, band_range br_j__, la::dmatrix<F>& result__, int irow0__, int jcol0__)
1624{
1625 PROFILE("wf::inner");
1626
1627 RTE_ASSERT(wf_i__.ld() == wf_j__.ld());
1628 // RTE_ASSERT((wf_i__.gkvec().reduced() == std::is_same<F, real_type<F>>::value));
1629 RTE_ASSERT((wf_j__.gkvec().reduced() == std::is_same<F, real_type<F>>::value));
1630
1631 if (spins__.size() == 2) {
1632 if (wf_i__.num_md() != wf::num_mag_dims(3)) {
1633 RTE_THROW("input wave-functions are not 2-component spinors");
1634 }
1635 if (wf_j__.num_md() != wf::num_mag_dims(3)) {
1636 RTE_THROW("input wave-functions are not 2-component spinors");
1637 }
1638 }
1639
1640 auto spla_mat_dist = wf_i__.comm().size() > result__.comm().size()
1641 ? spla::MatrixDistribution::create_mirror(wf_i__.comm().native())
1642 : result__.spla_distribution();
1643
1644 auto ld = wf_i__.ld();
1645
1646 F alpha = 1.0;
1647 /* inner product matrix is real */
1648 if (std::is_same<F, real_type<F>>::value) {
1649 alpha = 2.0;
1650 ld *= 2;
1651 }
1652
1653 T scale_half(0.5);
1654 T scale_two(2.0);
1655
1656 /* for Gamma case, contribution of G = 0 vector must not be counted double -> multiply by 0.5 */
1657 if (is_real_v<F>) {
1658 scale_gamma_wf(mem__, wf_j__, spins__, br_j__, &scale_half);
1659 }
1660
1661 F beta = 0.0;
1662
1663 F* result_ptr = result__.size_local() ? result__.at(sddk::memory_t::host, 0, 0) : nullptr;
1664
1665 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1666 auto s_i = wf_i__.actual_spin_index(s);
1667 auto s_j = wf_j__.actual_spin_index(s);
1668 auto wf_i_ptr = wf_i__.at(mem__, 0, s_i, wf::band_index(br_i__.begin()));
1669 auto wf_j_ptr = wf_j__.at(mem__, 0, s_j, wf::band_index(br_j__.begin()));
1670
1671 spla::pgemm_ssb(br_i__.size(), br_j__.size(), ld, SPLA_OP_CONJ_TRANSPOSE,
1672 alpha,
1673 reinterpret_cast<F const*>(wf_i_ptr), ld,
1674 reinterpret_cast<F const*>(wf_j_ptr), ld,
1675 beta,
1676 result_ptr, result__.ld(), irow0__, jcol0__, spla_mat_dist, spla_ctx__);
1677 beta = 1.0;
1678 }
1679
1680 /* for Gamma case, G = 0 vector is rescaled back */
1681 if (is_real_v<F>) {
1682 scale_gamma_wf(mem__, wf_j__, spins__, br_j__, &scale_two);
1683 }
1684
1685 /* make sure result is updated on device as well */
1686 if (result__.on_device()) {
1687 result__.copy_to(sddk::memory_t::device, irow0__, jcol0__, br_i__.size(), br_j__.size());
1688 }
1689}
1690
1691template <typename T, typename F>
1692inline std::enable_if_t<!std::is_same<T, real_type<F>>::value, void>
1693inner(::spla::Context& spla_ctx__, sddk::memory_t mem__, spin_range spins__, Wave_functions<T> const& wf_i__,
1694 band_range br_i__, Wave_functions<T> const& wf_j__, band_range br_j__, la::dmatrix<F>& result__,
1695 int irow0__, int jcol0__)
1696{
1697 if (is_device_memory(mem__)) {
1698 RTE_THROW("wf::inner(): mixed FP32/FP64 precision is implemented only for CPU");
1699 }
1700 RTE_ASSERT(wf_i__.ld() == wf_j__.ld());
1701 RTE_ASSERT((wf_i__.gkvec().reduced() == std::is_same<F, real_type<F>>::value));
1702 RTE_ASSERT((wf_j__.gkvec().reduced() == std::is_same<F, real_type<F>>::value));
1703 for (int i = 0; i < br_i__.size(); i++) {
1704 for (int j = 0; j < br_j__.size(); j++) {
1705 result__(irow0__ + i, jcol0__ + j) = 0.0;
1706 }
1707 }
1708 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1709 auto s_i = wf_i__.actual_spin_index(s);
1710 auto s_j = wf_j__.actual_spin_index(s);
1711 int nk = wf_i__.ld();
1712
1713 for (int i = 0; i < br_i__.size(); i++) {
1714 for (int j = 0; j < br_j__.size(); j++) {
1715 auto wf_i_ptr = wf_i__.at(sddk::memory_t::host, 0, s_i, wf::band_index(br_i__.begin() + i));
1716 auto wf_j_ptr = wf_j__.at(sddk::memory_t::host, 0, s_j, wf::band_index(br_j__.begin() + j));
1717 F z = 0.0;
1718
1719 for (int k = 0; k < nk; k++) {
1720 z += inner_diag_local_aux<T, F>(wf_i_ptr[k], wf_j_ptr[k]);
1721 }
1722 result__(irow0__ + i, jcol0__ + j) += z;
1723 }
1724 }
1725 }
1726}
1727
1728/// Orthogonalize n new wave-functions to the N old wave-functions
1729/** Orthogonalize sets of wave-fuctionsfuctions.
1730\tparam T Precision of the wave-functions (float or double).
1731\tparam F Type of the inner-product matrix (float, double or complex).
1732\param [in] spla_ctx SPLA library context.
1733\param [in] mem Location of the wave-functions data.
1734\param [in] spins Spin index range.
1735\param [in] br_old Band range of the functions that are alredy orthogonal and that will be peojected out.
1736\param [in] br_new Band range of the functions that needed to be orthogonalized.
1737\param [in] wf_i The <wf_i| states used to compute overlap matrix O_{ij}.
1738\param [in] wf_j The |wf_j> states used to compute overlap matrix O_{ij}.
1739\param [out wfs List of wave-functions sets (typically phi, hphi and sphi).
1740\param [out] o Work matrix to compute overlap <wf_i|wf_j>
1741\param [out] tmp Temporary wave-functions to store intermediate results.
1742\param [in] project_out Project out old subspace (if this was not done before).
1743\return Number of linearly independent wave-functions found.
1744*/
1745template <typename T, typename F>
1746int
1747orthogonalize(::spla::Context& spla_ctx__, sddk::memory_t mem__, spin_range spins__, band_range br_old__,
1748 band_range br_new__, Wave_functions<T> const& wf_i__, Wave_functions<T> const& wf_j__,
1749 std::vector<Wave_functions<T>*> wfs__, la::dmatrix<F>& o__, Wave_functions<T>& tmp__, bool project_out__)
1750{
1751 PROFILE("wf::orthogonalize");
1752
1753 /* number of new states */
1754 int n = br_new__.size();
1755
1756 auto pp = env::print_performance();
1757
1758 auto& comm = wf_i__.gkvec().comm();
1759
1760 int K{0};
1761
1762 if (pp) {
1763 K = wf_i__.ld();
1764 if (is_real_v<F>) {
1765 K *= 2;
1766 }
1767 }
1768
1769// //auto sddk_debug_ptr = utils::get_env<int>("SDDK_DEBUG");
1770// //int sddk_debug = (sddk_debug_ptr) ? (*sddk_debug_ptr) : 0;
1771//
1772 /* prefactor for the matrix multiplication in complex or double arithmetic (in Giga-operations) */
1773 double ngop{8e-9}; // default value for complex type
1774 if (is_real_v<F>) { // change it if it is real type
1775 ngop = 2e-9;
1776 }
1777
1778 if (pp) {
1779 comm.barrier();
1780 }
1781 auto t0 = ::sirius::time_now();
1782
1783 double gflops{0};
1784
1785 /* project out the old subspace:
1786 * |\tilda phi_new> = |phi_new> - |phi_old><phi_old|phi_new>
1787 * H|\tilda phi_new> = H|phi_new> - H|phi_old><phi_old|phi_new>
1788 * S|\tilda phi_new> = S|phi_new> - S|phi_old><phi_old|phi_new> */
1789 if (br_old__.size() > 0 && project_out__) {
1790 inner(spla_ctx__, mem__, spins__, wf_i__, br_old__, wf_j__, br_new__, o__, 0, 0);
1791 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1792 for (auto wf: wfs__) {
1793 auto sp = wf->actual_spin_index(s);
1794 transform(spla_ctx__, mem__, o__, 0, 0, -1.0, *wf, sp, br_old__, 1.0, *wf, sp, br_new__);
1795 }
1796 }
1797 if (pp) {
1798 /* inner and transform have the same number of flops */
1799 gflops += spins__.size() * static_cast<int>(1 + wfs__.size()) * ngop * br_old__.size() * n * K;
1800 }
1801 }
1802
1803// if (sddk_debug >= 2) {
1804// if (o__.comm().rank() == 0) {
1805// RTE_OUT(std::cout) << "check QR decomposition, matrix size : " << n__ << std::endl;
1806// }
1807// inner(spla_ctx__, spins__, *wfs__[idx_bra__], N__, n__, *wfs__[idx_ket__], N__, n__, o__, 0, 0);
1808//
1809// linalg(lib_t::scalapack).geqrf(n__, n__, o__, 0, 0);
1810// auto diag = o__.get_diag(n__);
1811// if (o__.comm().rank() == 0) {
1812// for (int i = 0; i < n__; i++) {
1813// if (std::abs(diag[i]) < std::numeric_limits<real_type<T>>::epsilon() * 10) {
1814// RTE_OUT(std::cout) << "small norm: " << i << " " << diag[i] << std::endl;
1815// }
1816// }
1817// }
1818//
1819// if (o__.comm().rank() == 0) {
1820// RTE_OUT(std::cout) << "check eigen-values, matrix size : " << n__ << std::endl;
1821// }
1822// inner(spla_ctx__, spins__, *wfs__[idx_bra__], N__, n__, *wfs__[idx_ket__], N__, n__, o__, 0, 0);
1823//
1824// // if (sddk_debug >= 3) {
1825// // save_to_hdf5("nxn_overlap.h5", o__, n__);
1826// //}
1827//
1828// std::vector<real_type<F>> eo(n__);
1829// dmatrix<F> evec(o__.num_rows(), o__.num_cols(), o__.blacs_grid(), o__.bs_row(), o__.bs_col());
1830//
1831// auto solver = (o__.comm().size() == 1) ? Eigensolver_factory("lapack", nullptr) :
1832// Eigensolver_factory("scalapack", nullptr);
1833// solver->solve(n__, o__, eo.data(), evec);
1834//
1835// if (o__.comm().rank() == 0) {
1836// for (int i = 0; i < n__; i++) {
1837// if (eo[i] < 1e-6) {
1838// RTE_OUT(std::cout) << "small eigen-value " << i << " " << eo[i] << std::endl;
1839// }
1840// }
1841// }
1842// }
1843//
1844 /* orthogonalize new n x n block */
1845 inner(spla_ctx__, mem__, spins__, wf_i__, br_new__, wf_j__, br_new__, o__, 0, 0);
1846 if (pp) {
1847 gflops += spins__.size() * ngop * n * n * K;
1848 }
1849
1850 /* At this point overlap matrix is computed for the new block and stored on the CPU. We
1851 * now have this choices
1852 * - mem: CPU
1853 * - o is not distributed
1854 * - potrf is computed on CPU with lapack
1855 * - trtri is computed on CPU with lapack
1856 * - trmm is computed on CPU with blas
1857 * - o is distributed
1858 * - potrf is computed on CPU with scalapack
1859 * - trtri is computed on CPU with scalapack
1860 * - trmm is computed on CPU with wf::transform
1861 *
1862 * - mem: GPU
1863 * - o is not distributed
1864 * - potrf is computed on CPU with lapack; later with cuSolver
1865 * - trtri is computed on CPU with lapack; later with cuSolver
1866 * - trmm is computed on GPU with cublas
1867 *
1868 * - o is distributed
1869 * - potrf is computed on CPU with scalapack
1870 * - trtri is computed on CPU with scalapack
1871 * - trmm is computed on GPU with wf::transform
1872 */
1873 // TODO: test magma and cuSolver
1874 /*
1875 * potrf from cuSolver works in a standalone test, but not here; here it returns -1;
1876 * disbled for further investigation
1877 *
1878 */
1879 auto la = la::lib_t::lapack;
1880 auto la1 = la::lib_t::blas;
1881 auto mem = sddk::memory_t::host;
1882 /* if matrix is distributed, we use ScaLAPACK for Cholesky factorization */
1883 if (o__.comm().size() > 1) {
1885 }
1886 if (is_device_memory(mem__)) {
1887 /* this is for trmm */
1888 la1 = la::lib_t::gpublas;
1889 /* this is for potrf */
1890 //if (o__.comm().size() == 1) {
1891 // mem = mem__;
1892 // la = sddk::linalg_t::gpublas;
1893 //}
1894 }
1895
1896 /* compute the transformation matrix (inverse of the Cholesky factor) */
1897 PROFILE_START("wf::orthogonalize|tmtrx");
1898 auto o_ptr = (o__.size_local() == 0) ? nullptr : o__.at(mem);
1899 if (la == la::lib_t::scalapack) {
1900 o__.make_real_diag(n);
1901 }
1902 /* Cholesky factorization */
1903 if (int info = la::wrap(la).potrf(n, o_ptr, o__.ld(), o__.descriptor())) {
1904 std::stringstream s;
1905 s << "error in Cholesky factorization, info = " << info << std::endl
1906 << "number of existing states: " << br_old__.size() << std::endl
1907 << "number of new states: " << br_new__.size();
1908 RTE_THROW(s);
1909 }
1910 /* inversion of triangular matrix */
1911 if (la::wrap(la).trtri(n, o_ptr, o__.ld(), o__.descriptor())) {
1912 RTE_THROW("error in inversion");
1913 }
1914 PROFILE_STOP("wf::orthogonalize|tmtrx");
1915
1916 /* single MPI rank and precision types of wave-functions and transformation matrices match */
1917 if (o__.comm().size() == 1 && std::is_same<T, real_type<F>>::value) {
1918 PROFILE_START("wf::orthogonalize|trans");
1919 if (is_device_memory(mem__)) {
1920 o__.copy_to(mem__, 0, 0, n, n);
1921 }
1922 int sid{0};
1923 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1924 /* multiplication by triangular matrix */
1925 for (auto& wf : wfs__) {
1926 auto sp = wf->actual_spin_index(s);
1927 auto ptr = reinterpret_cast<F*>(wf->at(mem__, 0, sp, wf::band_index(br_new__.begin())));
1928 int ld = wf->ld();
1929 /* Gamma-point case */
1930 if (is_real_v<F>) {
1931 ld *= 2;
1932 }
1933
1934 la::wrap(la1).trmm('R', 'U', 'N', ld, n, &la::constant<F>::one(),
1935 o__.at(mem__), o__.ld(), ptr, ld, acc::stream_id(sid++));
1936 }
1937 }
1938 if (la1 == la::lib_t::gpublas || la1 == la::lib_t::cublasxt || la1 == la::lib_t::magma) {
1939 /* sync stream only if processing unit is GPU */
1940 for (int i = 0; i < sid; i++) {
1942 }
1943 }
1944 if (pp) {
1945 gflops += spins__.size() * wfs__.size() * ngop * 0.5 * n * n * K;
1946 }
1947 PROFILE_STOP("wf::orthogonalize|trans");
1948 } else {
1949 /* o is upper triangular matrix */
1950 for (int i = 0; i < n; i++) {
1951 for (int j = i + 1; j < n; j++) {
1952 o__.set(j, i, 0);
1953 }
1954 }
1955
1956 /* phi is transformed into phi, so we can't use it as the output buffer;
1957 * use tmp instead and then overwrite phi */
1958 for (auto s = spins__.begin(); s != spins__.end(); s++) {
1959 for (auto wf: wfs__) {
1960 auto sp = wf->actual_spin_index(s);
1961 auto sp1 = tmp__.actual_spin_index(s);
1962 auto br1 = wf::band_range(0, br_new__.size());
1963 transform(spla_ctx__, mem__, o__, 0, 0, 1.0, *wf, sp, br_new__, 0.0, tmp__, sp1, br1);
1964 copy(mem__, tmp__, sp1, br1, *wf, sp, br_new__);
1965 }
1966 }
1967 }
1968
1969//== //
1970//== // if (sddk_debug >= 1) {
1971//== // //auto cs = o__.checksum(n__, n__);
1972//== // //if (o__.comm().rank() == 0) {
1973//== // // //utils::print_checksum("n x n overlap", cs);
1974//== // //}
1975//== // if (o__.comm().rank() == 0) {
1976//== // RTE_OUT(std::cout) << "check diagonal" << std::endl;
1977//== // }
1978//== // auto diag = o__.get_diag(n__);
1979//== // for (int i = 0; i < n__; i++) {
1980//== // if (std::real(diag[i]) <= 0 || std::imag(diag[i]) > 1e-12) {
1981//== // RTE_OUT(std::cout) << "wrong diagonal: " << i << " " << diag[i] << std::endl;
1982//== // }
1983//== // }
1984//== // if (o__.comm().rank() == 0) {
1985//== // RTE_OUT(std::cout) << "check hermitian" << std::endl;
1986//== // }
1987//== // auto d = check_hermitian(o__, n__);
1988//== // if (o__.comm().rank() == 0) {
1989//== // if (d > 1e-12) {
1990//== // std::stringstream s;
1991//== // s << "matrix is not hermitian, max diff = " << d;
1992//== // WARNING(s);
1993//== // } else {
1994//== // RTE_OUT(std::cout) << "OK! n x n overlap matrix is hermitian" << std::endl;
1995//== // }
1996//== // }
1997//== //
1998//== // }
1999//==
2000//== // if (sddk_debug >= 1) {
2001//== // inner(spla_ctx__, spins__, *wfs__[idx_bra__], N__, n__, *wfs__[idx_ket__], N__, n__, o__, 0, 0);
2002//== // auto err = check_identity(o__, n__);
2003//== // if (o__.comm().rank() == 0) {
2004//== // RTE_OUT(std::cout) << "orthogonalization error : " << err << std::endl;
2005//== // }
2006//== // }
2007 if (pp) {
2008 comm.barrier();
2009 auto t = ::sirius::time_interval(t0);
2010 if (comm.rank() == 0) {
2011 RTE_OUT(std::cout) << "effective performance : " << gflops / t << " GFlop/s/rank, "
2012 << gflops * comm.size() / t << " GFlop/s" << std::endl;
2013 }
2014 }
2015
2016 return 0;
2017}
2018
2019} // namespace wf
2020
2021} // namespace sirius
2022
2023#endif
Helper class to wrap stream id (integer number).
Definition: acc.hpp:132
Distributed matrix.
Definition: dmatrix.hpp:56
int potrf(ftn_int n, T *A, ftn_int lda, ftn_int const *desca=nullptr) const
Cholesky factorization.
MPI communicator wrapper.
int size() const
Size of the communicator (number of ranks).
void allreduce(T *buffer__, int count__) const
Perform the in-place (the output buffer is used as the input buffer) all-to-all reduction.
int rank() const
Rank of MPI process inside communicator.
Multidimensional array with the column-major (Fortran) order.
Definition: memory.hpp:660
void copy_to(memory_t mem__, size_t idx0__, size_t n__, acc::stream_id sid=acc::stream_id(-1))
Copy n elements starting from idx0 from one memory type to another.
Definition: memory.hpp:1339
uint32_t ld() const
Return leading dimension size.
Definition: memory.hpp:1233
mdarray< T, N > & allocate(memory_t memory__)
Allocate memory for array.
Definition: memory.hpp:1057
T checksum(size_t idx0__, size_t size__) const
Compute checksum.
Definition: memory.hpp:1262
bool on_device() const
Check if device pointer is available.
Definition: memory.hpp:1382
value_type local_size(block_id block_id__) const
Return local size of the split index for a given block.
Definition: splindex.hpp:302
Base class for the wave-functions.
int num_pw_
Local number of plane-wave coefficients.
Wave_functions_base(int num_pw__, int num_mt__, num_mag_dims num_md__, num_bands num_wf__, sddk::memory_t default_mem__)
Constructor.
auto num_md() const
Return number of magnetic dimensions.
auto num_wf() const
Return number of wave-functions.
auto memory_guard(sddk::memory_t mem__, wf::copy_to copy_to__=copy_to::none) const
Return an instance of the memory guard.
num_spins num_sc_
Number of spin components (1 or 2).
num_mag_dims num_md_
Number of magnetic dimensions (0, 1, or 3).
std::array< sddk::mdarray< std::complex< T >, 2 >, 2 > data_
Data storage for the wave-functions.
void allocate(sddk::memory_t mem__)
Allocate wave-functions.
std::complex< T > const * at(sddk::memory_t mem__, int i__, spin_index s__, band_index b__) const
Return const pointer to the wave-function coefficient at a given index, spin and band.
void zero(sddk::memory_t mem__)
Zero all wave-functions.
auto actual_spin_index(spin_index s__) const
Return the actual spin index of the wave-functions.
void zero(sddk::memory_t mem__, spin_index s__, band_range br__)
Zero a spin component of the wave-functions in a band range.
num_bands num_wf_
Total number of wave-functions.
auto ld() const
Return leading dimensions of the wave-functions coefficients array.
int num_mt_
Local number of muffin-tin coefficients.
auto num_sc() const
Return number of spin components.
auto at(sddk::memory_t mem__, int i__, spin_index s__, band_index b__)
Return pointer to the wave-function coefficient at a given index, spin and band.
void deallocate(sddk::memory_t mem__)
Deallocate wave-functions.
void copy_to(sddk::memory_t mem__)
Copy date to host or device.
Wave-fucntions in the FFT-friendly distribution.
Wave_functions_fft & operator=(Wave_functions_fft &&src__)
Move assignment operator.
std::complex< T > const * at(sddk::memory_t mem__, int i__, band_index b__) const
Return const pointer to the data for a given plane-wave and band indices.
void shuffle_to_wf_layout(spin_index ispn__, band_range b__)
Shuffle wave-function to the original slab layout.
int num_wf_local() const
Return local number of wave-functions.
unsigned int shuffle_flag_
Direction of the reshuffling: to FFT layout or back to WF layout or both.
spin_index s_
Spin-index of the wave-function component.
std::shared_ptr< fft::Gvec_fft > gkvec_fft_
Pointer to FFT-friendly G+k vector deistribution.
band_range br_
Range of bands in the input wave-functions to be swapped.
void shuffle_to_fft_layout(spin_index ispn__, band_range b__)
Shuffle wave-function to the FFT distribution.
auto on_device() const
Return true if data is avaliable on the device memory.
auto spl_num_wf() const
Return the split index for the number of wave-functions.
Wave_functions_fft(std::shared_ptr< fft::Gvec_fft > gkvec_fft__, Wave_functions< T > &wf__, spin_index s__, band_range br__, unsigned int shuffle_flag___)
Constructor.
auto grid_layout(int n__)
Return COSTA grd layout description.
auto at(sddk::memory_t mem__, int i__, band_index b__)
Return pointer to the data for a given plane-wave and band indices.
Wave_functions< T > * wf_
Pointer to the original wave-functions.
bool on_device_
True if the FFT wave-functions are also available on the device.
splindex_block spl_num_wf_
Split number of wave-functions between column communicator.
T * pw_coeffs_spfft(sddk::memory_t mem__, band_index b__)
Return pointer to the beginning of wave-functions casted to real type as required by the SpFFT librar...
std::complex< T > & pw_coeffs(int ig__, band_index b__)
Return reference to the plane-wave coefficient.
Wave-functions for the muffin-tin part of LAPW.
std::vector< int > offset_in_local_mt_coeffs_
Local offset in the block of MT coefficients for current rank.
int num_atoms_
Total number of atoms.
std::complex< T > const * at(sddk::memory_t mem__, int xi__, atom_index_t::local ia__, spin_index s__, band_index b__) const
Return const pointer to the coefficient by atomic orbital index, atom, spin and band indices.
auto at(sddk::memory_t mem__, int xi__, atom_index_t::local ia__, spin_index s__, band_index b__)
Return pointer to the coefficient by atomic orbital index, atom, spin and band indices.
auto const & mt_coeffs(int xi__, atom_index_t::local ia__, spin_index ispn__, band_index i__) const
Return const reference to the coefficient by atomic orbital index, atom, spin and band indices.
Wave_functions_mt(mpi::Communicator const &comm__, num_mag_dims num_md__, num_bands num_wf__, sddk::memory_t default_mem__, int num_pw__)
Construct without muffin-tin part.
Wave_functions_mt(mpi::Communicator const &comm__, std::vector< int > num_mt_coeffs__, num_mag_dims num_md__, num_bands num_wf__, sddk::memory_t default_mem__, int num_pw__=0)
Constructor.
std::vector< int > num_mt_coeffs_
Numbef of muffin-tin coefficients for each atom.
auto num_mt_coeffs() const
Return vector of muffin-tin coefficients for all atoms.
mpi::Communicator const & comm_
Communicator that is used to split atoms between MPI ranks.
auto const & spl_num_atoms() const
Return a split index for the number of atoms.
auto checksum_mt(sddk::memory_t mem__, spin_index s__, band_range br__) const
Compute checksum of the muffin-tin coefficients.
auto & mt_coeffs(int xi__, atom_index_t::local ia__, spin_index ispn__, band_index i__)
Return reference to the coefficient by atomic orbital index, atom, spin and band indices.
auto grid_layout_mt(spin_index ispn__, band_range b__)
Return COSTA layout for the muffin-tin part for a given spin index and band range.
static int get_local_num_mt_coeffs(std::vector< int > num_mt_coeffs__, mpi::Communicator const &comm__)
Calculate the local number of muffin-tin coefficients.
auto const & comm() const
Return const reference to the communicator.
mpi::block_data_descriptor mt_coeffs_distr_
Local size of muffin-tin coefficients for each rank.
splindex_block< atom_index_t > spl_num_atoms_
Distribution of atoms between MPI ranks.
void copy_mt_to(sddk::memory_t mem__, spin_index s__, band_range br__)
Copy muffin-tin coefficients to host or GPU memory.
Wave-functions representation.
Wave_functions(std::shared_ptr< fft::Gvec > gkvec__, num_mag_dims num_md__, num_bands num_wf__, sddk::memory_t default_mem__)
Constructor for pure plane-wave functions.
std::shared_ptr< fft::Gvec > gkvec_
Pointer to G+k- vectors object.
auto & pw_coeffs(int ig__, spin_index ispn__, band_index i__)
Return reference to the plane-wave coefficient for a given plane-wave, spin and band indices.
Wave_functions(std::shared_ptr< fft::Gvec > gkvec__, std::vector< int > num_mt_coeffs__, num_mag_dims num_md__, num_bands num_wf__, sddk::memory_t default_mem__)
Constructor for wave-functions with plane-wave and muffin-tin parts (LAPW case).
auto grid_layout_pw(spin_index ispn__, band_range b__) const
Return COSTA layout for the plane-wave part for a given spin index and band range.
Describe a range of bands.
Helper class to allocate and copy wave-functions to/from device.
Describe a range of spins.
Get the environment variables.
Declaration and implementation of Gvec class.
Contains definition and implementation of sirius::HDF5_tree class.
Linear algebra interface.
Memory management functions and classes.
bool is_device_memory(memory_t mem__)
Check if this is a valid device memory (memory, accessible by the device).
Definition: memory.hpp:93
memory_t
Memory types where the code can store data.
Definition: memory.hpp:71
bool is_host_memory(memory_t mem__)
Check if this is a valid host memory (memory, accessible by the host).
Definition: memory.hpp:86
void copyout(T *target__, T const *source__, size_t n__)
Copy memory from device to host.
Definition: acc.hpp:367
void copyin(T *target__, T const *source__, size_t n__)
Copy memory from host to device.
Definition: acc.hpp:337
void sync_stream(stream_id sid__)
Synchronize a single stream.
Definition: acc.hpp:234
void copy(T *target__, T const *source__, size_t n__)
Copy memory inside a device.
Definition: acc.hpp:320
void zero(T *ptr__, size_t n__)
Zero the device memory.
Definition: acc.hpp:397
@ cublasxt
cuBlasXt (cuBlas with CPU pointers and large matrices support)
@ scalapack
CPU ScaLAPACK.
@ lapack
CPU LAPACK.
@ magma
MAGMA with CPU pointers.
@ gpublas
GPU BLAS (cuBlas or ROCblas)
void scale_gamma_wf(sddk::memory_t mem__, wf::Wave_functions< T > const &wf__, wf::spin_range spins__, wf::band_range br__, T *scale__)
Scale G=0 component of the wave-functions.
int orthogonalize(::spla::Context &spla_ctx__, sddk::memory_t mem__, spin_range spins__, band_range br_old__, band_range br_new__, Wave_functions< T > const &wf_i__, Wave_functions< T > const &wf_j__, std::vector< Wave_functions< T > * > wfs__, la::dmatrix< F > &o__, Wave_functions< T > &tmp__, bool project_out__)
Orthogonalize n new wave-functions to the N old wave-functions.
void axpby(sddk::memory_t mem__, wf::spin_range spins__, wf::band_range br__, F const *alpha__, wf::Wave_functions< T > const *x__, F const *beta__, wf::Wave_functions< T > *y__)
Perform y <- a * x + b * y type of operation.
std::enable_if_t< std::is_same< T, real_type< F > >::value, void > inner(::spla::Context &spla_ctx__, sddk::memory_t mem__, spin_range spins__, W const &wf_i__, band_range br_i__, Wave_functions< T > const &wf_j__, band_range br_j__, la::dmatrix< F > &result__, int irow0__, int jcol0__)
Compute inner product between the two sets of wave-functions.
static std::enable_if_t< std::is_scalar< F >::value, F > inner_diag_local_aux(std::complex< T > z1__, std::complex< T > z2__)
For real-type F (double or float).
static std::enable_if_t< std::is_scalar< F >::value, std::complex< T > > axpby_aux(F a__, std::complex< T > x__, F b__, std::complex< T > y__)
For real-type F (double or float).
std::enable_if_t< std::is_same< T, real_type< F > >::value, void > transform(::spla::Context &spla_ctx__, sddk::memory_t mem__, la::dmatrix< F > const &M__, int irow0__, int jcol0__, real_type< F > alpha__, Wave_functions< T > const &wf_in__, spin_index s_in__, band_range br_in__, real_type< F > beta__, Wave_functions< T > &wf_out__, spin_index s_out__, band_range br_out__)
Apply linear transformation to the wave-functions.
void copy(sddk::memory_t mem__, Wave_functions< T > const &in__, wf::spin_index s_in__, wf::band_range br_in__, Wave_functions< F > &out__, wf::spin_index s_out__, wf::band_range br_out__)
Copy wave-functions.
Namespace of the SIRIUS library.
Definition: sirius.f90:5
strong_type< int, struct __block_id_tag > block_id
ID of the block.
Definition: splindex.hpp:108
strong_type< int, struct __n_blocks_tag > n_blocks
Number of blocks to which the global index is split.
Definition: splindex.hpp:105
auto checksum_gpu(std::complex< T > const *wf__, int ld__, int num_rows_loc__, int nwf__)
Add checksum for the arrays on GPUs.
auto conj(double x__)
Return complex conjugate of a number. For a real value this is the number itself.
Definition: math_tools.hpp:165
Eror and warning handling during run-time execution.
A wrapper class to create strong types.
static const unsigned int none
Do nothing.
static const unsigned int fft_layout
Shuffle to FFT distribution.
static const unsigned int wf_layout
Shuffle to back to default slab distribution.
Timing functions.