API Reference

Utilities

has_cuda_support

mpi4jax.has_cuda_support() bool

Returns True if mpi4jax is built with CUDA support and can be used with GPU-based jax-arrays, False otherwise.

Communication primitives

allgather

mpi4jax.allgather(x, *, comm=None, token=None)

Perform an allgather operation.

Warning

x must have the same shape and dtype on all processes.

Parameters:
  • x – Array or scalar input to send.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Received data.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]

allreduce

mpi4jax.allreduce(x, op, *, comm=None, token=None)

Perform an allreduce operation.

Note

This primitive can be differentiated via jax.grad() and related functions if op is mpi4py.MPI.SUM.

Parameters:
  • x – Array or scalar input.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Result of the allreduce operation.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]

alltoall

mpi4jax.alltoall(x, *, comm=None, token=None)

Perform an alltoall operation.

Parameters:
  • x – Array input to send. First axis must have size nproc.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Received data.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]

barrier

mpi4jax.barrier(*, comm=None, token=None)

Perform a barrier operation.

Parameters:
  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • A new, modified token, that depends on this operation.

Return type:

Token

bcast

mpi4jax.bcast(x, root, *, comm=None, token=None)

Perform a bcast (broadcast) operation.

Warning

Unlike mpi4py’s bcast, this returns a new array with the received data.

Parameters:
  • x – Array or scalar input. Data is only read on root process. On non-root processes, this is used to determine the shape and dtype of the result.

  • root (int) – The process to use as source.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Received data.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]

gather

mpi4jax.gather(x, root, *, comm=None, token=None)

Perform a gather operation.

Warning

x must have the same shape and dtype on all processes.

Warning

The shape of the returned data varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes the output is identical to the input.

Parameters:
  • x – Array or scalar input to send.

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Received data on root process, otherwise unmodified input.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]

recv

mpi4jax.recv(x, source=-1, *, tag=-1, comm=None, status=None, token=None)

Perform a recv (receive) operation.

Warning

Unlike mpi4py’s recv, this returns a new array with the received data.

Parameters:
  • x – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.

  • source (int) – Rank of the source MPI process.

  • tag (int) – Tag of this message.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • status (mpi4py.MPI.Status) – Status object, can be used for introspection.

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Received data.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]

reduce

mpi4jax.reduce(x, op, root, *, comm=None, token=None)

Perform a reduce operation.

Parameters:
  • x – Array or scalar input to send.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Result of the reduce operation on root process, otherwise unmodified input.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]

scan

mpi4jax.scan(x, op, *, comm=None, token=None)

Perform a scan operation.

Parameters:
  • x – Array or scalar input to send.

  • op (mpi4py.MPI.Op) – The reduction operator (e.g mpi4py.MPI.SUM).

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Result of the scan operation.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]

scatter

mpi4jax.scatter(x, root, *, comm=None, token=None)

Perform a scatter operation.

Warning

Unlike mpi4py’s scatter, this returns a new array with the received data.

Warning

The expected shape of the first input varies between ranks. On the root process, it is (nproc, *input_shape). On all other processes, it is input_shape.

Parameters:
  • x – Array or scalar input with the correct shape and dtype. On the root process, this contains the data to send, and its first axis must have size nproc. On non-root processes, this may contain arbitrary data and will not be overwritten.

  • root (int) – Rank of the root MPI process.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Received data.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]

send

mpi4jax.send(x, dest, *, tag=0, comm=None, token=None)

Perform a send operation.

Parameters:
  • x – Array or scalar input to send.

  • dest (int) – Rank of the destination MPI process.

  • tag (int) – Tag of this message.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

A new, modified token, that depends on this operation.

Return type:

Token

sendrecv

mpi4jax.sendrecv(sendbuf, recvbuf, source, dest, *, sendtag=0, recvtag=-1, comm=None, status=None, token=None)

Perform a sendrecv operation.

Warning

Unlike mpi4py’s sendrecv, this returns a new array with the received data.

Parameters:
  • sendbuf – Array or scalar input to send.

  • recvbuf – Array or scalar input with the correct shape and dtype. This can contain arbitrary data and will not be overwritten.

  • source (int) – Rank of the source MPI process.

  • dest (int) – Rank of the destination MPI process.

  • sendtag (int) – Tag of this message for sending.

  • recvtag (int) – Tag of this message for receiving.

  • comm (mpi4py.MPI.Comm) – The MPI communicator to use (defaults to a clone of COMM_WORLD).

  • status (mpi4py.MPI.Status) – Status object, can be used for introspection.

  • token (Token) – XLA token to use to ensure correct execution order. If not given, a new token is generated.

Returns:

  • Received data.

  • A new, modified token, that depends on this operation.

Return type:

Tuple[DeviceArray, Token]