SIRIUS 7.5.0
Electronic structure library and applications
generate_fv_states.cpp
1// Copyright (c) 2013-2016 Anton Kozhevnikov, Thomas Schulthess
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted provided that
5// the following conditions are met:
6//
7// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
8// following disclaimer.
9// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
10// and the following disclaimer in the documentation and/or other materials provided with the distribution.
11//
12// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
13// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
15// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
16// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
17// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
18// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19
20/** \file generate_fv_states.hpp
21 *
22 * \brief Contains implementation of sirius::K_point::generate_fv_states method.
23 */
24
25#include "k_point.hpp"
26#include "lapw/generate_alm_block.hpp"
27
28namespace sirius {
29
30template <typename T>
32{
33 PROFILE("sirius::K_point::generate_fv_states");
34
35 if (!ctx_.full_potential()) {
36 return;
37 }
38
39 auto const& uc = ctx_.unit_cell();
40
41 auto pcs = env::print_checksum();
42
43 auto bs = ctx_.cyclic_block_size();
44 la::dmatrix<std::complex<T>> alm_fv(uc.mt_aw_basis_size(), ctx_.num_fv_states(),
45 ctx_.blacs_grid(), bs, bs);
46
47 int atom_begin{0};
48 int mt_aw_offset{0};
49
50 /* loop over blocks of atoms */
51 for (auto na : split_in_blocks(uc.num_atoms(), 64)) {
52 /* actual number of AW radial functions in a block of atoms */
53 int num_mt_aw{0};
54 for (int i = 0; i < na; i++) {
55 int ia = atom_begin + i;
56 auto& type = uc.atom(ia).type();
57 num_mt_aw += type.mt_aw_basis_size();
58 }
59
60 /* generate complex conjugated Alm coefficients for a block of atoms */
61 auto alm = generate_alm_block<false, T>(ctx_, atom_begin, na, this->alm_coeffs_loc());
62 auto cs = alm.checksum();
63 if (pcs) {
64 print_checksum("alm", cs, RTE_OUT(this->out(0)));
65 }
66
67 /* compute F(lm, i) = A(lm, G)^{T} * evec(G, i) for the block of atoms */
68 spla::pgemm_ssb(num_mt_aw, ctx_.num_fv_states(), this->gkvec().count(), SPLA_OP_TRANSPOSE, 1.0,
69 alm.at(sddk::memory_t::host), alm.ld(),
70 &fv_eigen_vectors_slab().pw_coeffs(0, wf::spin_index(0), wf::band_index(0)),
71 fv_eigen_vectors_slab().ld(),
72 0.0, alm_fv.at(sddk::memory_t::host), alm_fv.ld(), mt_aw_offset, 0, alm_fv.spla_distribution(),
73 ctx_.spla_context());
74
75 atom_begin += na;
76 mt_aw_offset += num_mt_aw;
77 }
78
79 std::vector<int> num_mt_apw_coeffs(uc.num_atoms());
80 for (int ia = 0; ia < uc.num_atoms(); ia++) {
81 num_mt_apw_coeffs[ia] = uc.atom(ia).mt_aw_basis_size();
82 }
83 wf::Wave_functions_mt<T> alm_fv_slab(this->comm(), num_mt_apw_coeffs, wf::num_mag_dims(0),
84 wf::num_bands(ctx_.num_fv_states()), sddk::memory_t::host);
85
86 auto& one = la::constant<std::complex<T>>::one();
88
89 auto layout_in = alm_fv.grid_layout(0, 0, uc.mt_aw_basis_size(), ctx_.num_fv_states());
90 auto layout_out = alm_fv_slab.grid_layout_mt(wf::spin_index(0), wf::band_range(0, ctx_.num_fv_states()));
91 costa::transform(layout_in, layout_out, 'N', one, zero, this->comm().native());
92
93 #pragma omp parallel for
94 for (int i = 0; i < ctx_.num_fv_states(); i++) {
95 /* G+k block */
96 auto in_ptr = &fv_eigen_vectors_slab().pw_coeffs(0, wf::spin_index(0), wf::band_index(i));
97 auto out_ptr = &fv_states_->pw_coeffs(0, wf::spin_index(0), wf::band_index(i));
98 std::copy(in_ptr, in_ptr + gkvec().count(), out_ptr);
99
100 for (auto it : alm_fv_slab.spl_num_atoms()) {
101 int num_mt_aw = uc.atom(it.i).type().mt_aw_basis_size();
102 /* aw part of the muffin-tin coefficients */
103 for (int xi = 0; xi < num_mt_aw; xi++) {
104 fv_states_->mt_coeffs(xi, it.li, wf::spin_index(0), wf::band_index(i)) =
105 alm_fv_slab.mt_coeffs(xi, it.li, wf::spin_index(0), wf::band_index(i));
106 }
107 /* lo part of muffin-tin coefficients */
108 for (int xi = 0; xi < uc.atom(it.i).type().mt_lo_basis_size(); xi++) {
109 fv_states_->mt_coeffs(num_mt_aw + xi, it.li, wf::spin_index(0), wf::band_index(i)) =
110 fv_eigen_vectors_slab().mt_coeffs(xi, it.li, wf::spin_index(0), wf::band_index(i));
111 }
112 }
113 }
114 if (pcs) {
115 auto z1 = fv_states_->checksum_pw(sddk::memory_t::host, wf::spin_index(0), wf::band_range(0, ctx_.num_fv_states()));
116 auto z2 = fv_states_->checksum_mt(sddk::memory_t::host, wf::spin_index(0), wf::band_range(0, ctx_.num_fv_states()));
117 print_checksum("fv_states_pw", z1, RTE_OUT(this->out(0)));
118 print_checksum("fv_states_mt", z2, RTE_OUT(this->out(0)));
119
120 }
121}
122
124#ifdef SIRIUS_USE_FP32
126#endif
127
128} // namespace sirius
void generate_fv_states()
Generate first-variational states from eigen-vectors.
Distributed matrix.
Definition: dmatrix.hpp:56
uint32_t ld() const
Return leading dimension size.
Definition: memory.hpp:1233
Wave-functions for the muffin-tin part of LAPW.
auto const & spl_num_atoms() const
Return a split index for the number of atoms.
auto & mt_coeffs(int xi__, atom_index_t::local ia__, spin_index ispn__, band_index i__)
Return reference to the coefficient by atomic orbital index, atom, spin and band indices.
auto grid_layout_mt(spin_index ispn__, band_range b__)
Return COSTA layout for the muffin-tin part for a given spin index and band range.
Describe a range of bands.
Contains definition of sirius::K_point class.
void copy(T *target__, T const *source__, size_t n__)
Copy memory inside a device.
Definition: acc.hpp:320
void zero(T *ptr__, size_t n__)
Zero the device memory.
Definition: acc.hpp:397
std::enable_if_t< std::is_same< T, real_type< F > >::value, void > transform(::spla::Context &spla_ctx__, sddk::memory_t mem__, la::dmatrix< F > const &M__, int irow0__, int jcol0__, real_type< F > alpha__, Wave_functions< T > const &wf_in__, spin_index s_in__, band_range br_in__, real_type< F > beta__, Wave_functions< T > &wf_out__, spin_index s_out__, band_range br_out__)
Apply linear transformation to the wave-functions.
Namespace of the SIRIUS library.
Definition: sirius.f90:5
auto split_in_blocks(int length__, int block_size__)
Split the 'length' elements into blocks with the initial block size.
Definition: splindex.hpp:43