Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable dumping only part of the buffer for burst writer. #1870

Merged
merged 27 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions hoomd/GSDDequeWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ GSDDequeWriter::GSDDequeWriter(std::shared_ptr<SystemDefinition> sysdef,
int queue_size,
std::string mode,
bool write_at_init,
bool clear_whole_buffer_after_dump,
uint64_t timestep)
: GSDDumpWriter(sysdef, trigger, fname, group, mode), m_queue_size(queue_size)
{
setLogWriter(logger);
bool file_empty = true;
m_clear_whole_buffer_after_dump = true;
#ifdef ENABLE_MPI
if (m_sysdef->isDomainDecomposed())
{
Expand All @@ -42,9 +44,10 @@ GSDDequeWriter::GSDDequeWriter(std::shared_ptr<SystemDefinition> sysdef,
else
{
analyze(timestep);
dump();
dump(0, -1);
}
}
setClearWholeBufferAfterDump(clear_whole_buffer_after_dump);
}

void GSDDequeWriter::analyze(uint64_t timestep)
Expand All @@ -59,14 +62,42 @@ void GSDDequeWriter::analyze(uint64_t timestep)
}
}

void GSDDequeWriter::dump()
void GSDDequeWriter::dump(long int start, long int end)
{
for (auto i {static_cast<long int>(m_frame_queue.size()) - 1}; i >= 0; --i)
auto buffer_length = static_cast<long int>(m_frame_queue.size());
if (end > buffer_length)
{
throw std::runtime_error("Burst.dump's end index is out of range.");
}
if (start < 0 || start > buffer_length)
{
throw std::runtime_error("Burst.dump's start index is out of range.");
}
long int iterator_start, iterator_end;
if (end < 0)
{
iterator_end = buffer_length - start;
iterator_start = 0;
}
else
{
iterator_end = buffer_length - start;
iterator_start = buffer_length - end;
}
janbridley marked this conversation as resolved.
Show resolved Hide resolved
for (auto i = iterator_end - 1; i >= iterator_start; --i)
{
write(m_frame_queue[i], m_log_queue[i]);
}
janbridley marked this conversation as resolved.
Show resolved Hide resolved
m_frame_queue.clear();
m_log_queue.clear();
if (m_clear_whole_buffer_after_dump)
{
m_frame_queue.clear();
m_log_queue.clear();
}
else
{
m_frame_queue.erase(m_frame_queue.begin() + iterator_start, m_frame_queue.end());
m_log_queue.erase(m_log_queue.begin() + iterator_start, m_log_queue.end());
}
}

int GSDDequeWriter::getMaxQueueSize() const
Expand All @@ -93,6 +124,16 @@ void GSDDequeWriter::setMaxQueueSize(int new_max_size)
}
}

bool GSDDequeWriter::getClearWholeBufferAfterDump() const
{
return m_clear_whole_buffer_after_dump;
}

void GSDDequeWriter::setClearWholeBufferAfterDump(bool clear_whole_buffer_after_dump)
{
m_clear_whole_buffer_after_dump = clear_whole_buffer_after_dump;
}

namespace detail
{
void export_GSDDequeWriter(pybind11::module& m)
Expand All @@ -108,10 +149,14 @@ void export_GSDDequeWriter(pybind11::module& m)
int,
std::string,
bool,
bool,
uint64_t>())
.def_property("max_burst_size",
&GSDDequeWriter::getMaxQueueSize,
&GSDDequeWriter::setMaxQueueSize)
.def_property("clear_whole_buffer_after_dump",
&GSDDequeWriter::getClearWholeBufferAfterDump,
&GSDDequeWriter::setClearWholeBufferAfterDump)
.def("__len__", &GSDDequeWriter::getCurrentQueueSize)
.def("dump", &GSDDequeWriter::dump);
}
Expand Down
6 changes: 5 additions & 1 deletion hoomd/GSDDequeWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,24 @@ class PYBIND11_EXPORT GSDDequeWriter : public GSDDumpWriter
int queue_size,
std::string mode,
bool write_on_init,
bool clear_whole_buffer_after_dump,
uint64_t timestep);
~GSDDequeWriter() = default;

void analyze(uint64_t timestep) override;

void dump();
void dump(long int start, long int end);

int getMaxQueueSize() const;
void setMaxQueueSize(int new_max_size);
bool getClearWholeBufferAfterDump() const;
void setClearWholeBufferAfterDump(bool clear_whole_buffer_after_dump);

size_t getCurrentQueueSize() const;

protected:
int m_queue_size;
bool m_clear_whole_buffer_after_dump;
std::deque<GSDDumpWriter::GSDFrame> m_frame_queue;
std::deque<pybind11::dict> m_log_queue;
};
Expand Down
89 changes: 87 additions & 2 deletions hoomd/md/pytest/test_burst_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def test_len(sim, tmp_path):
assert len(burst_writer) == 0


def test_burst_dump(sim, tmp_path):
@pytest.mark.parametrize("start, end", [(0, -1), (0, 0), (0, 1), (0, 2), (1, 1),
(2, 2), (1, 2), (1, -1), (2, -1)])
def test_burst_dump(sim, tmp_path, start, end):
filename = tmp_path / "temporary_test_file.gsd"

burst_trigger = hoomd.trigger.Periodic(period=2, phase=1)
Expand All @@ -164,11 +166,59 @@ def test_burst_dump(sim, tmp_path):
# First frame is always written
assert len(traj) == 1

burst_writer.dump(start=start, end=end)
burst_writer.flush()
dumped_frames = [3, 5, 7]
if sim.device.communicator.rank == 0:
if end == -1:
end = len(dumped_frames)
with gsd.hoomd.open(name=filename, mode='r') as traj:
assert [frame.configuration.step for frame in traj
] == [0] + dumped_frames[start:end]


@pytest.mark.parametrize("clear_entire_buffer", [True, False])
def test_burst_dump_with_clear_buffer(sim, tmp_path, clear_entire_buffer):
filename = tmp_path / "temporary_test_file.gsd"
start_frame = 1
end_frame = 3
burst_trigger = hoomd.trigger.Periodic(period=2, phase=1)
burst_writer = hoomd.write.Burst(
trigger=burst_trigger,
filename=filename,
mode='wb',
dynamic=['property', 'momentum'],
max_burst_size=4,
write_at_start=True,
clear_whole_buffer_after_dump=clear_entire_buffer)
sim.operations.writers.append(burst_writer)
sim.run(12)
burst_writer.flush()
if sim.device.communicator.rank == 0:
assert Path(filename).exists()
with gsd.hoomd.open(filename, "r") as traj:
# First frame is always written
assert len(traj) == 1

burst_writer.dump(start_frame, end_frame)
burst_writer.flush()
dumped_frames = [0, 7, 9]
if sim.device.communicator.rank == 0:
with gsd.hoomd.open(name=filename, mode='r') as traj:
print([frame.configuration.step for frame in traj])
assert [frame.configuration.step for frame in traj] == dumped_frames

sim.run(4)
burst_writer.dump()
burst_writer.flush()
if clear_entire_buffer:
dumped_frames += [13, 15]
else:
dumped_frames += [11, 13, 15]
if sim.device.communicator.rank == 0:
with gsd.hoomd.open(name=filename, mode='r') as traj:
assert [frame.configuration.step for frame in traj] == [0, 3, 5, 7]
print([frame.configuration.step for frame in traj])
assert [frame.configuration.step for frame in traj] == dumped_frames


def test_burst_max_size(sim, tmp_path):
Expand Down Expand Up @@ -242,3 +292,38 @@ def test_write_burst_log(sim, tmp_path):
with gsd.hoomd.open(name=filename, mode='r') as traj:
for frame, sim_ke in zip(traj[1:], kinetic_energies):
assert frame.log[key] == sim_ke


@pytest.mark.parametrize("clear_entire_buffer", [True, False])
def test_burst_dump_empty_buffer(sim, tmp_path, clear_entire_buffer):
filename = tmp_path / "temporary_test_file.gsd"
burst_trigger = hoomd.trigger.Periodic(period=2, phase=1)
burst_writer = hoomd.write.Burst(
trigger=burst_trigger,
filename=filename,
mode='wb',
dynamic=['property', 'momentum'],
max_burst_size=3,
write_at_start=True,
clear_whole_buffer_after_dump=clear_entire_buffer)
sim.operations.writers.append(burst_writer)
sim.run(8)
burst_writer.flush()
if sim.device.communicator.rank == 0:
assert Path(filename).exists()
with gsd.hoomd.open(filename, "r") as traj:
# First frame is always written
assert len(traj) == 1

burst_writer.dump(1, 2)
burst_writer.flush()
if sim.device.communicator.rank == 0:
with gsd.hoomd.open(name=filename, mode='r') as traj:
assert len(traj) == 2

sim.run(4)
burst_writer.dump()
burst_writer.flush()
if sim.device.communicator.rank == 0:
with gsd.hoomd.open(name=filename, mode='r') as traj:
assert len(traj) == (4 if clear_entire_buffer else 5)
41 changes: 30 additions & 11 deletions hoomd/write/gsd_burst.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class Burst(GSD):
write_at_start (bool): When ``True`` **and** the file does not exist or
has 0 frames: write one frame with the current state of the system
when `hoomd.Simulation.run` is called. Defaults to ``False``.
clear_whole_buffer_after_dump (bool): When ``True`` the buffer is
emptied after calling `dump` each time. When ``False``, `dump` removes
frames from the buffer unil the ``end`` index. Defaults to ``True``.

Warning:
`Burst` errors when attempting to create a file or writing to one with
Expand Down Expand Up @@ -80,6 +83,16 @@ class Burst(GSD):
.. code-block:: python

write_at_start = burst.write_at_start

clear_whole_buffer_after_dump (bool): When ``True`` the buffer is
emptied after calling `dump` each time. When ``False``, `dump` removes
frames from the buffer unil the ``end`` index.

.. rubric:: Example:

.. code-block:: python

burst.clear_buffer_after_dump = False
"""

def __init__(self,
Expand All @@ -90,7 +103,8 @@ def __init__(self,
dynamic=None,
logger=None,
max_burst_size=-1,
write_at_start=False):
write_at_start=False,
clear_whole_buffer_after_dump=True):
super().__init__(trigger=trigger,
filename=filename,
filter=filter,
Expand All @@ -102,32 +116,37 @@ def __init__(self,
ParameterDict(max_burst_size=int, write_at_start=bool))
self._param_dict.update({
"max_burst_size": max_burst_size,
"write_at_start": write_at_start
"write_at_start": write_at_start,
"clear_whole_buffer_after_dump": clear_whole_buffer_after_dump
})

def _attach_hook(self):
sim = self._simulation
self._cpp_obj = _hoomd.GSDDequeWriter(sim.state._cpp_sys_def,
self.trigger, self.filename,
sim.state._get_group(self.filter),
self.logger, self.max_burst_size,
self.mode, self.write_at_start,
sim.timestep)
self._cpp_obj = _hoomd.GSDDequeWriter(
sim.state._cpp_sys_def, self.trigger, self.filename,
sim.state._get_group(self.filter), self.logger, self.max_burst_size,
self.mode, self.write_at_start, self.clear_whole_buffer_after_dump,
sim.timestep)

def dump(self):
"""Write all currently stored frames to the file and empties the buffer.
def dump(self, start=0, end=-1):
"""Write stored frames in range to the file and empties the buffer.

This method alllows for custom writing of frames at user specified
conditions.

Args:
start (int): The first frame to write. Defaults to 0.
end (int): The last frame to write.
Defaults to -1 (last frame).

.. rubric:: Example:

.. code-block:: python

burst.dump()
"""
if self._attached:
self._cpp_obj.dump()
self._cpp_obj.dump(start, end)

def __len__(self):
"""Get the current length of the internal frame buffer.
Expand Down
Loading