25#ifndef __COMMUNICATOR_HPP__
26#define __COMMUNICATOR_HPP__
51#define CALL_MPI(func__, args__) \
53 if (func__ args__ != MPI_SUCCESS) { \
54 std::printf("error in %s at line %i of file %s\n", #func__, __LINE__, __FILE__); \
55 MPI_Abort(MPI_COMM_WORLD, -1); \
74 operator MPI_Op()
const noexcept {
return MPI_SUM;}
80 operator MPI_Op()
const noexcept {
return MPI_MAX;}
86 operator MPI_Op()
const noexcept {
return MPI_MIN;}
92 operator MPI_Op()
const noexcept {
return MPI_LAND;}
101 operator MPI_Datatype()
const noexcept {
return MPI_FLOAT;}
107 operator MPI_Datatype()
const noexcept {
return MPI_C_FLOAT_COMPLEX;}
113 operator MPI_Datatype()
const noexcept {
return MPI_DOUBLE;}
119 operator MPI_Datatype()
const noexcept {
return MPI_C_DOUBLE_COMPLEX;}
125 operator MPI_Datatype()
const noexcept {
return MPI_LONG_DOUBLE;}
131 operator MPI_Datatype()
const noexcept {
return MPI_INT;}
137 operator MPI_Datatype()
const noexcept {
return MPI_SHORT;}
143 operator MPI_Datatype()
const noexcept {
return MPI_CHAR;}
149 operator MPI_Datatype()
const noexcept {
return MPI_UNSIGNED_CHAR;}
155 operator MPI_Datatype()
const noexcept {
return MPI_UNSIGNED_LONG_LONG;}
161 operator MPI_Datatype()
const noexcept {
return MPI_UNSIGNED_LONG;}
167 operator MPI_Datatype()
const noexcept {
return MPI_C_BOOL;}
173 operator MPI_Datatype()
const noexcept {
return MPI_UINT32_T;}
179 std::vector<int> counts;
180 std::vector<int> offsets;
187 : num_ranks(num_ranks__)
189 counts = std::vector<int>(num_ranks, 0);
190 offsets = std::vector<int>(num_ranks, 0);
195 for (
int i = 1; i < num_ranks; i++) {
196 offsets[i] = offsets[i - 1] + counts[i - 1];
200 inline int size()
const
202 return counts.back() + offsets.back();
209 MPI_Request handler_;
217 CALL_MPI(MPI_Wait, (&handler_, MPI_STATUS_IGNORE));
220 MPI_Request& handler()
228 void operator()(MPI_Comm* comm__)
const
230 int mpi_finalized_flag;
231 MPI_Finalized(&mpi_finalized_flag);
232 if (!mpi_finalized_flag) {
233 CALL_MPI(MPI_Comm_free, (comm__));
285 MPI_Init_thread(NULL, NULL, required__, &provided);
287 MPI_Query_thread(&provided);
288 if ((provided < required__) && (Communicator::world().
rank() == 0)) {
289 std::printf(
"Warning! Required level of thread support is not provided.\n");
290 std::printf(
"provided: %d \nrequired: %d\n", provided, required__);
300 static bool is_finalized()
302 int mpi_finalized_flag;
303 MPI_Finalized(&mpi_finalized_flag);
304 return mpi_finalized_flag ==
true;
325 void abort(
int errcode__)
const
327 CALL_MPI(MPI_Abort, (this->
native(), errcode__));
330 inline Communicator cart_create(
int ndims__,
int const* dims__,
int const* periods__)
const
332 auto comm_sptr = std::shared_ptr<MPI_Comm>(
new MPI_Comm, mpi_comm_deleter());
333 CALL_MPI(MPI_Cart_create, (this->
native(), ndims__, dims__, periods__, 0, comm_sptr.get()));
337 inline Communicator cart_sub(
int const* remain_dims__)
const
339 auto comm_sptr = std::shared_ptr<MPI_Comm>(
new MPI_Comm, mpi_comm_deleter());
340 CALL_MPI(MPI_Cart_sub, (this->
native(), remain_dims__, comm_sptr.get()));
346 auto comm_sptr = std::shared_ptr<MPI_Comm>(
new MPI_Comm, mpi_comm_deleter());
347 CALL_MPI(MPI_Comm_split, (this->
native(), color__,
rank(), comm_sptr.get()));
353 auto comm_sptr = std::shared_ptr<MPI_Comm>(
new MPI_Comm, mpi_comm_deleter());
354 CALL_MPI(MPI_Comm_dup, (this->
native(), comm_sptr.get()));
361 static std::map<int, Communicator> fcomm_map;
362 if (!fcomm_map.count(fcomm__)) {
363 fcomm_map[fcomm__] =
Communicator(MPI_Comm_f2c(fcomm__));
366 auto& comm = fcomm_map[fcomm__];
376 static int get_tag(
int i__,
int j__)
381 return (j__ * (j__ + 1) / 2 + i__ + 1) << 6;
384 static std::string processor_name()
386 char name[MPI_MAX_PROCESSOR_NAME];
388 CALL_MPI(MPI_Get_processor_name, (name, &len));
389 return std::string(name, len);
405 inline int cart_rank(std::vector<int>
const& coords__)
const
407 if (this->
native() == MPI_COMM_SELF) {
412 CALL_MPI(MPI_Cart_rank, (this->
native(), &coords__[0], &r));
416 inline bool is_null()
const
421 inline void barrier()
const
423#if defined(__PROFILE_MPI)
424 PROFILE(
"MPI_Barrier");
426 assert(this->
native() != MPI_COMM_NULL);
427 CALL_MPI(MPI_Barrier, (this->
native()));
430 template <
typename T, op_t mpi_op__ = op_t::sum>
431 inline void reduce(T* buffer__,
int count__,
int root__)
const
433 if (root__ ==
rank()) {
434 CALL_MPI(MPI_Reduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>(),
435 op_wrapper<mpi_op__>(), root__, this->
native()));
437 CALL_MPI(MPI_Reduce, (buffer__, NULL, count__, type_wrapper<T>(),
438 op_wrapper<mpi_op__>(), root__, this->
native()));
442 template <
typename T, op_t mpi_op__ = op_t::sum>
443 inline void reduce(T* buffer__,
int count__,
int root__, MPI_Request* req__)
const
445 if (root__ ==
rank()) {
446 CALL_MPI(MPI_Ireduce, (MPI_IN_PLACE, buffer__, count__, type_wrapper<T>(),
447 op_wrapper<mpi_op__>(), root__, this->
native(), req__));
449 CALL_MPI(MPI_Ireduce, (buffer__, NULL, count__, type_wrapper<T>(),
450 op_wrapper<mpi_op__>(), root__, this->
native(), req__));
454 template <
typename T, op_t mpi_op__ = op_t::sum>
455 void reduce(T
const* sendbuf__, T* recvbuf__,
int count__,
int root__)
const
457 CALL_MPI(MPI_Reduce, (sendbuf__, recvbuf__, count__, type_wrapper<T>(),
458 op_wrapper<mpi_op__>(), root__, this->
native()));
461 template <
typename T, op_t mpi_op__ = op_t::sum>
462 void reduce(T
const* sendbuf__, T* recvbuf__,
int count__,
int root__, MPI_Request* req__)
const
464 CALL_MPI(MPI_Ireduce, (sendbuf__, recvbuf__, count__, type_wrapper<T>(),
465 op_wrapper<mpi_op__>(), root__, this->
native(), req__));
469 template <
typename T, op_t mpi_op__ = op_t::sum>
472 CALL_MPI(MPI_Allreduce, (MPI_IN_PLACE, buffer__, count__,
type_wrapper<T>(),
477 template <
typename T, op_t op__ = op_t::sum>
480 allreduce<T, op__>(buffer__.data(),
static_cast<int>(buffer__.size()));
483 template <
typename T, op_t mpi_op__ = op_t::sum>
484 inline void iallreduce(T* buffer__,
int count__, MPI_Request* req__)
const
486#if defined(__PROFILE_MPI)
487 PROFILE(
"MPI_Iallreduce");
489 CALL_MPI(MPI_Iallreduce, (MPI_IN_PLACE, buffer__, count__,
type_wrapper<T>(),
494 template <
typename T>
495 inline void bcast(T* buffer__,
int count__,
int root__)
const
497#if defined(__PROFILE_MPI)
498 PROFILE(
"MPI_Bcast");
503 inline void bcast(std::string& str__,
int root__)
const
506 if (
rank() == root__) {
507 sz =
static_cast<int>(str__.size());
509 bcast(&sz, 1, root__);
510 char* buf =
new char[sz + 1];
511 if (
rank() == root__) {
512 std::copy(str__.c_str(), str__.c_str() + sz + 1, buf);
514 bcast(buf, sz + 1, root__);
515 str__ = std::string(buf);
520 template <
typename T>
522 allgather(T* buffer__,
int const* recvcounts__,
int const* displs__)
const
524#if defined(__PROFILE_MPI)
525 PROFILE(
"MPI_Allgatherv");
527 CALL_MPI(MPI_Allgatherv, (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, buffer__, recvcounts__, displs__,
532 template <
typename T>
534 allgather(T
const* sendbuf__,
int sendcount__, T* recvbuf__,
int const* recvcounts__,
int const* displs__)
const
536#if defined(__PROFILE_MPI)
537 PROFILE(
"MPI_Allgatherv");
539 CALL_MPI(MPI_Allgatherv, (sendbuf__, sendcount__,
type_wrapper<T>(), recvbuf__, recvcounts__,
543 template <
typename T>
545 allgather(T
const* sendbuf__, T* recvbuf__,
int count__,
int displs__)
const
547 std::vector<int> v(
size() * 2);
548 v[2 *
rank()] = count__;
549 v[2 *
rank() + 1] = displs__;
551 CALL_MPI(MPI_Allgather,
552 (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, v.data(), 2,
type_wrapper<int>(), this->native()));
554 std::vector<int> counts(
size());
555 std::vector<int> displs(
size());
557 for (
int i = 0; i <
size(); i++) {
558 counts[i] = v[2 * i];
559 displs[i] = v[2 * i + 1];
562 CALL_MPI(MPI_Allgatherv, (sendbuf__, count__, type_wrapper<T>(), recvbuf__, counts.data(),
563 displs.data(), type_wrapper<T>(), this->native()));
567 template <
typename T>
571 std::vector<int> v(
size() * 2);
572 v[2 *
rank()] = count__;
573 v[2 *
rank() + 1] = displs__;
575 CALL_MPI(MPI_Allgather,
576 (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, v.data(), 2,
type_wrapper<int>(), this->native()));
578 std::vector<int> counts(
size());
579 std::vector<int> displs(
size());
581 for (
int i = 0; i <
size(); i++) {
582 counts[i] = v[2 * i];
583 displs[i] = v[2 * i + 1];
585 allgather(buffer__, counts.data(), displs.data());
588 template <
typename T>
589 void send(T
const* buffer__,
int count__,
int dest__,
int tag__)
const
591#if defined(__PROFILE_MPI)
597 template <
typename T>
598 Request isend(T
const* buffer__,
int count__,
int dest__,
int tag__)
const
601#if defined(__PROFILE_MPI)
602 PROFILE(
"MPI_Isend");
604 CALL_MPI(MPI_Isend, (buffer__, count__, type_wrapper<T>(), dest__, tag__, this->
native(), &req.handler()));
608 template <
typename T>
609 void recv(T* buffer__,
int count__,
int source__,
int tag__)
const
611#if defined(__PROFILE_MPI)
615 (buffer__, count__, type_wrapper<T>(), source__, tag__, this->
native(), MPI_STATUS_IGNORE));
618 template <
typename T>
619 Request irecv(T* buffer__,
int count__,
int source__,
int tag__)
const
622#if defined(__PROFILE_MPI)
623 PROFILE(
"MPI_Irecv");
625 CALL_MPI(MPI_Irecv, (buffer__, count__, type_wrapper<T>(), source__, tag__, this->
native(), &req.handler()));
629 template <
typename T>
630 void gather(T
const* sendbuf__, T* recvbuf__,
int const* recvcounts__,
int const* displs__,
int root__)
const
632 int sendcount = recvcounts__[
rank()];
634#if defined(__PROFILE_MPI)
635 PROFILE(
"MPI_Gatherv");
637 CALL_MPI(MPI_Gatherv, (sendbuf__, sendcount, type_wrapper<T>(), recvbuf__, recvcounts__, displs__,
638 type_wrapper<T>(), root__, this->
native()));
642 template <
typename T>
643 void gather(T
const* sendbuf__, T* recvbuf__,
int offset__,
int count__,
int root__)
const
646#if defined(__PROFILE_MPI)
647 PROFILE(
"MPI_Gatherv");
649 std::vector<int> v(
size() * 2);
650 v[2 *
rank()] = count__;
651 v[2 *
rank() + 1] = offset__;
653 CALL_MPI(MPI_Allgather,
654 (MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, v.data(), 2,
type_wrapper<int>(), this->native()));
656 std::vector<int> counts(
size());
657 std::vector<int> offsets(
size());
659 for (
int i = 0; i <
size(); i++) {
660 counts[i] = v[2 * i];
661 offsets[i] = v[2 * i + 1];
663 CALL_MPI(MPI_Gatherv, (sendbuf__, count__,
type_wrapper<T>(), recvbuf__, counts.data(),
667 template <
typename T>
668 void scatter(T
const* sendbuf__, T* recvbuf__,
int const* sendcounts__,
int const* displs__,
int root__)
const
670#if defined(__PROFILE_MPI)
671 PROFILE(
"MPI_Scatterv");
673 int recvcount = sendcounts__[
rank()];
674 CALL_MPI(MPI_Scatterv, (sendbuf__, sendcounts__, displs__,
type_wrapper<T>(), recvbuf__, recvcount,
678 template <
typename T>
679 void alltoall(T
const* sendbuf__,
int sendcounts__, T* recvbuf__,
int recvcounts__)
const
681#if defined(__PROFILE_MPI)
682 PROFILE(
"MPI_Alltoall");
684 CALL_MPI(MPI_Alltoall, (sendbuf__, sendcounts__, type_wrapper<T>(), recvbuf__, recvcounts__,
685 type_wrapper<T>(), this->
native()));
688 template <
typename T>
689 void alltoall(T
const* sendbuf__,
int const* sendcounts__,
int const* sdispls__, T* recvbuf__,
690 int const* recvcounts__,
int const* rdispls__)
const
692#if defined(__PROFILE_MPI)
693 PROFILE(
"MPI_Alltoallv");
695 CALL_MPI(MPI_Alltoallv, (sendbuf__, sendcounts__, sdispls__, type_wrapper<T>(), recvbuf__,
696 recvcounts__, rdispls__, type_wrapper<T>(), this->
native()));
MPI communicator wrapper.
void bcast(T *buffer__, int count__, int root__) const
Perform buffer broadcast.
void allgather(T *buffer__, int const *recvcounts__, int const *displs__) const
In-place MPI_Allgatherv.
static void initialize(int required__)
MPI initialization.
void allreduce(std::vector< T > &buffer__) const
Perform the in-place (the output buffer is used as the input buffer) all-to-all reduction.
void allgather(T const *sendbuf__, int sendcount__, T *recvbuf__, int const *recvcounts__, int const *displs__) const
Out-of-place MPI_Allgatherv.
void gather(T const *sendbuf__, T *recvbuf__, int offset__, int count__, int root__) const
Gather data on a given rank.
MPI_Comm mpi_comm_raw_
Raw MPI communicator.
int rank_
Store communicator's rank.
void allgather(T *buffer__, int count__, int displs__) const
In-place MPI_Allgatherv.
int size_
Store communicator's size.
MPI_Comm native() const
Return the native raw MPI communicator handler.
int size() const
Size of the communicator (number of ranks).
Communicator()
Default constructor.
int cart_rank(std::vector< int > const &coords__) const
Rank of MPI process inside communicator with associated Cartesian partitioning.
void allreduce(T *buffer__, int count__) const
Perform the in-place (the output buffer is used as the input buffer) all-to-all reduction.
static Communicator const & map_fcomm(int fcomm__)
Mapping between Fortran and SIRIUS MPI communicators.
int rank() const
Rank of MPI process inside communicator.
static void finalize()
MPI shut down.
std::shared_ptr< MPI_Comm > mpi_comm_
Smart pointer to allocated MPI communicator.
Communicator(MPI_Comm mpi_comm__)
Constructor for existing communicator.
Communicator(std::shared_ptr< MPI_Comm > comm__)
Constructor for new communicator.
void copy(T *target__, T const *source__, size_t n__)
Copy memory inside a device.
int get_device_id(int num_devices__)
Get GPU device id associated with the current rank.
op_t
Tyoe of MPI reduction.
int num_ranks_per_node()
Get number of ranks per node.
Namespace of the SIRIUS library.