Skip to content

Commit

Permalink
When each node has only one process, use numpy arrays instead of shar…
Browse files Browse the repository at this point in the history
…ed memory.
  • Loading branch information
tskisner committed Dec 15, 2022
1 parent e2dc425 commit db3d31c
Showing 1 changed file with 107 additions and 92 deletions.
199 changes: 107 additions & 92 deletions pshmem/shmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,103 +161,118 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
self.data = None

if self._n > 0:
# First rank on each node creates the buffer
if self._noderank == 0:
try:
self._shmem = posix_ipc.SharedMemory(
self._name,
posix_ipc.O_CREX,
size=int(nbytes),
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed allocation of {} bytes".format(nbytes)
msg += " ({} elements of {} bytes each)".format(
self._n, self._dsize
)
msg += ": {}".format(e)
print(msg, flush=True)
raise
try:
# MMap the shared memory
self._shmap = mmap.mmap(
self._shmem.fd,
self._shmem.size,
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed MMap of {} bytes".format(nbytes)
msg += " ({} elements of {} bytes each)".format(
self._n, self._dsize
)
msg += ": {}".format(e)
print(msg, flush=True)
# Try to free the shared memory object
if self._nodeprocs == 1:
# There is only one process on a node, so we use a standard
# numpy array rather than wrapped shared memory. This helps
# reduce the total number of open files, which is limited by
# the kernel.
self._flat = np.zeros(
self._n,
dtype=self._dtype,
)
# Wrap
self.data = self._flat.reshape(self._shape)
else:
# First rank on each node creates the buffer
if self._noderank == 0:
try:
self._shmem = posix_ipc.SharedMemory(
self._name,
posix_ipc.O_CREX,
size=int(nbytes),
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed allocation of {} bytes".format(nbytes)
msg += " ({} elements of {} bytes each)".format(
self._n, self._dsize
)
msg += ": {}".format(e)
print(msg, flush=True)
raise
try:
# MMap the shared memory
self._shmap = mmap.mmap(
self._shmem.fd,
self._shmem.size,
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed MMap of {} bytes".format(nbytes)
msg += " ({} elements of {} bytes each)".format(
self._n, self._dsize
)
msg += ": {}".format(e)
print(msg, flush=True)
# Try to free the shared memory object
try:
self._shmem.close_fd()
self._shmem.unlink()
except Exception as eclose:
pass
raise

# Wait for that to be created
if self._nodecomm is not None:
self._nodecomm.barrier()

# Other ranks on the node attach
if self._noderank != 0:
try:
self._shmem = posix_ipc.SharedMemory(self._name)
# MMap the shared memory
self._shmap = mmap.mmap(
self._shmem.fd,
self._shmem.size,
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed to attach buffer of {} bytes".format(nbytes)
msg += " ({} elements of {} bytes each)".format(
self._n, self._dsize
)
msg += ": {}".format(e)
print(msg, flush=True)
raise

# Wait for other processes to attach
if self._nodecomm is not None:
self._nodecomm.barrier()

# Now that all processes have mmap'ed the shared memory we can
# close the shared memory handle
self._shmem.close_fd()

# Wait for all processes to close file handle
if self._nodecomm is not None:
self._nodecomm.barrier()

# One process requests the file to be deleted, but this will not
# actually happen until all processes release their mmap.
if self._noderank == 0:
try:
self._shmem.close_fd()
self._shmem.unlink()
except Exception as eclose:
pass
raise

# Wait for that to be created
if self._nodecomm is not None:
self._nodecomm.barrier()

# Other ranks on the node attach
if self._noderank != 0:
try:
self._shmem = posix_ipc.SharedMemory(self._name)
# MMap the shared memory
self._shmap = mmap.mmap(
self._shmem.fd,
self._shmem.size,
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed to attach buffer of {} bytes".format(nbytes)
msg += " ({} elements of {} bytes each)".format(
self._n, self._dsize
)
msg += ": {}".format(e)
print(msg, flush=True)
raise

# Wait for other processes to attach
if self._nodecomm is not None:
self._nodecomm.barrier()
except posix_ipc.ExistentialError:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed to unlink shared memory"
msg += ": {}".format(e)
print(msg, flush=True)
raise

# Create a numpy array which acts as a view of the buffer.
self._flat = np.ndarray(
self._n,
dtype=self._dtype,
buffer=self._shmap,
)
# Initialize to zero.
if self._noderank == 0:
self._flat[:] = 0

# Now that all processes have mmap'ed the shared memory we can
# close the shared memory handle
self._shmem.close_fd()
# Wrap
self.data = self._flat.reshape(self._shape)

# Wait for all processes to close file handle
if self._nodecomm is not None:
self._nodecomm.barrier()

# One process requests the file to be deleted, but this will not
# actually happen until all processes release their mmap.
if self._noderank == 0:
try:
self._shmem.unlink()
except posix_ipc.ExistentialError:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed to unlink shared memory"
msg += ": {}".format(e)
print(msg, flush=True)
raise

# Create a numpy array which acts as a view of the buffer.
self._flat = np.ndarray(
self._shape,
dtype=self._dtype,
buffer=self._shmap,
)
self.data = self._flat.reshape(self._shape)

# Initialize to zero.
if self._noderank == 0:
self._flat[:] = 0

def __del__(self):
self.close()
Expand Down

0 comments on commit db3d31c

Please sign in to comment.