Skip to content

Commit

Permalink
move buffer clearing argument to class constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
DomFijan committed Aug 22, 2024
1 parent afc59e6 commit 617f984
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 32 deletions.
10 changes: 7 additions & 3 deletions hoomd/GSDDequeWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ GSDDequeWriter::GSDDequeWriter(std::shared_ptr<SystemDefinition> sysdef,
int queue_size,
std::string mode,
bool write_at_init,
bool clear_whole_buffer_after_dump,
uint64_t timestep)
: GSDDumpWriter(sysdef, trigger, fname, group, mode), m_queue_size(queue_size)
{
setLogWriter(logger);
bool file_empty = true;
m_clear_whole_buffer_after_dump = true;
#ifdef ENABLE_MPI
if (m_sysdef->isDomainDecomposed())
{
Expand All @@ -42,9 +44,10 @@ GSDDequeWriter::GSDDequeWriter(std::shared_ptr<SystemDefinition> sysdef,
else
{
analyze(timestep);
dump(0, -1, true);
dump(0, -1);
}
}
m_clear_whole_buffer_after_dump = clear_whole_buffer_after_dump;
}

void GSDDequeWriter::analyze(uint64_t timestep)
Expand All @@ -59,7 +62,7 @@ void GSDDequeWriter::analyze(uint64_t timestep)
}
}

void GSDDequeWriter::dump(long int start, long int end, bool clear_entire_buffer)
void GSDDequeWriter::dump(long int start, long int end)
{
auto buffer_length = static_cast<long int>(m_frame_queue.size());
if (end > buffer_length)
Expand All @@ -85,7 +88,7 @@ void GSDDequeWriter::dump(long int start, long int end, bool clear_entire_buffer
{
write(m_frame_queue[i], m_log_queue[i]);
}
if (clear_entire_buffer)
if (m_clear_whole_buffer_after_dump)
{
m_frame_queue.clear();
m_log_queue.clear();
Expand Down Expand Up @@ -136,6 +139,7 @@ void export_GSDDequeWriter(pybind11::module& m)
int,
std::string,
bool,
bool,
uint64_t>())
.def_property("max_burst_size",
&GSDDequeWriter::getMaxQueueSize,
Expand Down
4 changes: 3 additions & 1 deletion hoomd/GSDDequeWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ class PYBIND11_EXPORT GSDDequeWriter : public GSDDumpWriter
int queue_size,
std::string mode,
bool write_on_init,
bool clear_whole_buffer_after_dump,
uint64_t timestep);
~GSDDequeWriter() = default;

void analyze(uint64_t timestep) override;

void dump(long int start, long int end, bool clear_entire_buffer);
void dump(long int start, long int end);

int getMaxQueueSize() const;
void setMaxQueueSize(int new_max_size);
Expand All @@ -40,6 +41,7 @@ class PYBIND11_EXPORT GSDDequeWriter : public GSDDumpWriter

protected:
int m_queue_size;
bool m_clear_whole_buffer_after_dump;
std::deque<GSDDumpWriter::GSDFrame> m_frame_queue;
std::deque<pybind11::dict> m_log_queue;
};
Expand Down
32 changes: 18 additions & 14 deletions hoomd/md/pytest/test_burst_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ def test_burst_dump_with_clear_buffer(sim, tmp_path, clear_entire_buffer):
start_frame = 1
end_frame = 3
burst_trigger = hoomd.trigger.Periodic(period=2, phase=1)
burst_writer = hoomd.write.Burst(trigger=burst_trigger,
filename=filename,
mode='wb',
dynamic=['property', 'momentum'],
max_burst_size=4,
write_at_start=True)
burst_writer = hoomd.write.Burst(
trigger=burst_trigger,
filename=filename,
mode='wb',
dynamic=['property', 'momentum'],
max_burst_size=4,
write_at_start=True,
clear_whole_buffer_after_dump=clear_entire_buffer)
sim.operations.writers.append(burst_writer)
sim.run(12)
burst_writer.flush()
Expand All @@ -198,7 +200,7 @@ def test_burst_dump_with_clear_buffer(sim, tmp_path, clear_entire_buffer):
# First frame is always written
assert len(traj) == 1

burst_writer.dump(start_frame, end_frame, clear_entire_buffer)
burst_writer.dump(start_frame, end_frame)
burst_writer.flush()
dumped_frames = [0, 7, 9]
if sim.device.communicator.rank == 0:
Expand Down Expand Up @@ -296,12 +298,14 @@ def test_write_burst_log(sim, tmp_path):
def test_burst_dump_empty_buffer(sim, tmp_path, clear_entire_buffer):
filename = tmp_path / "temporary_test_file.gsd"
burst_trigger = hoomd.trigger.Periodic(period=2, phase=1)
burst_writer = hoomd.write.Burst(trigger=burst_trigger,
filename=filename,
mode='wb',
dynamic=['property', 'momentum'],
max_burst_size=3,
write_at_start=True)
burst_writer = hoomd.write.Burst(
trigger=burst_trigger,
filename=filename,
mode='wb',
dynamic=['property', 'momentum'],
max_burst_size=3,
write_at_start=True,
clear_whole_buffer_after_dump=clear_entire_buffer)
sim.operations.writers.append(burst_writer)
sim.run(8)
burst_writer.flush()
Expand All @@ -311,7 +315,7 @@ def test_burst_dump_empty_buffer(sim, tmp_path, clear_entire_buffer):
# First frame is always written
assert len(traj) == 1

burst_writer.dump(1, 2, clear_entire_buffer=clear_entire_buffer)
burst_writer.dump(1, 2)
burst_writer.flush()
if sim.device.communicator.rank == 0:
with gsd.hoomd.open(name=filename, mode='r') as traj:
Expand Down
36 changes: 22 additions & 14 deletions hoomd/write/gsd_burst.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ class Burst(GSD):
.. code-block:: python
write_at_start = burst.write_at_start
clear_whole_buffer_after_dump (bool): When ``True`` the buffer is
emptied after calling `dump` each time. If ``False`` only frames in the
buffer until the end frame will be deleted. Defaults to ``True``.
.. rubric:: Example:
.. code-block:: python
burst.clear_buffer_after_dump = False
"""

def __init__(self,
Expand All @@ -90,7 +100,8 @@ def __init__(self,
dynamic=None,
logger=None,
max_burst_size=-1,
write_at_start=False):
write_at_start=False,
clear_whole_buffer_after_dump=True):
super().__init__(trigger=trigger,
filename=filename,
filter=filter,
Expand All @@ -102,19 +113,19 @@ def __init__(self,
ParameterDict(max_burst_size=int, write_at_start=bool))
self._param_dict.update({
"max_burst_size": max_burst_size,
"write_at_start": write_at_start
"write_at_start": write_at_start,
"clear_whole_buffer_after_dump": clear_whole_buffer_after_dump
})

def _attach_hook(self):
sim = self._simulation
self._cpp_obj = _hoomd.GSDDequeWriter(sim.state._cpp_sys_def,
self.trigger, self.filename,
sim.state._get_group(self.filter),
self.logger, self.max_burst_size,
self.mode, self.write_at_start,
sim.timestep)

def dump(self, start=0, end=-1, clear_entire_buffer=True):
self._cpp_obj = _hoomd.GSDDequeWriter(
sim.state._cpp_sys_def, self.trigger, self.filename,
sim.state._get_group(self.filter), self.logger, self.max_burst_size,
self.mode, self.write_at_start, self.clear_whole_buffer_after_dump,
sim.timestep)

def dump(self, start=0, end=-1):
"""Write stored frames in range to the file and empties the buffer.
This method alllows for custom writing of frames at user specified
Expand All @@ -124,9 +135,6 @@ def dump(self, start=0, end=-1, clear_entire_buffer=True):
start (int): The first frame to write. Defaults to 0.
end (int): The last frame to write.
Defaults to -1 (last frame).
clear_entire_buffer (bool): When ``True`` the buffer is emptied
after writing. If ``False`` only frames in the buffer until end
frame will be deleted. Defaults to ``True``.
.. rubric:: Example:
Expand All @@ -135,7 +143,7 @@ def dump(self, start=0, end=-1, clear_entire_buffer=True):
burst.dump()
"""
if self._attached:
self._cpp_obj.dump(start, end, clear_entire_buffer)
self._cpp_obj.dump(start, end)

def __len__(self):
"""Get the current length of the internal frame buffer.
Expand Down

0 comments on commit 617f984

Please sign in to comment.