SIRIUS 7.5.0
Electronic structure library and applications
spheric_function_set.hpp
1#ifndef __SPHERIC_FUNCTION_SET_HPP__
2#define __SPHERIC_FUNCTION_SET_HPP__
3
5//#include "core/strong_type.hpp"
6
7namespace sirius {
8
9using lmax_t = strong_type<int, struct __lmax_t_tag>;
10
11template <typename T, typename I>
13{
14 private:
15 /// Pointer to the unit cell
16 Unit_cell const* unit_cell_{nullptr};
17 /// Text label of the function set.
18 std::string label_;
19 /// List of atoms for which the spherical expansion is defined.
20 std::vector<int> atoms_;
21 /// Split the number of atoms between MPI ranks.
22 /** If the pointer is null, spheric functions set is treated as global, without MPI distribution */
24 /// List of spheric functions.
25 std::vector<Spheric_function<function_domain_t::spectral, T>> func_;
26
27 bool all_atoms_{false};
28
29 void init(std::function<lmax_t(int)> lmax__, spheric_function_set_ptr_t<T> const* sptr__ = nullptr)
30 {
31 func_.resize(unit_cell_->num_atoms());
32
33 auto set_func = [&](int ia)
34 {
35 if (sptr__) {
37 sptr__->ptr + sptr__->lmmax * sptr__->nrmtmax * ia,
38 sptr__->lmmax, unit_cell_->atom(ia).radial_grid());
39 } else {
40 func_[ia] = Spheric_function<function_domain_t::spectral, T>(sf::lmmax(lmax__(ia)),
41 unit_cell_->atom(ia).radial_grid());
42 }
43 };
44
45 if (spl_atoms_) {
46 for (auto it : (*spl_atoms_)) {
47 set_func(atoms_[it.i]);
48 }
49 } else {
50 for (int ia : atoms_) {
51 set_func(ia);
52 }
53 }
54 }
55
56 public:
57 Spheric_function_set()
58 {
59 }
60
61 /// Constructor for all atoms.
62 Spheric_function_set(std::string label__, Unit_cell const& unit_cell__, std::function<lmax_t(int)> lmax__,
63 splindex_block<I> const* spl_atoms__ = nullptr,
64 spheric_function_set_ptr_t<T> const* sptr__ = nullptr)
65 : unit_cell_{&unit_cell__}
66 , label_{label__}
67 , spl_atoms_{spl_atoms__}
68 , all_atoms_{true}
69 {
70 atoms_.resize(unit_cell__.num_atoms());
71 std::iota(atoms_.begin(), atoms_.end(), 0);
72 if (spl_atoms_) {
73 if (spl_atoms_->size() != unit_cell__.num_atoms()) {
74 RTE_THROW("wrong split atom index");
75 }
76 }
77 init(lmax__, sptr__);
78 }
79
80 /// Constructor for a subset of atoms.
81 Spheric_function_set(std::string label__, Unit_cell const& unit_cell__, std::vector<int> atoms__,
82 std::function<lmax_t(int)> lmax__, splindex_block<I> const* spl_atoms__ = nullptr)
83 : unit_cell_{&unit_cell__}
84 , label_{label__}
85 , atoms_{atoms__}
86 , spl_atoms_{spl_atoms__}
87 , all_atoms_{false}
88 {
89 if (spl_atoms_) {
90 if (spl_atoms_->size() != static_cast<int>(atoms__.size())) {
91 RTE_THROW("wrong split atom index");
92 }
93 }
94 init(lmax__);
95 }
96
97 auto const& atoms() const
98 {
99 return atoms_;
100 }
101
102 auto& operator[](int ia__)
103 {
104 return func_[ia__];
105 }
106
107 auto const& operator[](int ia__) const
108 {
109 return func_[ia__];
110 }
111
112 inline auto const& unit_cell() const
113 {
114 return *unit_cell_;
115 }
116
117 inline void zero()
118 {
119 if (unit_cell_) {
120 for (int ia = 0; ia < unit_cell_->num_atoms(); ia++) {
121 if (func_[ia].size()) {
122 func_[ia].zero();
123 }
124 }
125 }
126 }
127
128 /// Synchronize global function.
129 /** Assuming that each MPI rank was handling part of the global spherical function, broadcast data
130 * from each rank. As a result, each rank stores a full and identical copy of global spherical function. */
131 inline void sync(splindex_block<I> const& spl_atoms__)
132 {
133 for (int i = 0; i < spl_atoms__.size(); i++) {
134 auto loc = spl_atoms__.location(typename I::global(i));
135 int ia = atoms_[i];
136 unit_cell_->comm().bcast(func_[ia].at(sddk::memory_t::host), static_cast<int>(func_[ia].size()), loc.ib);
137 }
138 }
139
141 {
142 for (int ia = 0; ia < unit_cell_->num_atoms(); ia++) {
143 if (func_[ia].size() && rhs__[ia].size()) {
144 func_[ia] += rhs__[ia];
145 }
146 }
147 return *this;
148 }
149
150 template <typename T_, typename I_>
151 friend T_
152 inner(Spheric_function_set<T_, I_> const& f1__, Spheric_function_set<T_, I_> const& f2__);
153
154 template <typename T_, typename I_>
155 friend void
156 copy(Spheric_function_set<T_, I_> const& src__, Spheric_function_set<T_, I_>& dest__);
157
158 template <typename T_, typename I_>
159 friend void
160 copy(Spheric_function_set<T_, I_> const& src__, spheric_function_set_ptr_t<T_> dest__);
161
162 template <typename T_, typename I_>
163 friend void
164 copy(spheric_function_set_ptr_t<T_> src__, Spheric_function_set<T_, I_> const& dest__);
165
166 template <typename T_, typename I_>
167 friend void
168 scale(T_ alpha__, Spheric_function_set<T_, I_>& x__);
169
170 template <typename T_, typename I_>
171 friend void
172 axpy(T_ alpha__, Spheric_function_set<T_, I_> const& x__, Spheric_function_set<T_, I_>& y__);
173};
174
175template <typename T, typename I>
176inline T inner(Spheric_function_set<T, I> const& f1__, Spheric_function_set<T, I> const& f2__)
177{
178 auto ptr = (f1__.spl_atoms_) ? f1__.spl_atoms_ : f2__.spl_atoms_;
179
180 /* if both functions are split then the split index must match */
181 if (f1__.spl_atoms_ && f2__.spl_atoms_) {
182 RTE_ASSERT(f1__.spl_atoms_ == f2__.spl_atoms_);
183 }
184
185 T result{0};
186
187 auto const& comm = f1__.unit_cell_->comm();
188
189 if (ptr) {
190 for (int i = 0; i < ptr->local_size(); i++) {
191 int ia = f1__.atoms_[(*ptr).global_index(typename I::local(i))];
192 result += inner(f1__[ia], f2__[ia]);
193 }
194 } else {
195 splindex_block<I> spl_atoms(f1__.atoms_.size(), n_blocks(comm.size()), block_id(comm.rank()));
196 for (int i = 0; i < spl_atoms.local_size(); i++) {
197 int ia = f1__.atoms_[spl_atoms.global_index(typename I::local(i))];
198 result += inner(f1__[ia], f2__[ia]);
199 }
200 }
201 comm.allreduce(&result, 1);
202 return result;
203}
204
205/// Copy from Spheric_function_set to external pointer.
206/** External pointer is assumed to be global. */
207template <typename T, typename I>
208inline void
210{
211 auto p = dest__.ptr;
212 for (auto ia : src__.atoms()) {
213 if (src__[ia].size()) {
214 if (src__[ia].angular_domain_size() > dest__.lmmax) {
215 RTE_THROW("wrong angular_domain_size");
216 }
217 sddk::mdarray<T, 2> rlm(p, dest__.lmmax, dest__.nrmtmax);
218 for (int ir = 0; ir < src__[ia].radial_grid().num_points(); ir++) {
219 for (int lm = 0; lm < src__[ia].angular_domain_size(); lm++) {
220 rlm(lm, ir) = src__[ia](lm, ir);
221 }
222 }
223 }
224 p += dest__.lmmax * dest__.nrmtmax;
225 }
226 if (src__.spl_atoms_) {
227 int ld = dest__.lmmax * dest__.nrmtmax;
228 src__.unit_cell_->comm().allgather(dest__.ptr, ld * src__.spl_atoms_->local_size(),
229 ld * src__.spl_atoms_->global_offset());
230 }
231}
232
233/// Copy from external pointer to Spheric_function_set.
234/** External pointer is assumed to be global. */
235template <typename T, typename I>
236inline void
238{
239 auto p = src__.ptr;
240 for (auto ia : dest__.atoms()) {
241 if (dest__[ia].size()) {
242 if (dest__[ia].angular_domain_size() > src__.lmmax) {
243 RTE_THROW("wrong angular_domain_size");
244 }
245 sddk::mdarray<T, 2> rlm(p, src__.lmmax, src__.nrmtmax);
246 for (int ir = 0; ir < dest__[ia].radial_grid().num_points(); ir++) {
247 for (int lm = 0; lm < dest__[ia].angular_domain_size(); lm++) {
248 dest__[ia](lm, ir) = rlm(lm, ir);
249 }
250 }
251 }
252 p += src__.lmmax * src__.nrmtmax;
253 }
254}
255
256template <typename T, typename I>
257inline void
258copy(Spheric_function_set<T, I> const& src__, Spheric_function_set<T, I>& dest__)
259{
260 for (int ia = 0; ia < src__.unit_cell_->num_atoms(); ia++) {
261 if (src__.func_[ia].size()) {
262 copy(src__.func_[ia], dest__.func_[ia]);
263 }
264 }
265}
266
267template <typename T, typename I>
268inline void
269scale(T alpha__, Spheric_function_set<T, I>& x__)
270{
271 for (int ia = 0; ia < x__.unit_cell_->num_atoms(); ia++) {
272 if (x__.func_[ia].size()) {
273 x__.func_[ia] *= alpha__;
274 }
275 }
276}
277
278template <typename T, typename I>
279inline void
280axpy(T alpha__, Spheric_function_set<T, I> const& x__, Spheric_function_set<T, I>& y__)
281{
282 for (int ia = 0; ia < x__.unit_cell_->num_atoms(); ia++) {
283 if (x__.func_[ia].size()) {
284 y__.func_[ia] += x__.func_[ia] * alpha__;
285 }
286 }
287}
288
289}
290
291#endif
std::string label_
Text label of the function set.
Spheric_function_set(std::string label__, Unit_cell const &unit_cell__, std::function< lmax_t(int)> lmax__, splindex_block< I > const *spl_atoms__=nullptr, spheric_function_set_ptr_t< T > const *sptr__=nullptr)
Constructor for all atoms.
std::vector< int > atoms_
List of atoms for which the spherical expansion is defined.
Unit_cell const * unit_cell_
Pointer to the unit cell.
std::vector< Spheric_function< function_domain_t::spectral, T > > func_
List of spheric functions.
void sync(splindex_block< I > const &spl_atoms__)
Synchronize global function.
splindex_block< I > const * spl_atoms_
Split the number of atoms between MPI ranks.
Spheric_function_set(std::string label__, Unit_cell const &unit_cell__, std::vector< int > atoms__, std::function< lmax_t(int)> lmax__, splindex_block< I > const *spl_atoms__=nullptr)
Constructor for a subset of atoms.
Function in spherical harmonics or spherical coordinates representation.
Representation of a unit cell.
Definition: unit_cell.hpp:43
Atom const & atom(int id__) const
Return const atom instance by id.
Definition: unit_cell.hpp:344
int num_atoms() const
Number of atoms in the unit cell.
Definition: unit_cell.hpp:338
value_type local_size(block_id block_id__) const
Return local size of the split index for a given block.
Definition: splindex.hpp:302
splindex< Index_t >::location_t location(typename Index_t::global idx__) const
Return "local index, rank" pair for a global index.
Definition: splindex.hpp:319
auto size() const noexcept
Return total length of the index (global number of elements).
Definition: splindex.hpp:213
int lmmax(int lmax)
Maximum number of combinations for a given .
Definition: specfunc.hpp:44
int lm(int l, int m)
Get composite lm index by angular index l and azimuthal index m.
Definition: specfunc.hpp:50
std::enable_if_t< std::is_same< T, real_type< F > >::value, void > inner(::spla::Context &spla_ctx__, sddk::memory_t mem__, spin_range spins__, W const &wf_i__, band_range br_i__, Wave_functions< T > const &wf_j__, band_range br_j__, la::dmatrix< F > &result__, int irow0__, int jcol0__)
Compute inner product between the two sets of wave-functions.
Namespace of the SIRIUS library.
Definition: sirius.f90:5
strong_type< int, struct __block_id_tag > block_id
ID of the block.
Definition: splindex.hpp:108
strong_type< int, struct __n_blocks_tag > n_blocks
Number of blocks to which the global index is split.
Definition: splindex.hpp:105
Contains definition and partial implementation of sirius::Unit_cell class.