1#ifndef __SPHERIC_FUNCTION_SET_HPP__
2#define __SPHERIC_FUNCTION_SET_HPP__
9using lmax_t = strong_type<int, struct __lmax_t_tag>;
11template <
typename T,
typename I>
25 std::vector<Spheric_function<function_domain_t::spectral, T>>
func_;
27 bool all_atoms_{
false};
33 auto set_func = [&](
int ia)
37 sptr__->ptr + sptr__->lmmax * sptr__->nrmtmax * ia,
40 func_[ia] = Spheric_function<function_domain_t::spectral, T>(
sf::lmmax(lmax__(ia)),
57 Spheric_function_set()
74 RTE_THROW(
"wrong split atom index");
91 RTE_THROW(
"wrong split atom index");
97 auto const& atoms()
const
102 auto& operator[](
int ia__)
107 auto const& operator[](
int ia__)
const
112 inline auto const& unit_cell()
const
121 if (
func_[ia].size()) {
133 for (
int i = 0; i < spl_atoms__.
size(); i++) {
134 auto loc = spl_atoms__.
location(
typename I::global(i));
136 unit_cell_->comm().bcast(
func_[ia].at(sddk::memory_t::host),
static_cast<int>(
func_[ia].size()), loc.ib);
143 if (
func_[ia].size() && rhs__[ia].size()) {
144 func_[ia] += rhs__[ia];
150 template <
typename T_,
typename I_>
152 inner(Spheric_function_set<T_, I_>
const& f1__, Spheric_function_set<T_, I_>
const& f2__);
154 template <
typename T_,
typename I_>
156 copy(Spheric_function_set<T_, I_>
const& src__, Spheric_function_set<T_, I_>& dest__);
158 template <
typename T_,
typename I_>
160 copy(Spheric_function_set<T_, I_>
const& src__, spheric_function_set_ptr_t<T_> dest__);
162 template <
typename T_,
typename I_>
164 copy(spheric_function_set_ptr_t<T_> src__, Spheric_function_set<T_, I_>
const& dest__);
166 template <
typename T_,
typename I_>
168 scale(T_ alpha__, Spheric_function_set<T_, I_>& x__);
170 template <
typename T_,
typename I_>
172 axpy(T_ alpha__, Spheric_function_set<T_, I_>
const& x__, Spheric_function_set<T_, I_>& y__);
175template <
typename T,
typename I>
176inline T
inner(Spheric_function_set<T, I>
const& f1__, Spheric_function_set<T, I>
const& f2__)
178 auto ptr = (f1__.spl_atoms_) ? f1__.spl_atoms_ : f2__.spl_atoms_;
181 if (f1__.spl_atoms_ && f2__.spl_atoms_) {
182 RTE_ASSERT(f1__.spl_atoms_ == f2__.spl_atoms_);
187 auto const& comm = f1__.unit_cell_->comm();
190 for (
int i = 0; i < ptr->local_size(); i++) {
191 int ia = f1__.atoms_[(*ptr).global_index(
typename I::local(i))];
192 result +=
inner(f1__[ia], f2__[ia]);
195 splindex_block<I> spl_atoms(f1__.atoms_.size(),
n_blocks(comm.size()),
block_id(comm.rank()));
196 for (
int i = 0; i < spl_atoms.local_size(); i++) {
197 int ia = f1__.atoms_[spl_atoms.global_index(
typename I::local(i))];
198 result +=
inner(f1__[ia], f2__[ia]);
201 comm.allreduce(&result, 1);
207template <
typename T,
typename I>
212 for (
auto ia : src__.atoms()) {
213 if (src__[ia].size()) {
214 if (src__[ia].angular_domain_size() > dest__.lmmax) {
215 RTE_THROW(
"wrong angular_domain_size");
218 for (
int ir = 0; ir < src__[ia].radial_grid().num_points(); ir++) {
219 for (
int lm = 0;
lm < src__[ia].angular_domain_size();
lm++) {
220 rlm(
lm, ir) = src__[ia](
lm, ir);
224 p += dest__.lmmax * dest__.nrmtmax;
227 int ld = dest__.lmmax * dest__.nrmtmax;
235template <
typename T,
typename I>
240 for (
auto ia : dest__.atoms()) {
241 if (dest__[ia].size()) {
242 if (dest__[ia].angular_domain_size() > src__.lmmax) {
243 RTE_THROW(
"wrong angular_domain_size");
246 for (
int ir = 0; ir < dest__[ia].radial_grid().num_points(); ir++) {
247 for (
int lm = 0;
lm < dest__[ia].angular_domain_size();
lm++) {
248 dest__[ia](
lm, ir) = rlm(
lm, ir);
252 p += src__.lmmax * src__.nrmtmax;
256template <
typename T,
typename I>
258copy(Spheric_function_set<T, I>
const& src__, Spheric_function_set<T, I>& dest__)
260 for (
int ia = 0; ia < src__.unit_cell_->num_atoms(); ia++) {
261 if (src__.func_[ia].size()) {
262 copy(src__.func_[ia], dest__.func_[ia]);
267template <
typename T,
typename I>
269scale(T alpha__, Spheric_function_set<T, I>& x__)
271 for (
int ia = 0; ia < x__.unit_cell_->num_atoms(); ia++) {
272 if (x__.func_[ia].size()) {
273 x__.func_[ia] *= alpha__;
278template <
typename T,
typename I>
280axpy(T alpha__, Spheric_function_set<T, I>
const& x__, Spheric_function_set<T, I>& y__)
282 for (
int ia = 0; ia < x__.unit_cell_->num_atoms(); ia++) {
283 if (x__.func_[ia].size()) {
284 y__.func_[ia] += x__.func_[ia] * alpha__;
std::string label_
Text label of the function set.
Spheric_function_set(std::string label__, Unit_cell const &unit_cell__, std::function< lmax_t(int)> lmax__, splindex_block< I > const *spl_atoms__=nullptr, spheric_function_set_ptr_t< T > const *sptr__=nullptr)
Constructor for all atoms.
std::vector< int > atoms_
List of atoms for which the spherical expansion is defined.
Unit_cell const * unit_cell_
Pointer to the unit cell.
std::vector< Spheric_function< function_domain_t::spectral, T > > func_
List of spheric functions.
void sync(splindex_block< I > const &spl_atoms__)
Synchronize global function.
splindex_block< I > const * spl_atoms_
Split the number of atoms between MPI ranks.
Spheric_function_set(std::string label__, Unit_cell const &unit_cell__, std::vector< int > atoms__, std::function< lmax_t(int)> lmax__, splindex_block< I > const *spl_atoms__=nullptr)
Constructor for a subset of atoms.
Function in spherical harmonics or spherical coordinates representation.
Representation of a unit cell.
Atom const & atom(int id__) const
Return const atom instance by id.
int num_atoms() const
Number of atoms in the unit cell.
value_type local_size(block_id block_id__) const
Return local size of the split index for a given block.
splindex< Index_t >::location_t location(typename Index_t::global idx__) const
Return "local index, rank" pair for a global index.
auto size() const noexcept
Return total length of the index (global number of elements).
int lmmax(int lmax)
Maximum number of combinations for a given .
int lm(int l, int m)
Get composite lm index by angular index l and azimuthal index m.
std::enable_if_t< std::is_same< T, real_type< F > >::value, void > inner(::spla::Context &spla_ctx__, sddk::memory_t mem__, spin_range spins__, W const &wf_i__, band_range br_i__, Wave_functions< T > const &wf_j__, band_range br_j__, la::dmatrix< F > &result__, int irow0__, int jcol0__)
Compute inner product between the two sets of wave-functions.
Namespace of the SIRIUS library.
strong_type< int, struct __block_id_tag > block_id
ID of the block.
strong_type< int, struct __n_blocks_tag > n_blocks
Number of blocks to which the global index is split.
Contains definition and partial implementation of sirius::Unit_cell class.