48template <
typename FUNC>
62 std::function<
double(
const FUNC&,
const FUNC&)> inner_,
63 std::function<
void(
double, FUNC&)> scal_,
64 std::function<
void(
const FUNC&, FUNC&)> copy_,
65 std::function<
void(
double,
const FUNC&, FUNC&)> axpy_,
66 std::function<
void(
double,
double, FUNC&, FUNC&)> rotate_)
77 : size([](const FUNC&) -> double { return 0; })
78 ,
inner([](
const FUNC&,
const FUNC&) ->
double {
return 0.0; })
79 , scal([](
double, FUNC&) ->
void {})
80 ,
copy([](
const FUNC&, FUNC&) ->
void {})
81 , axpy([](
double,
const FUNC&, FUNC&) ->
void {})
82 , rotate([](
double,
double, FUNC&, FUNC&) ->
void {})
87 std::function<double(
const FUNC&)> size;
90 std::function<double(
const FUNC&,
const FUNC&)>
inner;
93 std::function<void(
double, FUNC&)> scal;
96 std::function<void(
const FUNC&, FUNC&)>
copy;
99 std::function<void(
double,
const FUNC&, FUNC&)> axpy;
102 std::function<void(
double,
double, FUNC&, FUNC&)> rotate;
106namespace mixer_impl {
110template <std::size_t FUNC_REVERSE_INDEX,
bool normalize,
typename... FUNCS>
114 const std::tuple<std::unique_ptr<FUNCS>...>& x,
const std::tuple<std::unique_ptr<FUNCS>...>& y)
117 if (std::get<FUNC_REVERSE_INDEX>(x) && std::get<FUNC_REVERSE_INDEX>(y)) {
119 auto v = std::get<FUNC_REVERSE_INDEX>(function_prop)
120 .inner(*std::get<FUNC_REVERSE_INDEX>(x), *std::get<FUNC_REVERSE_INDEX>(y));
123 auto sx = std::get<FUNC_REVERSE_INDEX>(function_prop).size(*std::get<FUNC_REVERSE_INDEX>(x));
124 auto sy = std::get<FUNC_REVERSE_INDEX>(function_prop).size(*std::get<FUNC_REVERSE_INDEX>(y));
126 throw std::runtime_error(
"[sirius::mixer::InnerProduct] sizes of two functions don't match");
141template <
bool normalize,
typename... FUNCS>
145 const std::tuple<std::unique_ptr<FUNCS>...>& x,
const std::tuple<std::unique_ptr<FUNCS>...>& y)
147 if (std::get<0>(x) && std::get<0>(y)) {
148 auto v = std::get<0>(function_prop).inner(*std::get<0>(x), *std::get<0>(y));
150 auto sx = std::get<0>(function_prop).size(*std::get<0>(x));
151 auto sy = std::get<0>(function_prop).size(*std::get<0>(y));
153 throw std::runtime_error(
"[sirius::mixer::InnerProduct] sizes of two functions don't match");
168template <std::size_t FUNC_REVERSE_INDEX,
typename... FUNCS>
172 std::tuple<std::unique_ptr<FUNCS>...>& x)
174 if (std::get<FUNC_REVERSE_INDEX>(x)) {
175 std::get<FUNC_REVERSE_INDEX>(function_prop).scal(alpha, *std::get<FUNC_REVERSE_INDEX>(x));
181template <
typename... FUNCS>
185 std::tuple<std::unique_ptr<FUNCS>...>& x)
187 if (std::get<0>(x)) {
188 std::get<0>(function_prop).scal(alpha, *std::get<0>(x));
193template <std::size_t FUNC_REVERSE_INDEX,
typename... FUNCS>
197 const std::tuple<std::unique_ptr<FUNCS>...>& x, std::tuple<std::unique_ptr<FUNCS>...>& y)
199 if (std::get<FUNC_REVERSE_INDEX>(x) && std::get<FUNC_REVERSE_INDEX>(y)) {
200 std::get<FUNC_REVERSE_INDEX>(function_prop)
201 .copy(*std::get<FUNC_REVERSE_INDEX>(x), *std::get<FUNC_REVERSE_INDEX>(y));
207template <
typename... FUNCS>
211 const std::tuple<std::unique_ptr<FUNCS>...>& x, std::tuple<std::unique_ptr<FUNCS>...>& y)
213 if (std::get<0>(x) && std::get<0>(y)) {
214 std::get<0>(function_prop).copy(*std::get<0>(x), *std::get<0>(y));
219template <std::size_t FUNC_REVERSE_INDEX,
typename... FUNCS>
223 const std::tuple<std::unique_ptr<FUNCS>...>& x, std::tuple<std::unique_ptr<FUNCS>...>& y)
225 if (std::get<FUNC_REVERSE_INDEX>(x) && std::get<FUNC_REVERSE_INDEX>(y)) {
226 std::get<FUNC_REVERSE_INDEX>(function_prop)
227 .axpy(alpha, *std::get<FUNC_REVERSE_INDEX>(x), *std::get<FUNC_REVERSE_INDEX>(y));
233template <
typename... FUNCS>
237 const std::tuple<std::unique_ptr<FUNCS>...>& x, std::tuple<std::unique_ptr<FUNCS>...>& y)
239 if (std::get<0>(x) && std::get<0>(y)) {
240 std::get<0>(function_prop).axpy(alpha, *std::get<0>(x), *std::get<0>(y));
245template <std::size_t FUNC_REVERSE_INDEX,
typename... FUNCS>
249 std::tuple<std::unique_ptr<FUNCS>...>& x, std::tuple<std::unique_ptr<FUNCS>...>& y)
251 if (std::get<FUNC_REVERSE_INDEX>(x) && std::get<FUNC_REVERSE_INDEX>(y)) {
252 std::get<FUNC_REVERSE_INDEX>(function_prop)
253 .rotate(c, s, *std::get<FUNC_REVERSE_INDEX>(x), *std::get<FUNC_REVERSE_INDEX>(y));
259template <
typename... FUNCS>
263 std::tuple<std::unique_ptr<FUNCS>...>& x, std::tuple<std::unique_ptr<FUNCS>...>& y)
265 if (std::get<0>(x) && std::get<0>(y)) {
266 std::get<0>(function_prop).rotate(c, s, *std::get<0>(x), *std::get<0>(y));
277template <
typename... FUNCS>
281 static_assert(
sizeof...(FUNCS) > 0,
"At least one function type must be provided");
283 static constexpr std::size_t number_of_functions =
sizeof...(FUNCS);
291 , max_history_(max_history)
292 , rmse_history_(max_history)
293 , output_history_(max_history)
294 , residual_history_(max_history)
298 virtual ~Mixer() =
default;
306 template <std::size_t FUNC_INDEX,
typename... ARGS>
308 const FunctionProperties<
typename std::tuple_element<FUNC_INDEX, std::tuple<FUNCS...>>::type>& function_prop,
309 const typename std::tuple_element<FUNC_INDEX, std::tuple<FUNCS...>>::type& init_value, ARGS&&... args)
312 throw std::runtime_error(
"Initializing function_prop after mixing not allowed!");
315 std::get<FUNC_INDEX>(functions_) = function_prop;
321 std::get<FUNC_INDEX>(input_).reset(
322 new typename std::tuple_element<FUNC_INDEX, std::tuple<FUNCS...>>::type(args...));
324 for (std::size_t i = 0; i < max_history_; ++i) {
325 std::get<FUNC_INDEX>(output_history_[i])
326 .reset(
new typename std::tuple_element<FUNC_INDEX, std::tuple<FUNCS...>>::type(args...));
327 std::get<FUNC_INDEX>(residual_history_[i])
328 .reset(
new typename std::tuple_element<FUNC_INDEX, std::tuple<FUNCS...>>::type(args...));
332 std::get<FUNC_INDEX>(functions_).copy(init_value, *std::get<FUNC_INDEX>(output_history_[0]));
333 std::get<FUNC_INDEX>(functions_).copy(init_value, *std::get<FUNC_INDEX>(input_));
339 template <std::
size_t FUNC_INDEX>
340 void set_input(
const typename std::tuple_element<FUNC_INDEX, std::tuple<FUNCS...>>::type& input)
342 if (std::get<FUNC_INDEX>(input_)) {
343 std::get<FUNC_INDEX>(functions_).copy(input, *std::get<FUNC_INDEX>(input_));
345 throw std::runtime_error(
"Mixer function not initialized!");
352 template <std::
size_t FUNC_INDEX>
353 void get_output(
typename std::tuple_element<FUNC_INDEX, std::tuple<FUNCS...>>::type& output)
355 const auto idx = idx_hist(step_);
356 if (!std::get<FUNC_INDEX>(output_history_[idx])) {
357 throw std::runtime_error(
"Mixer function not initialized!");
359 std::get<FUNC_INDEX>(functions_).copy(*std::get<FUNC_INDEX>(output_history_[idx]), output);
366 double mix(
double rms_min__)
368 this->update_residual();
370 double rmse = rmse_history_[idx_hist(step_)];
371 if (rmse < rms_min__) {
384 virtual void mix_impl() = 0;
387 void update_residual()
389 this->copy(input_, residual_history_[idx_hist(step_)]);
390 this->axpy(-1.0, output_history_[idx_hist(step_)], residual_history_[idx_hist(step_)]);
396 const auto idx = idx_hist(step_);
399 double rmse = inner_product<true>(residual_history_[idx], residual_history_[idx]);
401 rmse_history_[idx_hist(step_)] = std::sqrt(rmse);
405 std::size_t idx_hist(std::size_t step)
const
407 return step % max_history_;
410 template <
bool normalize>
411 double inner_product(
const std::tuple<std::unique_ptr<FUNCS>...>& x,
412 const std::tuple<std::unique_ptr<FUNCS>...>& y)
414 return mixer_impl::InnerProduct<
sizeof...(FUNCS) - 1, normalize, FUNCS...>::apply(functions_, x, y);
417 void scale(
double alpha, std::tuple<std::unique_ptr<FUNCS>...>& x)
419 mixer_impl::Scaling<
sizeof...(FUNCS) - 1, FUNCS...>::apply(functions_, alpha, x);
422 void copy(
const std::tuple<std::unique_ptr<FUNCS>...>& x, std::tuple<std::unique_ptr<FUNCS>...>& y)
424 mixer_impl::Copy<
sizeof...(FUNCS) - 1, FUNCS...>::apply(functions_, x, y);
427 void axpy(
double alpha,
const std::tuple<std::unique_ptr<FUNCS>...>& x, std::tuple<std::unique_ptr<FUNCS>...>& y)
429 mixer_impl::Axpy<
sizeof...(FUNCS) - 1, FUNCS...>::apply(functions_, alpha, x, y);
432 void rotate(
double c,
double s, std::tuple<std::unique_ptr<FUNCS>...>& x, std::tuple<std::unique_ptr<FUNCS>...>& y)
434 mixer_impl::Rotate<
sizeof...(FUNCS) - 1, FUNCS...>::apply(functions_, c, s, x, y);
441 std::size_t max_history_;
444 std::vector<double> rmse_history_;
447 std::tuple<FunctionProperties<FUNCS>...> functions_;
450 std::tuple<std::unique_ptr<FUNCS>...> input_;
453 std::vector<std::tuple<std::unique_ptr<FUNCS>...>> output_history_;
456 std::vector<std::tuple<std::unique_ptr<FUNCS>...>> residual_history_;
Abstract mixer for variadic number of Function objects, which are described by FunctionProperties.
void initialize_function(const FunctionProperties< typename std::tuple_element< FUNC_INDEX, std::tuple< FUNCS... > >::type > &function_prop, const typename std::tuple_element< FUNC_INDEX, std::tuple< FUNCS... > >::type &init_value, ARGS &&... args)
Mixer(std::size_t max_history)
Construct a mixer. Functions have to initialized individually.
void set_input(const typename std::tuple_element< FUNC_INDEX, std::tuple< FUNCS... > >::type &input)
Set input for next mixing step.
double mix(double rms_min__)
Mix input and stored history. Returns the root mean square error computed by inner products of residu...
void get_output(typename std::tuple_element< FUNC_INDEX, std::tuple< FUNCS... > >::type &output)
Access last generated output. Mixing must have been performed at least once.
void copy(T *target__, T const *source__, size_t n__)
Copy memory inside a device.
std::enable_if_t< std::is_same< T, real_type< F > >::value, void > inner(::spla::Context &spla_ctx__, sddk::memory_t mem__, spin_range spins__, W const &wf_i__, band_range br_i__, Wave_functions< T > const &wf_j__, band_range br_j__, la::dmatrix< F > &result__, int irow0__, int jcol0__)
Compute inner product between the two sets of wave-functions.
Namespace of the SIRIUS library.
Describes operations on a function type used for mixing.
FunctionProperties(std::function< double(const FUNC &)> size_, std::function< double(const FUNC &, const FUNC &)> inner_, std::function< void(double, FUNC &)> scal_, std::function< void(const FUNC &, FUNC &)> copy_, std::function< void(double, const FUNC &, FUNC &)> axpy_, std::function< void(double, double, FUNC &, FUNC &)> rotate_)
Compute inner product <x|y> between pairs of functions in tuples and accumulate in the result.