diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index 7cc6b9987f..779a2319da 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -1,6 +1,7 @@ """ Zeromq transport classes """ + import errno import hashlib import logging @@ -178,7 +179,7 @@ def __init__(self, opts, io_loop, **kwargs): self._socket.setsockopt(zmq.IPV4ONLY, 0) if HAS_ZMQ_MONITOR and self.opts["zmq_monitor"]: - self._monitor = ZeroMQSocketMonitor(self._socket) + self._monitor = ZeroMQSocketMonitor(self._socket, opts) self._monitor.start_io_loop(self.io_loop) def close(self): @@ -211,6 +212,12 @@ def connect(self, publish_port, connect_callback=None, disconnect_callback=None) self.master_pub, ) log.debug("%r connecting to %s", self, self.master_pub) + if ( + hasattr(self, "_monitor") + and self._monitor is not None + and disconnect_callback is not None + ): + self._monitor.disconnect_callback = disconnect_callback self._socket.connect(self.master_pub) connect_callback(True) @@ -634,7 +641,7 @@ def mark_future(msg): class ZeroMQSocketMonitor: __EVENT_MAP = None - def __init__(self, socket): + def __init__(self, socket, opts=None): """ Create ZMQ monitor sockets @@ -644,6 +651,11 @@ def __init__(self, socket): self._socket = socket self._monitor_socket = self._socket.get_monitor_socket() self._monitor_stream = None + self.disconnect_callback = None + self.disconnect_on_retry = None + self._connect_retry = None + if opts is not None: + self.disconnect_on_retry = opts.get("zmq_disconnect_on_retry", 10) def start_io_loop(self, io_loop): log.trace("Event monitor start!") @@ -680,6 +692,21 @@ def monitor_callback(self, msg): log.debug("ZeroMQ event: %s", evt) if evt["event"] == zmq.EVENT_MONITOR_STOPPED: self.stop() + elif evt["event"] == zmq.EVENT_DISCONNECTED: + if self.disconnect_callback is not None: + self.disconnect_callback() + elif evt["event"] == zmq.EVENT_CONNECT_RETRIED: + if ( + self.disconnect_on_retry is not None + and self.disconnect_callback is not None + ): + if self._connect_retry is None: + self._connect_retry = self.disconnect_on_retry + self._connect_retry -= 1 + if self._connect_retry <= 0: + log.debug("Calling disconnect callback as number of retries reached.") + self.disconnect_callback() + self._connect_retry = self.disconnect_on_retry def stop(self): if self._socket is None: