diff --git a/datadog/dogstatsd/base.py b/datadog/dogstatsd/base.py index e0c8d357..2e0c811d 100644 --- a/datadog/dogstatsd/base.py +++ b/datadog/dogstatsd/base.py @@ -105,17 +105,25 @@ def pre_fork(): c.pre_fork() -def post_fork(): +def post_fork_parent(): """Restore all client instances after a fork. If SUPPORTS_FORKING is true, this will be called automatically after os.fork(). """ for c in _instances: - c.post_fork() + c.post_fork_parent() +def post_fork_child(): + for c in _instances: + c.post_fork_child() + if SUPPORTS_FORKING: - os.register_at_fork(before=pre_fork, after_in_child=post_fork, after_in_parent=post_fork) # type: ignore + os.register_at_fork( # type: ignore + before=pre_fork, + after_in_child=post_fork_child, + after_in_parent=post_fork_parent, + ) # pylint: disable=useless-object-inheritance,too-many-instance-attributes @@ -1397,29 +1405,46 @@ def wait_for_pending(self): def pre_fork(self): """Prepare client for a process fork. - Flush any pending payloads, stop all background threads and - close the connection. Once the function returns. + Flush any pending payloads and stop all background threads. The client should not be used from this point until - post_fork() is called. + state is restored by calling post_fork_parent() or + post_fork_child(). """ - log.debug("[%d] pre_fork for %s", os.getpid(), self) - self._forking = True + # Hold the config lock across fork. This will make sure that + # we don't fork in the middle of the concurrent modification + # of the client's settings. Data protected by other locks may + # be left in inconsistent state in the child process, which we + # will clean up in post_fork_child. - with self._config_lock: - self._stop_flush_thread() - self._stop_sender_thread() - self.close_socket() + self._config_lock.acquire() + self._stop_flush_thread() + self._stop_sender_thread() - def post_fork(self): - """Restore the client state after a fork.""" + def post_fork_parent(self): + """Restore the client state after a fork in the parent process.""" + self._start_flush_thread(self._flush_interval) + self._start_sender_thread() + self._config_lock.release() - log.debug("[%d] post_fork for %s", os.getpid(), self) + def post_fork_child(self): + """Restore the client state after a fork in the child process.""" + self._config_lock.release() - self.close_socket() + # Discard the locks that could have been locked at the time + # when we forked. This may cause inconsistent internal state, + # which we will fix in the next steps. + self._socket_lock = Lock() + self._buffer_lock = RLock() - self._forking = False + # Reset the buffer so we don't send metrics from the parent + # process. Also makes sure buffer properties are consistent. + self._reset_buffer() + # Execute the socket_path setter to reconcile transport and + # payload size properties in respect to socket_path value. + self.socket_path = self.socket_path + self.close_socket() with self._config_lock: self._start_flush_thread(self._flush_interval) diff --git a/tests/integration/dogstatsd/test_statsd_fork.py b/tests/integration/dogstatsd/test_statsd_fork.py index 5b19f37b..c856376e 100644 --- a/tests/integration/dogstatsd/test_statsd_fork.py +++ b/tests/integration/dogstatsd/test_statsd_fork.py @@ -1,6 +1,7 @@ import os import itertools import socket +import threading import pytest @@ -31,7 +32,7 @@ def inner(*args, **kwargs): return inner statsd.pre_fork = track(statsd.pre_fork) - statsd.post_fork = track(statsd.post_fork) + statsd.post_fork_parent = track(statsd.post_fork_parent) pid = os.fork() if pid == 0: @@ -41,3 +42,49 @@ def inner(*args, **kwargs): os.waitpid(pid, 0) assert len(tracker) == 2 + + +def sender_a(statsd, running): + while running[0]: + statsd.gauge("spam", 1) + + +def sender_b(statsd, signal): + while running[0]: + with statsd: + statsd.gauge("spam", 1) + +@pytest.mark.parametrize( + "disable_background_sender, disable_buffering, sender", + list(itertools.product([True, False], [True, False], [sender_a, sender_b])), +) +def test_fork_with_thread(disable_background_sender, disable_buffering, sender): + if not SUPPORTS_FORKING: + pytest.skip("os.register_at_fork is required for this test") + + statsd = DogStatsd( + telemetry_min_flush_interval=0, + disable_background_sender=disable_background_sender, + disable_buffering=disable_buffering, + ) + + sender = None + try: + sender_running = [True] + sender = threading.Thread(target=sender, args=(statsd, sender_running)) + sender.daemon = True + sender.start() + + pid = os.fork() + if pid == 0: + os._exit(42) + + assert pid > 0 + (_, status) = os.waitpid(pid, 0) + + assert os.WEXITSTATUS(status) == 42 + finally: + statsd.stop() + if sender: + sender_running[0] = False + sender.join() diff --git a/tests/integration/dogstatsd/test_statsd_sender.py b/tests/integration/dogstatsd/test_statsd_sender.py index d3a56860..55710c17 100644 --- a/tests/integration/dogstatsd/test_statsd_sender.py +++ b/tests/integration/dogstatsd/test_statsd_sender.py @@ -76,7 +76,7 @@ def test_fork_hooks(disable_background_sender, disable_buffering): assert statsd._queue is None or statsd._queue.empty() assert len(statsd._buffer) == 0 - statsd.post_fork() + statsd.post_fork_parent() assert disable_buffering or statsd._flush_thread.is_alive() assert disable_background_sender or statsd._sender_thread.is_alive() diff --git a/tests/unit/dogstatsd/test_statsd.py b/tests/unit/dogstatsd/test_statsd.py index a5fe7a0e..8998ac60 100644 --- a/tests/unit/dogstatsd/test_statsd.py +++ b/tests/unit/dogstatsd/test_statsd.py @@ -2021,7 +2021,7 @@ def inner(): # Statsd should survive this sequence of events statsd.pre_fork() statsd.get_socket() - statsd.post_fork() + statsd.post_fork_parent() t = Thread(target=inner) t.daemon = True t.start()