mpi4py实现了MPI的很多接口,并可以方便的在多进程中传递python的数据结构,编写python多进程程序。

https://mpi4py.readthedocs.io/en/stable/tutorial.html#running-python-scripts-with-mpi

在通信对象为python对象时,调用的方法是小写的,如Comm.send,Comm.recv,Comm.scatter等等。发送对象会作为参数传递给通信调用,接收对象则是值。

通信对象为缓冲区类数据时,调用的方法以大写字母开头,如Comm.Send、Comm.Recv、Comm.Bcast、Comm.Scatter、Comm.Gather。并且在传递时以元组形式传递类型,数量等,形式类似于[data, MPI.DOUBLE],[data, count, MPI.DOUBLE]。而向量集合通信操作(Comm.Scatterv和Comm.Gatherv等),参数需要指定为[data, count, displ, datatype],count和dipl是整数值的序列。

mpi4py编写的程序通常使用mpiexec -n 4 python script.py这样的形式使用。

1 通信

一些一般对象通信的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
# Point to Point-----------------------------
if rank == 0:
data = {'a': 7, 'b': 3.14}
comm.send(data, dest=1, tag=11)
elif rank == 1:
data = comm.recv(source=0, tag=11)

if rank == 0:
data = {'a': 7, 'b': 3.14}
req = comm.isend(data, dest=1, tag=11)
req.wait()
elif rank == 1:
req = comm.irecv(source=0, tag=11)
data = req.wait()
# Collective----------------------------------
if rank == 0:
data = {'key1' : [7, 2.72, 2+3j],
'key2' : ( 'abc', 'xyz')}
else:
data = None
data = comm.bcast(data, root=0)

if rank == 0:
data = [(i+1)**2 for i in range(size)]
else:
data = None
data = comm.scatter(data, root=0)

Numpy数组则是缓冲区数据类型,因此需要提前开辟空间,并使用大写字母开头的函数调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from mpi4py import MPI
import numpy

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
# Point to Point---------------------------------------
# passing MPI datatypes explicitly
if rank == 0:
data = numpy.arange(1000, dtype='i')
comm.Send([data, MPI.INT], dest=1, tag=77)
elif rank == 1:
data = numpy.empty(1000, dtype='i')
comm.Recv([data, MPI.INT], source=0, tag=77)

# automatic MPI datatype discovery
if rank == 0:
data = numpy.arange(100, dtype=numpy.float64)
comm.Send(data, dest=1, tag=13)
elif rank == 1:
data = numpy.empty(100, dtype=numpy.float64)
comm.Recv(data, source=0, tag=13)
# Collective-------------------------------------------
if rank == 0:
data = np.arange(100, dtype='i')
else:
data = np.empty(100, dtype='i')
comm.Bcast(data, root=0)
for i in range(100):
assert data[i] == i

sendbuf = None
if rank == 0:
sendbuf = np.empty([size, 100], dtype='i')
sendbuf.T[:,:] = range(size)
recvbuf = np.empty(100, dtype='i')
comm.Scatter(sendbuf, recvbuf, root=0)

sendbuf = np.zeros(100, dtype='i') + rank
recvbuf = None
if rank == 0:
recvbuf = np.empty([size, 100], dtype='i')
comm.Gather(sendbuf, recvbuf, root=0)

2 I/O

这里和MPI有点区别,MPI的write_at_all是组文件I/O,进程读写相同的内容,而mpi4py中的write_at_all是和MPI的write_at一样的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from mpi4py import MPI
import numpy as np

amode = MPI.MODE_WRONLY|MPI.MODE_CREATE
comm = MPI.COMM_WORLD
fh = MPI.File.Open(comm, "./datafile.contig", amode)

buffer = np.empty(10, dtype=np.int)
buffer[:] = comm.Get_rank()

offset = comm.Get_rank()*buffer.nbytes
fh.Write_at_all(offset, buffer)

fh.Close()