Ginkgo Generated from branch based on main. Ginkgo version 1.9.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
mpi.hpp
1// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_MPI_HPP_
6#define GKO_PUBLIC_CORE_BASE_MPI_HPP_
7
8
9#include <memory>
10#include <type_traits>
11#include <utility>
12
13#include <ginkgo/config.hpp>
14#include <ginkgo/core/base/exception.hpp>
15#include <ginkgo/core/base/exception_helpers.hpp>
16#include <ginkgo/core/base/executor.hpp>
17#include <ginkgo/core/base/types.hpp>
18#include <ginkgo/core/base/utils_helper.hpp>
19
20
21#if GINKGO_BUILD_MPI
22
23
24#include <mpi.h>
25
26
27namespace gko {
28namespace experimental {
35namespace mpi {
36
37
41inline constexpr bool is_gpu_aware()
42{
43#if GINKGO_HAVE_GPU_AWARE_MPI
44 return true;
45#else
46 return false;
47#endif
48}
49
50
58int map_rank_to_device_id(MPI_Comm comm, int num_devices);
59
60
61#define GKO_REGISTER_MPI_TYPE(input_type, mpi_type) \
62 template <> \
63 struct type_impl<input_type> { \
64 static MPI_Datatype get_type() { return mpi_type; } \
65 }
66
75template <typename T>
76struct type_impl {};
77
78
79GKO_REGISTER_MPI_TYPE(char, MPI_CHAR);
80GKO_REGISTER_MPI_TYPE(unsigned char, MPI_UNSIGNED_CHAR);
81GKO_REGISTER_MPI_TYPE(unsigned, MPI_UNSIGNED);
82GKO_REGISTER_MPI_TYPE(int, MPI_INT);
83GKO_REGISTER_MPI_TYPE(unsigned short, MPI_UNSIGNED_SHORT);
84GKO_REGISTER_MPI_TYPE(unsigned long, MPI_UNSIGNED_LONG);
85GKO_REGISTER_MPI_TYPE(long, MPI_LONG);
86GKO_REGISTER_MPI_TYPE(long long, MPI_LONG_LONG_INT);
87GKO_REGISTER_MPI_TYPE(unsigned long long, MPI_UNSIGNED_LONG_LONG);
88GKO_REGISTER_MPI_TYPE(float, MPI_FLOAT);
89GKO_REGISTER_MPI_TYPE(double, MPI_DOUBLE);
90GKO_REGISTER_MPI_TYPE(long double, MPI_LONG_DOUBLE);
91GKO_REGISTER_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX);
92GKO_REGISTER_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
93
94
102public:
109 contiguous_type(int count, MPI_Datatype old_type) : type_(MPI_DATATYPE_NULL)
110 {
111 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_contiguous(count, old_type, &type_));
112 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_commit(&type_));
113 }
114
118 contiguous_type() : type_(MPI_DATATYPE_NULL) {}
119
124
129
135 contiguous_type(contiguous_type&& other) noexcept : type_(MPI_DATATYPE_NULL)
136 {
137 *this = std::move(other);
138 }
139
148 {
149 if (this != &other) {
150 this->type_ = std::exchange(other.type_, MPI_DATATYPE_NULL);
151 }
152 return *this;
153 }
154
159 {
160 if (type_ != MPI_DATATYPE_NULL) {
161 MPI_Type_free(&type_);
162 }
163 }
164
170 MPI_Datatype get() const { return type_; }
171
172private:
173 MPI_Datatype type_;
174};
175
176
181enum class thread_type {
182 serialized = MPI_THREAD_SERIALIZED,
183 funneled = MPI_THREAD_FUNNELED,
184 single = MPI_THREAD_SINGLE,
185 multiple = MPI_THREAD_MULTIPLE
186};
187
188
199public:
200 static bool is_finalized()
201 {
202 int flag = 0;
203 GKO_ASSERT_NO_MPI_ERRORS(MPI_Finalized(&flag));
204 return flag;
205 }
206
207 static bool is_initialized()
208 {
209 int flag = 0;
210 GKO_ASSERT_NO_MPI_ERRORS(MPI_Initialized(&flag));
211 return flag;
212 }
213
219 int get_provided_thread_support() const { return provided_thread_support_; }
220
229 environment(int& argc, char**& argv,
230 const thread_type thread_t = thread_type::serialized)
231 {
232 this->required_thread_support_ = static_cast<int>(thread_t);
233 GKO_ASSERT_NO_MPI_ERRORS(
234 MPI_Init_thread(&argc, &argv, this->required_thread_support_,
235 &(this->provided_thread_support_)));
236 }
237
241 ~environment() { MPI_Finalize(); }
242
243 environment(const environment&) = delete;
244 environment(environment&&) = delete;
245 environment& operator=(const environment&) = delete;
246 environment& operator=(environment&&) = delete;
247
248private:
249 int required_thread_support_;
250 int provided_thread_support_;
251};
252
253
254namespace {
255
256
261class comm_deleter {
262public:
263 using pointer = MPI_Comm*;
264 void operator()(pointer comm) const
265 {
266 GKO_ASSERT(*comm != MPI_COMM_NULL);
267 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_free(comm));
268 delete comm;
269 }
270};
271
272
273} // namespace
274
275
279struct status {
283 status() : status_(MPI_Status{}) {}
284
290 MPI_Status* get() { return &this->status_; }
291
302 template <typename T>
303 int get_count(const T* data) const
304 {
305 int count;
306 MPI_Get_count(&status_, type_impl<T>::get_type(), &count);
307 return count;
308 }
309
310private:
311 MPI_Status status_;
312};
313
314
319class request {
320public:
325 request() : req_(MPI_REQUEST_NULL) {}
326
327 request(const request&) = delete;
328
329 request& operator=(const request&) = delete;
330
331 request(request&& o) noexcept { *this = std::move(o); }
332
333 request& operator=(request&& o) noexcept
334 {
335 if (this != &o) {
336 this->req_ = std::exchange(o.req_, MPI_REQUEST_NULL);
337 }
338 return *this;
339 }
340
341 ~request()
342 {
343 if (req_ != MPI_REQUEST_NULL) {
344 if (MPI_Request_free(&req_) != MPI_SUCCESS) {
345 std::terminate(); // since we can't throw in destructors, we
346 // have to terminate the program
347 }
348 }
349 }
350
356 MPI_Request* get() { return &this->req_; }
357
365 {
367 GKO_ASSERT_NO_MPI_ERRORS(MPI_Wait(&req_, status.get()));
368 return status;
369 }
370
371
372private:
373 MPI_Request req_;
374};
375
376
384inline std::vector<status> wait_all(std::vector<request>& req)
385{
386 std::vector<status> stat;
387 for (std::size_t i = 0; i < req.size(); ++i) {
388 stat.emplace_back(req[i].wait());
389 }
390 return stat;
391}
392
393
409public:
420 communicator(const MPI_Comm& comm, bool force_host_buffer = false)
421 : comm_(), force_host_buffer_(force_host_buffer)
422 {
423 this->comm_.reset(new MPI_Comm(comm));
424 }
425
434 communicator(const MPI_Comm& comm, int color, int key)
435 {
436 MPI_Comm comm_out;
437 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split(comm, color, key, &comm_out));
438 this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
439 }
440
449 communicator(const communicator& comm, int color, int key)
450 {
451 MPI_Comm comm_out;
452 GKO_ASSERT_NO_MPI_ERRORS(
453 MPI_Comm_split(comm.get(), color, key, &comm_out));
454 this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
455 }
456
462 const MPI_Comm& get() const { return *(this->comm_.get()); }
463
464 bool force_host_buffer() const { return force_host_buffer_; }
465
471 int size() const { return get_num_ranks(); }
472
478 int rank() const { return get_my_rank(); };
479
485 int node_local_rank() const { return get_node_local_rank(); };
486
492 bool operator==(const communicator& rhs) const
493 {
494 return compare(rhs.get());
495 }
496
502 bool operator!=(const communicator& rhs) const { return !(*this == rhs); }
503
508 void synchronize() const
509 {
510 GKO_ASSERT_NO_MPI_ERRORS(MPI_Barrier(this->get()));
511 }
512
526 template <typename SendType>
527 void send(std::shared_ptr<const Executor> exec, const SendType* send_buffer,
528 const int send_count, const int destination_rank,
529 const int send_tag) const
530 {
531 auto guard = exec->get_scoped_device_id_guard();
532 GKO_ASSERT_NO_MPI_ERRORS(
533 MPI_Send(send_buffer, send_count, type_impl<SendType>::get_type(),
534 destination_rank, send_tag, this->get()));
535 }
536
553 template <typename SendType>
554 request i_send(std::shared_ptr<const Executor> exec,
555 const SendType* send_buffer, const int send_count,
556 const int destination_rank, const int send_tag) const
557 {
558 auto guard = exec->get_scoped_device_id_guard();
559 request req;
560 GKO_ASSERT_NO_MPI_ERRORS(
561 MPI_Isend(send_buffer, send_count, type_impl<SendType>::get_type(),
562 destination_rank, send_tag, this->get(), req.get()));
563 return req;
564 }
565
581 template <typename RecvType>
582 status recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
583 const int recv_count, const int source_rank,
584 const int recv_tag) const
585 {
586 auto guard = exec->get_scoped_device_id_guard();
587 status st;
588 GKO_ASSERT_NO_MPI_ERRORS(
589 MPI_Recv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
590 source_rank, recv_tag, this->get(), st.get()));
591 return st;
592 }
593
609 template <typename RecvType>
610 request i_recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
611 const int recv_count, const int source_rank,
612 const int recv_tag) const
613 {
614 auto guard = exec->get_scoped_device_id_guard();
615 request req;
616 GKO_ASSERT_NO_MPI_ERRORS(
617 MPI_Irecv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
618 source_rank, recv_tag, this->get(), req.get()));
619 return req;
620 }
621
634 template <typename BroadcastType>
635 void broadcast(std::shared_ptr<const Executor> exec, BroadcastType* buffer,
636 int count, int root_rank) const
637 {
638 auto guard = exec->get_scoped_device_id_guard();
639 GKO_ASSERT_NO_MPI_ERRORS(MPI_Bcast(buffer, count,
641 root_rank, this->get()));
642 }
643
659 template <typename BroadcastType>
660 request i_broadcast(std::shared_ptr<const Executor> exec,
661 BroadcastType* buffer, int count, int root_rank) const
662 {
663 auto guard = exec->get_scoped_device_id_guard();
664 request req;
665 GKO_ASSERT_NO_MPI_ERRORS(
666 MPI_Ibcast(buffer, count, type_impl<BroadcastType>::get_type(),
667 root_rank, this->get(), req.get()));
668 return req;
669 }
670
685 template <typename ReduceType>
686 void reduce(std::shared_ptr<const Executor> exec,
687 const ReduceType* send_buffer, ReduceType* recv_buffer,
688 int count, MPI_Op operation, int root_rank) const
689 {
690 auto guard = exec->get_scoped_device_id_guard();
691 GKO_ASSERT_NO_MPI_ERRORS(MPI_Reduce(send_buffer, recv_buffer, count,
693 operation, root_rank, this->get()));
694 }
695
712 template <typename ReduceType>
713 request i_reduce(std::shared_ptr<const Executor> exec,
714 const ReduceType* send_buffer, ReduceType* recv_buffer,
715 int count, MPI_Op operation, int root_rank) const
716 {
717 auto guard = exec->get_scoped_device_id_guard();
718 request req;
719 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ireduce(
720 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
721 operation, root_rank, this->get(), req.get()));
722 return req;
723 }
724
738 template <typename ReduceType>
739 void all_reduce(std::shared_ptr<const Executor> exec,
740 ReduceType* recv_buffer, int count, MPI_Op operation) const
741 {
742 auto guard = exec->get_scoped_device_id_guard();
743 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
744 MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
745 operation, this->get()));
746 }
747
763 template <typename ReduceType>
764 request i_all_reduce(std::shared_ptr<const Executor> exec,
765 ReduceType* recv_buffer, int count,
766 MPI_Op operation) const
767 {
768 auto guard = exec->get_scoped_device_id_guard();
769 request req;
770 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
771 MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
772 operation, this->get(), req.get()));
773 return req;
774 }
775
790 template <typename ReduceType>
791 void all_reduce(std::shared_ptr<const Executor> exec,
792 const ReduceType* send_buffer, ReduceType* recv_buffer,
793 int count, MPI_Op operation) const
794 {
795 auto guard = exec->get_scoped_device_id_guard();
796 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
797 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
798 operation, this->get()));
799 }
800
817 template <typename ReduceType>
818 request i_all_reduce(std::shared_ptr<const Executor> exec,
819 const ReduceType* send_buffer, ReduceType* recv_buffer,
820 int count, MPI_Op operation) const
821 {
822 auto guard = exec->get_scoped_device_id_guard();
823 request req;
824 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
825 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
826 operation, this->get(), req.get()));
827 return req;
828 }
829
846 template <typename SendType, typename RecvType>
847 void gather(std::shared_ptr<const Executor> exec,
848 const SendType* send_buffer, const int send_count,
849 RecvType* recv_buffer, const int recv_count,
850 int root_rank) const
851 {
852 auto guard = exec->get_scoped_device_id_guard();
853 GKO_ASSERT_NO_MPI_ERRORS(
854 MPI_Gather(send_buffer, send_count, type_impl<SendType>::get_type(),
855 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
856 root_rank, this->get()));
857 }
858
878 template <typename SendType, typename RecvType>
879 request i_gather(std::shared_ptr<const Executor> exec,
880 const SendType* send_buffer, const int send_count,
881 RecvType* recv_buffer, const int recv_count,
882 int root_rank) const
883 {
884 auto guard = exec->get_scoped_device_id_guard();
885 request req;
886 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igather(
887 send_buffer, send_count, type_impl<SendType>::get_type(),
888 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
889 this->get(), req.get()));
890 return req;
891 }
892
911 template <typename SendType, typename RecvType>
912 void gather_v(std::shared_ptr<const Executor> exec,
913 const SendType* send_buffer, const int send_count,
914 RecvType* recv_buffer, const int* recv_counts,
915 const int* displacements, int root_rank) const
916 {
917 auto guard = exec->get_scoped_device_id_guard();
918 GKO_ASSERT_NO_MPI_ERRORS(MPI_Gatherv(
919 send_buffer, send_count, type_impl<SendType>::get_type(),
920 recv_buffer, recv_counts, displacements,
921 type_impl<RecvType>::get_type(), root_rank, this->get()));
922 }
923
944 template <typename SendType, typename RecvType>
945 request i_gather_v(std::shared_ptr<const Executor> exec,
946 const SendType* send_buffer, const int send_count,
947 RecvType* recv_buffer, const int* recv_counts,
948 const int* displacements, int root_rank) const
949 {
950 auto guard = exec->get_scoped_device_id_guard();
951 request req;
952 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igatherv(
953 send_buffer, send_count, type_impl<SendType>::get_type(),
954 recv_buffer, recv_counts, displacements,
955 type_impl<RecvType>::get_type(), root_rank, this->get(),
956 req.get()));
957 return req;
958 }
959
975 template <typename SendType, typename RecvType>
976 void all_gather(std::shared_ptr<const Executor> exec,
977 const SendType* send_buffer, const int send_count,
978 RecvType* recv_buffer, const int recv_count) const
979 {
980 auto guard = exec->get_scoped_device_id_guard();
981 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allgather(
982 send_buffer, send_count, type_impl<SendType>::get_type(),
983 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
984 this->get()));
985 }
986
1005 template <typename SendType, typename RecvType>
1006 request i_all_gather(std::shared_ptr<const Executor> exec,
1007 const SendType* send_buffer, const int send_count,
1008 RecvType* recv_buffer, const int recv_count) const
1009 {
1010 auto guard = exec->get_scoped_device_id_guard();
1011 request req;
1012 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallgather(
1013 send_buffer, send_count, type_impl<SendType>::get_type(),
1014 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1015 this->get(), req.get()));
1016 return req;
1017 }
1018
1034 template <typename SendType, typename RecvType>
1035 void scatter(std::shared_ptr<const Executor> exec,
1036 const SendType* send_buffer, const int send_count,
1037 RecvType* recv_buffer, const int recv_count,
1038 int root_rank) const
1039 {
1040 auto guard = exec->get_scoped_device_id_guard();
1041 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatter(
1042 send_buffer, send_count, type_impl<SendType>::get_type(),
1043 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1044 this->get()));
1045 }
1046
1065 template <typename SendType, typename RecvType>
1066 request i_scatter(std::shared_ptr<const Executor> exec,
1067 const SendType* send_buffer, const int send_count,
1068 RecvType* recv_buffer, const int recv_count,
1069 int root_rank) const
1070 {
1071 auto guard = exec->get_scoped_device_id_guard();
1072 request req;
1073 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscatter(
1074 send_buffer, send_count, type_impl<SendType>::get_type(),
1075 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1076 this->get(), req.get()));
1077 return req;
1078 }
1079
1098 template <typename SendType, typename RecvType>
1099 void scatter_v(std::shared_ptr<const Executor> exec,
1100 const SendType* send_buffer, const int* send_counts,
1101 const int* displacements, RecvType* recv_buffer,
1102 const int recv_count, int root_rank) const
1103 {
1104 auto guard = exec->get_scoped_device_id_guard();
1105 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatterv(
1106 send_buffer, send_counts, displacements,
1107 type_impl<SendType>::get_type(), recv_buffer, recv_count,
1108 type_impl<RecvType>::get_type(), root_rank, this->get()));
1109 }
1110
1131 template <typename SendType, typename RecvType>
1132 request i_scatter_v(std::shared_ptr<const Executor> exec,
1133 const SendType* send_buffer, const int* send_counts,
1134 const int* displacements, RecvType* recv_buffer,
1135 const int recv_count, int root_rank) const
1136 {
1137 auto guard = exec->get_scoped_device_id_guard();
1138 request req;
1139 GKO_ASSERT_NO_MPI_ERRORS(
1140 MPI_Iscatterv(send_buffer, send_counts, displacements,
1141 type_impl<SendType>::get_type(), recv_buffer,
1142 recv_count, type_impl<RecvType>::get_type(),
1143 root_rank, this->get(), req.get()));
1144 return req;
1145 }
1146
1163 template <typename RecvType>
1164 void all_to_all(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
1165 const int recv_count) const
1166 {
1167 auto guard = exec->get_scoped_device_id_guard();
1168 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1169 MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1170 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1171 this->get()));
1172 }
1173
1192 template <typename RecvType>
1193 request i_all_to_all(std::shared_ptr<const Executor> exec,
1194 RecvType* recv_buffer, const int recv_count) const
1195 {
1196 auto guard = exec->get_scoped_device_id_guard();
1197 request req;
1198 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1199 MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1200 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1201 this->get(), req.get()));
1202 return req;
1203 }
1204
1221 template <typename SendType, typename RecvType>
1222 void all_to_all(std::shared_ptr<const Executor> exec,
1223 const SendType* send_buffer, const int send_count,
1224 RecvType* recv_buffer, const int recv_count) const
1225 {
1226 auto guard = exec->get_scoped_device_id_guard();
1227 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1228 send_buffer, send_count, type_impl<SendType>::get_type(),
1229 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1230 this->get()));
1231 }
1232
1251 template <typename SendType, typename RecvType>
1252 request i_all_to_all(std::shared_ptr<const Executor> exec,
1253 const SendType* send_buffer, const int send_count,
1254 RecvType* recv_buffer, const int recv_count) const
1255 {
1256 auto guard = exec->get_scoped_device_id_guard();
1257 request req;
1258 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1259 send_buffer, send_count, type_impl<SendType>::get_type(),
1260 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1261 this->get(), req.get()));
1262 return req;
1263 }
1264
1284 template <typename SendType, typename RecvType>
1285 void all_to_all_v(std::shared_ptr<const Executor> exec,
1286 const SendType* send_buffer, const int* send_counts,
1287 const int* send_offsets, RecvType* recv_buffer,
1288 const int* recv_counts, const int* recv_offsets) const
1289 {
1290 this->all_to_all_v(std::move(exec), send_buffer, send_counts,
1291 send_offsets, type_impl<SendType>::get_type(),
1292 recv_buffer, recv_counts, recv_offsets,
1294 }
1295
1311 void all_to_all_v(std::shared_ptr<const Executor> exec,
1312 const void* send_buffer, const int* send_counts,
1313 const int* send_offsets, MPI_Datatype send_type,
1314 void* recv_buffer, const int* recv_counts,
1315 const int* recv_offsets, MPI_Datatype recv_type) const
1316 {
1317 auto guard = exec->get_scoped_device_id_guard();
1318 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoallv(
1319 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1320 recv_counts, recv_offsets, recv_type, this->get()));
1321 }
1322
1342 request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1343 const void* send_buffer, const int* send_counts,
1344 const int* send_offsets, MPI_Datatype send_type,
1345 void* recv_buffer, const int* recv_counts,
1346 const int* recv_offsets,
1347 MPI_Datatype recv_type) const
1348 {
1349 auto guard = exec->get_scoped_device_id_guard();
1350 request req;
1351 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoallv(
1352 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1353 recv_counts, recv_offsets, recv_type, this->get(), req.get()));
1354 return req;
1355 }
1356
1377 template <typename SendType, typename RecvType>
1378 request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1379 const SendType* send_buffer, const int* send_counts,
1380 const int* send_offsets, RecvType* recv_buffer,
1381 const int* recv_counts,
1382 const int* recv_offsets) const
1383 {
1384 return this->i_all_to_all_v(
1385 std::move(exec), send_buffer, send_counts, send_offsets,
1386 type_impl<SendType>::get_type(), recv_buffer, recv_counts,
1387 recv_offsets, type_impl<RecvType>::get_type());
1388 }
1389
1404 template <typename ScanType>
1405 void scan(std::shared_ptr<const Executor> exec, const ScanType* send_buffer,
1406 ScanType* recv_buffer, int count, MPI_Op operation) const
1407 {
1408 auto guard = exec->get_scoped_device_id_guard();
1409 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scan(send_buffer, recv_buffer, count,
1411 operation, this->get()));
1412 }
1413
1430 template <typename ScanType>
1431 request i_scan(std::shared_ptr<const Executor> exec,
1432 const ScanType* send_buffer, ScanType* recv_buffer,
1433 int count, MPI_Op operation) const
1434 {
1435 auto guard = exec->get_scoped_device_id_guard();
1436 request req;
1437 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscan(send_buffer, recv_buffer, count,
1439 operation, this->get(), req.get()));
1440 return req;
1441 }
1442
1443private:
1444 std::shared_ptr<MPI_Comm> comm_;
1445 bool force_host_buffer_;
1446
1447 int get_my_rank() const
1448 {
1449 int my_rank = 0;
1450 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(get(), &my_rank));
1451 return my_rank;
1452 }
1453
1454 int get_node_local_rank() const
1455 {
1456 MPI_Comm local_comm;
1457 int rank;
1458 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split_type(
1459 this->get(), MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local_comm));
1460 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(local_comm, &rank));
1461 MPI_Comm_free(&local_comm);
1462 return rank;
1463 }
1464
1465 int get_num_ranks() const
1466 {
1467 int size = 1;
1468 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_size(this->get(), &size));
1469 return size;
1470 }
1471
1472 bool compare(const MPI_Comm& other) const
1473 {
1474 int flag;
1475 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), other, &flag));
1476 return flag == MPI_IDENT;
1477 }
1478};
1479
1480
1485bool requires_host_buffer(const std::shared_ptr<const Executor>& exec,
1486 const communicator& comm);
1487
1488
1494inline double get_walltime() { return MPI_Wtime(); }
1495
1496
1505template <typename ValueType>
1506class window {
1507public:
1511 enum class create_type { allocate = 1, create = 2, dynamic_create = 3 };
1512
1516 enum class lock_type { shared = 1, exclusive = 2 };
1517
1521 window() : window_(MPI_WIN_NULL) {}
1522
1523 window(const window& other) = delete;
1524
1525 window& operator=(const window& other) = delete;
1526
1533 window(window&& other) : window_{std::exchange(other.window_, MPI_WIN_NULL)}
1534 {}
1535
1543 {
1544 window_ = std::exchange(other.window_, MPI_WIN_NULL);
1545 }
1546
1559 window(std::shared_ptr<const Executor> exec, ValueType* base, int num_elems,
1560 const communicator& comm, const int disp_unit = sizeof(ValueType),
1561 MPI_Info input_info = MPI_INFO_NULL,
1562 create_type c_type = create_type::create)
1563 {
1564 auto guard = exec->get_scoped_device_id_guard();
1565 unsigned size = num_elems * sizeof(ValueType);
1566 if (c_type == create_type::create) {
1567 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_create(
1568 base, size, disp_unit, input_info, comm.get(), &this->window_));
1569 } else if (c_type == create_type::dynamic_create) {
1570 GKO_ASSERT_NO_MPI_ERRORS(
1571 MPI_Win_create_dynamic(input_info, comm.get(), &this->window_));
1572 } else if (c_type == create_type::allocate) {
1573 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_allocate(
1574 size, disp_unit, input_info, comm.get(), base, &this->window_));
1575 } else {
1576 GKO_NOT_IMPLEMENTED;
1577 }
1578 }
1579
1585 MPI_Win get_window() const { return this->window_; }
1586
1593 void fence(int assert = 0) const
1594 {
1595 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_fence(assert, this->window_));
1596 }
1597
1606 void lock(int rank, lock_type lock_t = lock_type::shared,
1607 int assert = 0) const
1608 {
1609 if (lock_t == lock_type::shared) {
1610 GKO_ASSERT_NO_MPI_ERRORS(
1611 MPI_Win_lock(MPI_LOCK_SHARED, rank, assert, this->window_));
1612 } else if (lock_t == lock_type::exclusive) {
1613 GKO_ASSERT_NO_MPI_ERRORS(
1614 MPI_Win_lock(MPI_LOCK_EXCLUSIVE, rank, assert, this->window_));
1615 } else {
1616 GKO_NOT_IMPLEMENTED;
1617 }
1618 }
1619
1626 void unlock(int rank) const
1627 {
1628 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock(rank, this->window_));
1629 }
1630
1637 void lock_all(int assert = 0) const
1638 {
1639 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_lock_all(assert, this->window_));
1640 }
1641
1646 void unlock_all() const
1647 {
1648 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock_all(this->window_));
1649 }
1650
1657 void flush(int rank) const
1658 {
1659 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush(rank, this->window_));
1660 }
1661
1668 void flush_local(int rank) const
1669 {
1670 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local(rank, this->window_));
1671 }
1672
1677 void flush_all() const
1678 {
1679 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_all(this->window_));
1680 }
1681
1686 void flush_all_local() const
1687 {
1688 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local_all(this->window_));
1689 }
1690
1694 void sync() const { GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_sync(this->window_)); }
1695
1700 {
1701 if (this->window_ && this->window_ != MPI_WIN_NULL) {
1702 MPI_Win_free(&this->window_);
1703 }
1704 }
1705
1716 template <typename PutType>
1717 void put(std::shared_ptr<const Executor> exec, const PutType* origin_buffer,
1718 const int origin_count, const int target_rank,
1719 const unsigned int target_disp, const int target_count) const
1720 {
1721 auto guard = exec->get_scoped_device_id_guard();
1722 GKO_ASSERT_NO_MPI_ERRORS(
1723 MPI_Put(origin_buffer, origin_count, type_impl<PutType>::get_type(),
1724 target_rank, target_disp, target_count,
1726 }
1727
1740 template <typename PutType>
1741 request r_put(std::shared_ptr<const Executor> exec,
1742 const PutType* origin_buffer, const int origin_count,
1743 const int target_rank, const unsigned int target_disp,
1744 const int target_count) const
1745 {
1746 auto guard = exec->get_scoped_device_id_guard();
1747 request req;
1748 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rput(
1749 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1750 target_rank, target_disp, target_count,
1751 type_impl<PutType>::get_type(), this->get_window(), req.get()));
1752 return req;
1753 }
1754
1766 template <typename PutType>
1767 void accumulate(std::shared_ptr<const Executor> exec,
1768 const PutType* origin_buffer, const int origin_count,
1769 const int target_rank, const unsigned int target_disp,
1770 const int target_count, MPI_Op operation) const
1771 {
1772 auto guard = exec->get_scoped_device_id_guard();
1773 GKO_ASSERT_NO_MPI_ERRORS(MPI_Accumulate(
1774 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1775 target_rank, target_disp, target_count,
1776 type_impl<PutType>::get_type(), operation, this->get_window()));
1777 }
1778
1792 template <typename PutType>
1793 request r_accumulate(std::shared_ptr<const Executor> exec,
1794 const PutType* origin_buffer, const int origin_count,
1795 const int target_rank, const unsigned int target_disp,
1796 const int target_count, MPI_Op operation) const
1797 {
1798 auto guard = exec->get_scoped_device_id_guard();
1799 request req;
1800 GKO_ASSERT_NO_MPI_ERRORS(MPI_Raccumulate(
1801 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1802 target_rank, target_disp, target_count,
1803 type_impl<PutType>::get_type(), operation, this->get_window(),
1804 req.get()));
1805 return req;
1806 }
1807
1818 template <typename GetType>
1819 void get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1820 const int origin_count, const int target_rank,
1821 const unsigned int target_disp, const int target_count) const
1822 {
1823 auto guard = exec->get_scoped_device_id_guard();
1824 GKO_ASSERT_NO_MPI_ERRORS(
1825 MPI_Get(origin_buffer, origin_count, type_impl<GetType>::get_type(),
1826 target_rank, target_disp, target_count,
1828 }
1829
1842 template <typename GetType>
1843 request r_get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1844 const int origin_count, const int target_rank,
1845 const unsigned int target_disp, const int target_count) const
1846 {
1847 auto guard = exec->get_scoped_device_id_guard();
1848 request req;
1849 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget(
1850 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1851 target_rank, target_disp, target_count,
1852 type_impl<GetType>::get_type(), this->get_window(), req.get()));
1853 return req;
1854 }
1855
1869 template <typename GetType>
1870 void get_accumulate(std::shared_ptr<const Executor> exec,
1871 GetType* origin_buffer, const int origin_count,
1872 GetType* result_buffer, const int result_count,
1873 const int target_rank, const unsigned int target_disp,
1874 const int target_count, MPI_Op operation) const
1875 {
1876 auto guard = exec->get_scoped_device_id_guard();
1877 GKO_ASSERT_NO_MPI_ERRORS(MPI_Get_accumulate(
1878 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1879 result_buffer, result_count, type_impl<GetType>::get_type(),
1880 target_rank, target_disp, target_count,
1881 type_impl<GetType>::get_type(), operation, this->get_window()));
1882 }
1883
1899 template <typename GetType>
1900 request r_get_accumulate(std::shared_ptr<const Executor> exec,
1901 GetType* origin_buffer, const int origin_count,
1902 GetType* result_buffer, const int result_count,
1903 const int target_rank,
1904 const unsigned int target_disp,
1905 const int target_count, MPI_Op operation) const
1906 {
1907 auto guard = exec->get_scoped_device_id_guard();
1908 request req;
1909 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget_accumulate(
1910 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1911 result_buffer, result_count, type_impl<GetType>::get_type(),
1912 target_rank, target_disp, target_count,
1913 type_impl<GetType>::get_type(), operation, this->get_window(),
1914 req.get()));
1915 return req;
1916 }
1917
1928 template <typename GetType>
1929 void fetch_and_op(std::shared_ptr<const Executor> exec,
1930 GetType* origin_buffer, GetType* result_buffer,
1931 const int target_rank, const unsigned int target_disp,
1932 MPI_Op operation) const
1933 {
1934 auto guard = exec->get_scoped_device_id_guard();
1935 GKO_ASSERT_NO_MPI_ERRORS(MPI_Fetch_and_op(
1936 origin_buffer, result_buffer, type_impl<GetType>::get_type(),
1937 target_rank, target_disp, operation, this->get_window()));
1938 }
1939
1940private:
1941 MPI_Win window_;
1942};
1943
1944
1945} // namespace mpi
1946} // namespace experimental
1947} // namespace gko
1948
1949
1950#endif // GKO_HAVE_MPI
1951
1952
1953#endif // GKO_PUBLIC_CORE_BASE_MPI_HPP_
A thin wrapper of MPI_Comm that supports most MPI calls.
Definition mpi.hpp:408
status recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive data from source rank.
Definition mpi.hpp:582
void scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator with offsets.
Definition mpi.hpp:1099
request i_broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
(Non-blocking) Broadcast data from calling process to all ranks in the communicator
Definition mpi.hpp:660
void gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Gather data onto the root rank from all ranks in the communicator.
Definition mpi.hpp:847
request i_recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive (Non-blocking, Immediate return) data from source rank.
Definition mpi.hpp:610
request i_scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator with offsets.
Definition mpi.hpp:1132
void all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Communicate data from all ranks to all other ranks (MPI_Alltoall).
Definition mpi.hpp:1222
request i_all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Communicate data from all ranks to all other ranks (MPI_Ialltoall).
Definition mpi.hpp:1252
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition mpi.hpp:1342
bool operator!=(const communicator &rhs) const
Compare two communicator objects for non-equality.
Definition mpi.hpp:502
void scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator.
Definition mpi.hpp:1035
void synchronize() const
This function is used to synchronize the ranks in the communicator.
Definition mpi.hpp:508
int rank() const
Return the rank of the calling process in the communicator.
Definition mpi.hpp:478
request i_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
(Non-blocking) Reduce data into root from all calling processes on the same communicator.
Definition mpi.hpp:713
int size() const
Return the size of the communicator (number of ranks).
Definition mpi.hpp:471
void send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Blocking) data from calling process to destination rank.
Definition mpi.hpp:527
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition mpi.hpp:1378
request i_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator.
Definition mpi.hpp:879
void all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place) Communicate data from all ranks to all other ranks in place (MPI_Alltoall).
Definition mpi.hpp:1164
void all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition mpi.hpp:1285
request i_all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place, non-blocking) Reduce data from all calling processes from all calling processes on same co...
Definition mpi.hpp:764
request i_all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place, Non-blocking) Communicate data from all ranks to all other ranks in place (MPI_Ialltoall).
Definition mpi.hpp:1193
void all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition mpi.hpp:1311
int node_local_rank() const
Return the node local rank of the calling process in the communicator.
Definition mpi.hpp:485
void broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
Broadcast data from calling process to all ranks in the communicator.
Definition mpi.hpp:635
const MPI_Comm & get() const
Return the underlying MPI_Comm object.
Definition mpi.hpp:462
communicator(const MPI_Comm &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition mpi.hpp:434
void all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place) Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:739
void all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Gather data onto all ranks from all ranks in the communicator.
Definition mpi.hpp:976
request i_all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Gather data onto all ranks from all ranks in the communicator.
Definition mpi.hpp:1006
bool operator==(const communicator &rhs) const
Compare two communicator objects for equality.
Definition mpi.hpp:492
void all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:791
request i_gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator with offsets.
Definition mpi.hpp:945
request i_all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:818
communicator(const MPI_Comm &comm, bool force_host_buffer=false)
Non-owning constructor for an existing communicator of type MPI_Comm.
Definition mpi.hpp:420
request i_scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition mpi.hpp:1431
void reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
Reduce data into root from all calling processes on the same communicator.
Definition mpi.hpp:686
request i_scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator.
Definition mpi.hpp:1066
void scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition mpi.hpp:1405
void gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
Gather data onto the root rank from all ranks in the communicator with offsets.
Definition mpi.hpp:912
request i_send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Non-blocking, Immediate return) data from calling process to destination rank.
Definition mpi.hpp:554
communicator(const communicator &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition mpi.hpp:449
A move-only wrapper for a contiguous MPI_Datatype.
Definition mpi.hpp:101
MPI_Datatype get() const
Access the underlying MPI_Datatype.
Definition mpi.hpp:170
contiguous_type(int count, MPI_Datatype old_type)
Constructs a wrapper for a contiguous MPI_Datatype.
Definition mpi.hpp:109
contiguous_type()
Constructs empty wrapper with MPI_DATATYPE_NULL.
Definition mpi.hpp:118
contiguous_type(const contiguous_type &)=delete
Disallow copying of wrapper type.
contiguous_type(contiguous_type &&other) noexcept
Move constructor, leaves other with MPI_DATATYPE_NULL.
Definition mpi.hpp:135
contiguous_type & operator=(contiguous_type &&other) noexcept
Move assignment, leaves other with MPI_DATATYPE_NULL.
Definition mpi.hpp:147
contiguous_type & operator=(const contiguous_type &)=delete
Disallow copying of wrapper type.
~contiguous_type()
Destructs object by freeing wrapped MPI_Datatype.
Definition mpi.hpp:158
Class that sets up and finalizes the MPI environment.
Definition mpi.hpp:198
~environment()
Call MPI_Finalize at the end of the scope of this class.
Definition mpi.hpp:241
int get_provided_thread_support() const
Return the provided thread support.
Definition mpi.hpp:219
environment(int &argc, char **&argv, const thread_type thread_t=thread_type::serialized)
Call MPI_Init_thread and initialize the MPI environment.
Definition mpi.hpp:229
The request class is a light, move-only wrapper around the MPI_Request handle.
Definition mpi.hpp:319
request()
The default constructor.
Definition mpi.hpp:325
MPI_Request * get()
Get a pointer to the underlying MPI_Request handle.
Definition mpi.hpp:356
status wait()
Allows a rank to wait on a particular request handle.
Definition mpi.hpp:364
This class wraps the MPI_Window class with RAII functionality.
Definition mpi.hpp:1506
void get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data from the target window.
Definition mpi.hpp:1819
request r_put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition mpi.hpp:1741
window()
The default constructor.
Definition mpi.hpp:1521
void get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Get Accumulate data from the target window.
Definition mpi.hpp:1870
void put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition mpi.hpp:1717
~window()
The deleter which calls MPI_Win_free when the window leaves its scope.
Definition mpi.hpp:1699
lock_type
The lock type for passive target synchronization of the windows.
Definition mpi.hpp:1516
window & operator=(window &&other)
The move assignment operator.
Definition mpi.hpp:1542
request r_accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Accumulate data into the target window.
Definition mpi.hpp:1793
request r_get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Get Accumulate data (with handle) from the target window.
Definition mpi.hpp:1900
void fetch_and_op(std::shared_ptr< const Executor > exec, GetType *origin_buffer, GetType *result_buffer, const int target_rank, const unsigned int target_disp, MPI_Op operation) const
Fetch and operate on data from the target window (An optimized version of Get_accumulate).
Definition mpi.hpp:1929
void sync() const
Synchronize the public and private buffers for the window object.
Definition mpi.hpp:1694
void unlock(int rank) const
Close the epoch using MPI_Win_unlock for the window object.
Definition mpi.hpp:1626
void fence(int assert=0) const
The active target synchronization using MPI_Win_fence for the window object.
Definition mpi.hpp:1593
void flush(int rank) const
Flush the existing RDMA operations on the target rank for the calling process for the window object.
Definition mpi.hpp:1657
void unlock_all() const
Close the epoch on all ranks using MPI_Win_unlock_all for the window object.
Definition mpi.hpp:1646
create_type
The create type for the window object.
Definition mpi.hpp:1511
window(std::shared_ptr< const Executor > exec, ValueType *base, int num_elems, const communicator &comm, const int disp_unit=sizeof(ValueType), MPI_Info input_info=MPI_INFO_NULL, create_type c_type=create_type::create)
Create a window object with a given data pointer and type.
Definition mpi.hpp:1559
void accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Accumulate data into the target window.
Definition mpi.hpp:1767
void lock_all(int assert=0) const
Create the epoch on all ranks using MPI_Win_lock_all for the window object.
Definition mpi.hpp:1637
void lock(int rank, lock_type lock_t=lock_type::shared, int assert=0) const
Create an epoch using MPI_Win_lock for the window object.
Definition mpi.hpp:1606
void flush_all_local() const
Flush all the local existing RDMA operations on the calling rank for the window object.
Definition mpi.hpp:1686
window(window &&other)
The move constructor.
Definition mpi.hpp:1533
void flush_local(int rank) const
Flush the existing RDMA operations on the calling rank from the target rank for the window object.
Definition mpi.hpp:1668
MPI_Win get_window() const
Get the underlying window object of MPI_Win type.
Definition mpi.hpp:1585
request r_get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data (with handle) from the target window.
Definition mpi.hpp:1843
void flush_all() const
Flush all the existing RDMA operations for the calling process for the window object.
Definition mpi.hpp:1677
int map_rank_to_device_id(MPI_Comm comm, int num_devices)
Maps each MPI rank to a single device id in a round robin manner.
bool requires_host_buffer(const std::shared_ptr< const Executor > &exec, const communicator &comm)
Checks if the combination of Executor and communicator requires passing MPI buffers from the host mem...
double get_walltime()
Get the rank in the communicator of the calling process.
Definition mpi.hpp:1494
constexpr bool is_gpu_aware()
Return if GPU aware functionality is available.
Definition mpi.hpp:41
thread_type
This enum specifies the threading type to be used when creating an MPI environment.
Definition mpi.hpp:181
std::vector< status > wait_all(std::vector< request > &req)
Allows a rank to wait on multiple request handles.
Definition mpi.hpp:384
The Ginkgo namespace.
Definition abstract_factory.hpp:20
STL namespace.
The status struct is a light wrapper around the MPI_Status struct.
Definition mpi.hpp:279
int get_count(const T *data) const
Get the count of the number of elements received by the communication call.
Definition mpi.hpp:303
status()
The default constructor.
Definition mpi.hpp:283
MPI_Status * get()
Get a pointer to the underlying MPI_Status object.
Definition mpi.hpp:290
A struct that is used to determine the MPI_Datatype of a specified type.
Definition mpi.hpp:76