From 452db14a0929cde71edd0a240655dd763d3bcbf2 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Mon, 10 Aug 2020 10:54:25 -0700 Subject: [PATCH] Bind a new socket per-worker We've seen lumpy distributions of requests to workers, which seems to parallel what some other projects have seen [0][1]. The solution (as best we can tell) is to take advantage of the SO_REUSEPORT that eventlet's been setting for us since basically forever [2]. [0] https://lwn.net/Articles/542629/ [1] https://github.com/varnish/hitch/issues/142 [2] https://github.com/eventlet/eventlet/commit/f9a3074a3 Change-Id: I83cdaa2cbd394cbd49609c65bf9c5ed026c55417 --- swift/common/wsgi.py | 348 ++++++++++++++-------------------- test/unit/common/test_wsgi.py | 212 +++++++++------------ 2 files changed, 232 insertions(+), 328 deletions(-) diff --git a/swift/common/wsgi.py b/swift/common/wsgi.py index f9cfd1f2fc..57a71a8f7d 100644 --- a/swift/common/wsgi.py +++ b/swift/common/wsgi.py @@ -636,7 +636,7 @@ class SwiftHttpProxiedProtocol(SwiftHttpProtocol): return environ -def run_server(conf, logger, sock, global_conf=None): +def run_server(conf, logger, sock, global_conf=None, ready_callback=None): # Ensure TZ environment variable exists to avoid stat('/etc/localtime') on # some platforms. This locks in reported times to UTC. os.environ['TZ'] = 'UTC+0' @@ -677,6 +677,8 @@ def run_server(conf, logger, sock, global_conf=None): # header; "Etag" just won't do). 'capitalize_response_headers': False, } + if ready_callback: + ready_callback() try: wsgi.server(sock, app, wsgi_logger, **server_kwargs) except socket.error as err: @@ -689,6 +691,15 @@ class StrategyBase(object): """ Some operations common to all strategy classes. """ + def __init__(self, conf, logger): + self.conf = conf + self.logger = logger + self.signaled_ready = False + + # Each strategy is welcome to track data however it likes, but all + # socket refs should be somewhere in this dict. This allows forked-off + # children to easily drop refs to sibling sockets in post_fork_hook(). + self.tracking_data = {} def post_fork_hook(self): """ @@ -696,7 +707,10 @@ class StrategyBase(object): wsgi server, to perform any initialization such as drop privileges. """ + if not self.signaled_ready: + capture_stdio(self.logger) drop_privileges(self.conf.get('user', 'swift')) + del self.tracking_data # children don't need to track siblings def shutdown_sockets(self): """ @@ -721,12 +735,40 @@ class StrategyBase(object): # on socket objects is provided to toggle it. sock.set_inheritable(False) + def signal_ready(self): + """ + Signal that the server is up and accepting connections. + """ + if self.signaled_ready: + return # Already did it + + # Redirect errors to logger and close stdio. swift-init (for example) + # uses this to know that the service is ready to accept connections. + capture_stdio(self.logger) + + # If necessary, signal an old copy of us that it's okay to shutdown + # its listen sockets now because ours are up and ready to receive + # connections. This is used for seamless reloading using SIGUSR1. + reexec_signal_fd = os.getenv(NOTIFY_FD_ENV_KEY) + if reexec_signal_fd: + reexec_signal_fd = int(reexec_signal_fd) + os.write(reexec_signal_fd, str(os.getpid()).encode('utf8')) + os.close(reexec_signal_fd) + + # Finally, signal systemd (if appropriate) that process started + # properly. + systemd_notify(logger=self.logger) + + self.signaled_ready = True + class WorkersStrategy(StrategyBase): """ WSGI server management strategy object for a single bind port and listen socket shared by a configured number of forked-off workers. + Tracking data is a map of ``pid -> socket``. + Used in :py:func:`run_wsgi`. :param dict conf: Server configuration dictionary. @@ -735,10 +777,7 @@ class WorkersStrategy(StrategyBase): """ def __init__(self, conf, logger): - self.conf = conf - self.logger = logger - self.sock = None - self.children = [] + super(WorkersStrategy, self).__init__(conf, logger) self.worker_count = config_auto_int_value(conf.get('workers'), CPU_COUNT) @@ -753,18 +792,6 @@ class WorkersStrategy(StrategyBase): return 0.5 - def do_bind_ports(self): - """ - Bind the one listen socket for this strategy. - """ - - try: - self.sock = get_socket(self.conf) - except ConfigFilePortError: - msg = 'bind_port wasn\'t properly set in the config file. ' \ - 'It must be explicitly set to a valid port number.' - return msg - def no_fork_sock(self): """ Return a server listen socket if the server should run in the @@ -773,7 +800,7 @@ class WorkersStrategy(StrategyBase): # Useful for profiling [no forks]. if self.worker_count == 0: - return self.sock + return get_socket(self.conf) def new_worker_socks(self): """ @@ -785,8 +812,8 @@ class WorkersStrategy(StrategyBase): where it will be ignored. """ - while len(self.children) < self.worker_count: - yield self.sock, None + while len(self.tracking_data) < self.worker_count: + yield get_socket(self.conf), None def log_sock_exit(self, sock, _unused): """ @@ -810,7 +837,7 @@ class WorkersStrategy(StrategyBase): self.logger.notice('Started child %s from parent %s', pid, os.getpid()) - self.children.append(pid) + self.tracking_data[pid] = sock def register_worker_exit(self, pid): """ @@ -823,139 +850,22 @@ class WorkersStrategy(StrategyBase): :param int pid: The PID of the worker that exited. """ - if pid in self.children: + sock = self.tracking_data.pop(pid, None) + if sock is None: + self.logger.info('Ignoring wait() result from unknown PID %s', pid) + else: self.logger.error('Removing dead child %s from parent %s', pid, os.getpid()) - self.children.remove(pid) - else: - self.logger.info('Ignoring wait() result from unknown PID %s', pid) + greenio.shutdown_safe(sock) + sock.close() def iter_sockets(self): """ Yields all known listen sockets. """ - if self.sock: - yield self.sock - - -class PortPidState(object): - """ - A helper class for :py:class:`ServersPerPortStrategy` to track listen - sockets and PIDs for each port. - - :param int servers_per_port: The configured number of servers per port. - :param logger: The server's :py:class:`~swift.common.utils.LogAdaptor` - """ - - def __init__(self, servers_per_port, logger): - self.servers_per_port = servers_per_port - self.logger = logger - self.sock_data_by_port = {} - - def sock_for_port(self, port): - """ - :param int port: The port whose socket is desired. - :returns: The bound listen socket for the given port. - """ - - return self.sock_data_by_port[port]['sock'] - - def port_for_sock(self, sock): - """ - :param socket sock: A tracked bound listen socket - :returns: The port the socket is bound to. - """ - - for port, sock_data in self.sock_data_by_port.items(): - if sock_data['sock'] == sock: - return port - - def _pid_to_port_and_index(self, pid): - for port, sock_data in self.sock_data_by_port.items(): - for server_idx, a_pid in enumerate(sock_data['pids']): - if pid == a_pid: - return port, server_idx - - def port_index_pairs(self): - """ - Returns current (port, server index) pairs. - - :returns: A set of (port, server_idx) tuples for currently-tracked - ports, sockets, and PIDs. - """ - - current_port_index_pairs = set() - for port, pid_state in self.sock_data_by_port.items(): - current_port_index_pairs |= set( - (port, i) - for i, pid in enumerate(pid_state['pids']) - if pid is not None) - return current_port_index_pairs - - def track_port(self, port, sock): - """ - Start tracking servers for the given port and listen socket. - - :param int port: The port to start tracking - :param socket sock: The bound listen socket for the port. - """ - - self.sock_data_by_port[port] = { - 'sock': sock, - 'pids': [None] * self.servers_per_port, - } - - def not_tracking(self, port): - """ - Return True if the specified port is not being tracked. - - :param int port: A port to check. - """ - - return port not in self.sock_data_by_port - - def all_socks(self): - """ - Yield all current listen sockets. - """ - - for orphan_data in self.sock_data_by_port.values(): - yield orphan_data['sock'] - - def forget_port(self, port): - """ - Idempotently forget a port, closing the listen socket at most once. - """ - - orphan_data = self.sock_data_by_port.pop(port, None) - if orphan_data: - greenio.shutdown_safe(orphan_data['sock']) - orphan_data['sock'].close() - self.logger.notice('Closing unnecessary sock for port %d', port) - - def add_pid(self, port, index, pid): - self.sock_data_by_port[port]['pids'][index] = pid - - def forget_pid(self, pid): - """ - Idempotently forget a PID. It's okay if the PID is no longer in our - data structure (it could have been removed by the "orphan port" removal - in :py:meth:`new_worker_socks`). - - :param int pid: The PID which exited. - """ - - port_server_idx = self._pid_to_port_and_index(pid) - if port_server_idx is None: - # This method can lose a race with the "orphan port" removal, when - # a ring reload no longer contains a port. So it's okay if we were - # unable to find a (port, server_idx) pair. - return - dead_port, server_idx = port_server_idx - self.logger.error('Removing dead child %d (PID: %s) for port %s', - server_idx, pid, dead_port) - self.sock_data_by_port[dead_port]['pids'][server_idx] = None + for sock in self.tracking_data.values(): + yield sock class ServersPerPortStrategy(StrategyBase): @@ -965,6 +875,8 @@ class ServersPerPortStrategy(StrategyBase): `servers_per_port` integer config setting determines how many workers are run per port. + Tracking data is a map like ``port -> [(pid, socket), ...]``. + Used in :py:func:`run_wsgi`. :param dict conf: Server configuration dictionary. @@ -974,12 +886,10 @@ class ServersPerPortStrategy(StrategyBase): """ def __init__(self, conf, logger, servers_per_port): - self.conf = conf - self.logger = logger + super(ServersPerPortStrategy, self).__init__(conf, logger) self.servers_per_port = servers_per_port self.swift_dir = conf.get('swift_dir', '/etc/swift') self.ring_check_interval = int(conf.get('ring_check_interval', 15)) - self.port_pid_state = PortPidState(servers_per_port, logger) bind_ip = conf.get('bind_ip', '0.0.0.0') self.cache = BindPortsCache(self.swift_dir, bind_ip) @@ -990,8 +900,7 @@ class ServersPerPortStrategy(StrategyBase): def _bind_port(self, port): new_conf = self.conf.copy() new_conf['bind_port'] = port - sock = get_socket(new_conf) - self.port_pid_state.track_port(port, sock) + return get_socket(new_conf) def loop_timeout(self): """ @@ -1003,15 +912,6 @@ class ServersPerPortStrategy(StrategyBase): return self.ring_check_interval - def do_bind_ports(self): - """ - Bind one listen socket per unique local storage policy ring port. - """ - - self._reload_bind_ports() - for port in self.bind_ports: - self._bind_port(port) - def no_fork_sock(self): """ This strategy does not support running in the foreground. @@ -1021,8 +921,8 @@ class ServersPerPortStrategy(StrategyBase): def new_worker_socks(self): """ - Yield a sequence of (socket, server_idx) tuples for each server which - should be forked-off and started. + Yield a sequence of (socket, (port, server_idx)) tuples for each server + which should be forked-off and started. Any sockets for "orphaned" ports no longer in any ring will be closed (causing their associated workers to gracefully exit) after all new @@ -1033,11 +933,15 @@ class ServersPerPortStrategy(StrategyBase): """ self._reload_bind_ports() - desired_port_index_pairs = set( + desired_port_index_pairs = { (p, i) for p in self.bind_ports - for i in range(self.servers_per_port)) + for i in range(self.servers_per_port)} - current_port_index_pairs = self.port_pid_state.port_index_pairs() + current_port_index_pairs = { + (p, i) + for p, port_data in self.tracking_data.items() + for i, (pid, sock) in enumerate(port_data) + if pid is not None} if desired_port_index_pairs != current_port_index_pairs: # Orphan ports are ports which had object-server processes running, @@ -1046,36 +950,44 @@ class ServersPerPortStrategy(StrategyBase): orphan_port_index_pairs = current_port_index_pairs - \ desired_port_index_pairs - # Fork off worker(s) for every port who's supposed to have + # Fork off worker(s) for every port that's supposed to have # worker(s) but doesn't missing_port_index_pairs = desired_port_index_pairs - \ current_port_index_pairs for port, server_idx in sorted(missing_port_index_pairs): - if self.port_pid_state.not_tracking(port): - try: - self._bind_port(port) - except Exception as e: - self.logger.critical('Unable to bind to port %d: %s', - port, e) - continue - yield self.port_pid_state.sock_for_port(port), server_idx + try: + sock = self._bind_port(port) + except Exception as e: + self.logger.critical('Unable to bind to port %d: %s', + port, e) + continue + yield sock, (port, server_idx) - for orphan_pair in orphan_port_index_pairs: + for port, idx in orphan_port_index_pairs: # For any port in orphan_port_index_pairs, it is guaranteed # that there should be no listen socket for that port, so we # can close and forget them. - self.port_pid_state.forget_port(orphan_pair[0]) + pid, sock = self.tracking_data[port][idx] + greenio.shutdown_safe(sock) + sock.close() + self.logger.notice( + 'Closing unnecessary sock for port %d (child pid %d)', + port, pid) + self.tracking_data[port][idx] = (None, None) + if all(sock is None + for _pid, sock in self.tracking_data[port]): + del self.tracking_data[port] - def log_sock_exit(self, sock, server_idx): + def log_sock_exit(self, sock, data): """ Log a server's exit. """ - port = self.port_pid_state.port_for_sock(sock) + port, server_idx = data self.logger.notice('Child %d (PID %d, port %d) exiting normally', server_idx, os.getpid(), port) - def register_worker_start(self, sock, server_idx, pid): + def register_worker_start(self, sock, data, pid): """ Called when a new worker is started. @@ -1085,10 +997,12 @@ class ServersPerPortStrategy(StrategyBase): :param int pid: The new worker process' PID """ - port = self.port_pid_state.port_for_sock(sock) + port, server_idx = data self.logger.notice('Started child %d (PID %d) for port %d', server_idx, pid, port) - self.port_pid_state.add_pid(port, server_idx, pid) + if port not in self.tracking_data: + self.tracking_data[port] = [(None, None)] * self.servers_per_port + self.tracking_data[port][server_idx] = (pid, sock) def register_worker_exit(self, pid): """ @@ -1097,15 +1011,22 @@ class ServersPerPortStrategy(StrategyBase): :param int pid: The PID of the worker that exited. """ - self.port_pid_state.forget_pid(pid) + for port_data in self.tracking_data.values(): + for idx, (child_pid, sock) in enumerate(port_data): + if child_pid == pid: + port_data[idx] = (None, None) + greenio.shutdown_safe(sock) + sock.close() + return def iter_sockets(self): """ Yields all known listen sockets. """ - for sock in self.port_pid_state.all_socks(): - yield sock + for port_data in self.tracking_data.values(): + for _pid, sock in port_data: + yield sock def run_wsgi(conf_path, app_section, *args, **kwargs): @@ -1140,6 +1061,15 @@ def run_wsgi(conf_path, app_section, *args, **kwargs): conf, logger, servers_per_port=servers_per_port) else: strategy = WorkersStrategy(conf, logger) + try: + # Quick sanity check + int(conf['bind_port']) + except (ValueError, KeyError, TypeError): + error_msg = 'bind_port wasn\'t properly set in the config file. ' \ + 'It must be explicitly set to a valid port number.' + logger.error(error_msg) + print(error_msg) + return 1 # patch event before loadapp utils.eventlet_monkey_patch() @@ -1154,35 +1084,14 @@ def run_wsgi(conf_path, app_section, *args, **kwargs): utils.FALLOCATE_RESERVE, utils.FALLOCATE_IS_PERCENT = \ utils.config_fallocate_value(conf.get('fallocate_reserve', '1%')) - # Start listening on bind_addr/port - error_msg = strategy.do_bind_ports() - if error_msg: - logger.error(error_msg) - print(error_msg) - return 1 - # Do some daemonization process hygene before we fork any children or run a # server without forking. clean_up_daemon_hygiene() - # Redirect errors to logger and close stdio. Do this *after* binding ports; - # we use this to signal that the service is ready to accept connections. - capture_stdio(logger) - - # If necessary, signal an old copy of us that it's okay to shutdown its - # listen sockets now because ours are up and ready to receive connections. - reexec_signal_fd = os.getenv(NOTIFY_FD_ENV_KEY) - if reexec_signal_fd: - reexec_signal_fd = int(reexec_signal_fd) - os.write(reexec_signal_fd, str(os.getpid()).encode('utf8')) - os.close(reexec_signal_fd) - - # Finally, signal systemd (if appropriate) that process started properly. - systemd_notify(logger=logger) - no_fork_sock = strategy.no_fork_sock() if no_fork_sock: - run_server(conf, logger, no_fork_sock, global_conf=global_conf) + run_server(conf, logger, no_fork_sock, global_conf=global_conf, + ready_callback=strategy.signal_ready) return 0 def stop_with_signal(signum, *args): @@ -1198,17 +1107,38 @@ def run_wsgi(conf_path, app_section, *args, **kwargs): while running_context[0]: for sock, sock_info in strategy.new_worker_socks(): + read_fd, write_fd = os.pipe() pid = os.fork() if pid == 0: + os.close(read_fd) signal.signal(signal.SIGHUP, signal.SIG_DFL) signal.signal(signal.SIGTERM, signal.SIG_DFL) signal.signal(signal.SIGUSR1, signal.SIG_DFL) strategy.post_fork_hook() - run_server(conf, logger, sock) + + def notify(): + os.write(write_fd, b'ready') + os.close(write_fd) + + run_server(conf, logger, sock, ready_callback=notify) strategy.log_sock_exit(sock, sock_info) return 0 else: - strategy.register_worker_start(sock, sock_info, pid) + os.close(write_fd) + worker_status = os.read(read_fd, 30) + os.close(read_fd) + # TODO: delay this status checking until after we've tried + # to start all workers. But, we currently use the register + # event to know when we've got enough workers :-/ + if worker_status == b'ready': + strategy.register_worker_start(sock, sock_info, pid) + else: + raise Exception( + 'worker did not start normally: %r' % worker_status) + + # TODO: signal_ready() as soon as we have at least one new worker for + # each port, instead of waiting for all of them + strategy.signal_ready() # The strategy may need to pay attention to something in addition to # child process exits (like new ports showing up in a ring). diff --git a/test/unit/common/test_wsgi.py b/test/unit/common/test_wsgi.py index da04d21f3a..cb489dd641 100644 --- a/test/unit/common/test_wsgi.py +++ b/test/unit/common/test_wsgi.py @@ -778,7 +778,7 @@ class TestWSGI(unittest.TestCase): def _initrp(conf_file, app_section, *args, **kwargs): return ( - {'__file__': 'test', 'workers': 0}, + {'__file__': 'test', 'workers': 0, 'bind_port': 12345}, 'logger', 'log_name') @@ -788,7 +788,8 @@ class TestWSGI(unittest.TestCase): def _global_conf_callback(preloaded_app_conf, global_conf): calls['_global_conf_callback'] += 1 self.assertEqual( - preloaded_app_conf, {'__file__': 'test', 'workers': 0}) + preloaded_app_conf, + {'__file__': 'test', 'workers': 0, 'bind_port': 12345}) self.assertEqual(global_conf, {'log_name': 'log_name'}) global_conf['test1'] = to_inject @@ -827,7 +828,7 @@ class TestWSGI(unittest.TestCase): def _initrp(conf_file, app_section, *args, **kwargs): calls['_initrp'] += 1 return ( - {'__file__': 'test', 'workers': 0}, + {'__file__': 'test', 'workers': 0, 'bind_port': 12345}, 'logger', 'log_name') @@ -862,11 +863,17 @@ class TestWSGI(unittest.TestCase): mock_run_server): # Make sure the right strategy gets used in a number of different # config cases. - mock_per_port().do_bind_ports.return_value = 'stop early' - mock_workers().do_bind_ports.return_value = 'stop early' + + class StopAtCreatingSockets(Exception): + '''Dummy exception to make sure we don't actually bind ports''' + + mock_per_port().no_fork_sock.return_value = None + mock_per_port().new_worker_socks.side_effect = StopAtCreatingSockets + mock_workers().no_fork_sock.return_value = None + mock_workers().new_worker_socks.side_effect = StopAtCreatingSockets logger = FakeLogger() stub__initrp = [ - {'__file__': 'test', 'workers': 2}, # conf + {'__file__': 'test', 'workers': 2, 'bind_port': 12345}, # conf logger, 'log_name', ] @@ -878,14 +885,13 @@ class TestWSGI(unittest.TestCase): mock_per_port.reset_mock() mock_workers.reset_mock() logger._clear() - self.assertEqual(1, wsgi.run_wsgi('conf_file', server_type)) - self.assertEqual([ - 'stop early', - ], logger.get_lines_for_level('error')) + with self.assertRaises(StopAtCreatingSockets): + wsgi.run_wsgi('conf_file', server_type) self.assertEqual([], mock_per_port.mock_calls) self.assertEqual([ mock.call(stub__initrp[0], logger), - mock.call().do_bind_ports(), + mock.call().no_fork_sock(), + mock.call().new_worker_socks(), ], mock_workers.mock_calls) stub__initrp[0]['servers_per_port'] = 3 @@ -893,26 +899,24 @@ class TestWSGI(unittest.TestCase): mock_per_port.reset_mock() mock_workers.reset_mock() logger._clear() - self.assertEqual(1, wsgi.run_wsgi('conf_file', server_type)) - self.assertEqual([ - 'stop early', - ], logger.get_lines_for_level('error')) + with self.assertRaises(StopAtCreatingSockets): + wsgi.run_wsgi('conf_file', server_type) self.assertEqual([], mock_per_port.mock_calls) self.assertEqual([ mock.call(stub__initrp[0], logger), - mock.call().do_bind_ports(), + mock.call().no_fork_sock(), + mock.call().new_worker_socks(), ], mock_workers.mock_calls) mock_per_port.reset_mock() mock_workers.reset_mock() logger._clear() - self.assertEqual(1, wsgi.run_wsgi('conf_file', 'object-server')) - self.assertEqual([ - 'stop early', - ], logger.get_lines_for_level('error')) + with self.assertRaises(StopAtCreatingSockets): + wsgi.run_wsgi('conf_file', 'object-server') self.assertEqual([ mock.call(stub__initrp[0], logger, servers_per_port=3), - mock.call().do_bind_ports(), + mock.call().no_fork_sock(), + mock.call().new_worker_socks(), ], mock_per_port.mock_calls) self.assertEqual([], mock_workers.mock_calls) @@ -1331,12 +1335,16 @@ class TestProxyProtocol(ProtocolTest): class CommonTestMixin(object): - def test_post_fork_hook(self): + @mock.patch('swift.common.wsgi.capture_stdio') + def test_post_fork_hook(self, mock_capture): self.strategy.post_fork_hook() self.assertEqual([ mock.call('bob'), ], self.mock_drop_privileges.mock_calls) + self.assertEqual([ + mock.call(self.logger), + ], mock_capture.mock_calls) class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): @@ -1350,9 +1358,9 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): 'bind_ip': '2.3.4.5', } self.servers_per_port = 3 - self.s1, self.s2 = mock.MagicMock(), mock.MagicMock() + self.sockets = [mock.MagicMock() for _ in range(6)] patcher = mock.patch('swift.common.wsgi.get_socket', - side_effect=[self.s1, self.s2]) + side_effect=self.sockets) self.mock_get_socket = patcher.start() self.addCleanup(patcher.stop) patcher = mock.patch('swift.common.wsgi.drop_privileges') @@ -1391,39 +1399,10 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): self.assertEqual(15, self.strategy.loop_timeout()) - def test_bind_ports(self): - self.strategy.do_bind_ports() - - self.assertEqual(set((6006, 6007)), self.strategy.bind_ports) - self.assertEqual([ - mock.call({'workers': 100, # ignored - 'user': 'bob', - 'swift_dir': '/jim/cricket', - 'ring_check_interval': '76', - 'bind_ip': '2.3.4.5', - 'bind_port': 6006}), - mock.call({'workers': 100, # ignored - 'user': 'bob', - 'swift_dir': '/jim/cricket', - 'ring_check_interval': '76', - 'bind_ip': '2.3.4.5', - 'bind_port': 6007}), - ], self.mock_get_socket.mock_calls) - self.assertEqual( - 6006, self.strategy.port_pid_state.port_for_sock(self.s1)) - self.assertEqual( - 6007, self.strategy.port_pid_state.port_for_sock(self.s2)) - # strategy binding no longer does clean_up_deemon_hygene() actions, the - # user of the strategy does. - self.assertEqual([], self.mock_setsid.mock_calls) - self.assertEqual([], self.mock_chdir.mock_calls) - self.assertEqual([], self.mock_umask.mock_calls) - def test_no_fork_sock(self): self.assertIsNone(self.strategy.no_fork_sock()) def test_new_worker_socks(self): - self.strategy.do_bind_ports() self.all_bind_ports_for_node.reset_mock() pid = 88 @@ -1434,8 +1413,12 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): pid += 1 self.assertEqual([ - (self.s1, 0), (self.s1, 1), (self.s1, 2), - (self.s2, 0), (self.s2, 1), (self.s2, 2), + (self.sockets[0], (6006, 0)), + (self.sockets[1], (6006, 1)), + (self.sockets[2], (6006, 2)), + (self.sockets[3], (6007, 0)), + (self.sockets[4], (6007, 1)), + (self.sockets[5], (6007, 2)), ], got_si) self.assertEqual([ 'Started child %d (PID %d) for port %d' % (0, 88, 6006), @@ -1454,8 +1437,8 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): # Get rid of servers for ports which disappear from the ring self.ports = (6007,) self.all_bind_ports_for_node.return_value = set(self.ports) - self.s1.reset_mock() - self.s2.reset_mock() + for s in self.sockets: + s.reset_mock() with mock.patch('swift.common.wsgi.greenio') as mock_greenio: self.assertEqual([], list(self.strategy.new_worker_socks())) @@ -1464,23 +1447,28 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): mock.call(), # ring_check_interval has passed... ], self.all_bind_ports_for_node.mock_calls) self.assertEqual([ - mock.call.shutdown_safe(self.s1), - ], mock_greenio.mock_calls) + [mock.call.close()] + for _ in range(3) + ], [s.mock_calls for s in self.sockets[:3]]) + self.assertEqual({ + ('shutdown_safe', (self.sockets[0],)), + ('shutdown_safe', (self.sockets[1],)), + ('shutdown_safe', (self.sockets[2],)), + }, {call[:2] for call in mock_greenio.mock_calls}) self.assertEqual([ - mock.call.close(), - ], self.s1.mock_calls) - self.assertEqual([], self.s2.mock_calls) # not closed - self.assertEqual([ - 'Closing unnecessary sock for port %d' % 6006, - ], self.logger.get_lines_for_level('notice')) + [] for _ in range(3) + ], [s.mock_calls for s in self.sockets[3:]]) # not closed + self.assertEqual({ + 'Closing unnecessary sock for port %d (child pid %d)' % (6006, p) + for p in range(88, 91) + }, set(self.logger.get_lines_for_level('notice'))) self.logger._clear() # Create new socket & workers for new ports that appear in ring self.ports = (6007, 6009) self.all_bind_ports_for_node.return_value = set(self.ports) - self.s1.reset_mock() - self.s2.reset_mock() - s3 = mock.MagicMock() + for s in self.sockets: + s.reset_mock() self.mock_get_socket.side_effect = Exception('ack') # But first make sure we handle failure to bind to the requested port! @@ -1499,7 +1487,8 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): self.logger._clear() # Will keep trying, so let it succeed again - self.mock_get_socket.side_effect = [s3] + new_sockets = self.mock_get_socket.side_effect = [ + mock.MagicMock() for _ in range(3)] got_si = [] for s, i in self.strategy.new_worker_socks(): @@ -1508,7 +1497,7 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): pid += 1 self.assertEqual([ - (s3, 0), (s3, 1), (s3, 2), + (s, (6009, i)) for i, s in enumerate(new_sockets) ], got_si) self.assertEqual([ 'Started child %d (PID %d) for port %d' % (0, 94, 6009), @@ -1524,6 +1513,11 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): # Restart a guy who died on us self.strategy.register_worker_exit(95) # server_idx == 1 + # TODO: check that the socket got cleaned up + + new_socket = mock.MagicMock() + self.mock_get_socket.side_effect = [new_socket] + got_si = [] for s, i in self.strategy.new_worker_socks(): got_si.append((s, i)) @@ -1531,7 +1525,7 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): pid += 1 self.assertEqual([ - (s3, 1), + (new_socket, (6009, 1)), ], got_si) self.assertEqual([ 'Started child %d (PID %d) for port %d' % (1, 97, 6009), @@ -1539,7 +1533,7 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): self.logger._clear() # Check log_sock_exit - self.strategy.log_sock_exit(self.s2, 2) + self.strategy.log_sock_exit(self.sockets[5], (6007, 2)) self.assertEqual([ 'Child %d (PID %d, port %d) exiting normally' % ( 2, os.getpid(), 6007), @@ -1551,21 +1545,22 @@ class TestServersPerPortStrategy(unittest.TestCase, CommonTestMixin): self.assertIsNone(self.strategy.register_worker_exit(89)) def test_shutdown_sockets(self): - self.strategy.do_bind_ports() + pid = 88 + for s, i in self.strategy.new_worker_socks(): + self.strategy.register_worker_start(s, i, pid) + pid += 1 with mock.patch('swift.common.wsgi.greenio') as mock_greenio: self.strategy.shutdown_sockets() self.assertEqual([ - mock.call.shutdown_safe(self.s1), - mock.call.shutdown_safe(self.s2), + mock.call.shutdown_safe(s) + for s in self.sockets ], mock_greenio.mock_calls) self.assertEqual([ - mock.call.close(), - ], self.s1.mock_calls) - self.assertEqual([ - mock.call.close(), - ], self.s2.mock_calls) + [mock.call.close()] + for _ in range(3) + ], [s.mock_calls for s in self.sockets[:3]]) class TestWorkersStrategy(unittest.TestCase, CommonTestMixin): @@ -1576,8 +1571,9 @@ class TestWorkersStrategy(unittest.TestCase, CommonTestMixin): 'user': 'bob', } self.strategy = wsgi.WorkersStrategy(self.conf, self.logger) + self.mock_socket = mock.Mock() patcher = mock.patch('swift.common.wsgi.get_socket', - return_value='abc') + return_value=self.mock_socket) self.mock_get_socket = patcher.start() self.addCleanup(patcher.stop) patcher = mock.patch('swift.common.wsgi.drop_privileges') @@ -1593,41 +1589,19 @@ class TestWorkersStrategy(unittest.TestCase, CommonTestMixin): # gets checked). self.assertEqual(0.5, self.strategy.loop_timeout()) - def test_binding(self): - self.assertIsNone(self.strategy.do_bind_ports()) - - self.assertEqual('abc', self.strategy.sock) - self.assertEqual([ - mock.call(self.conf), - ], self.mock_get_socket.mock_calls) - # strategy binding no longer drops privileges nor does - # clean_up_deemon_hygene() actions. - self.assertEqual([], self.mock_drop_privileges.mock_calls) - self.assertEqual([], self.mock_clean_up_daemon_hygene.mock_calls) - - self.mock_get_socket.side_effect = wsgi.ConfigFilePortError() - - self.assertEqual( - 'bind_port wasn\'t properly set in the config file. ' - 'It must be explicitly set to a valid port number.', - self.strategy.do_bind_ports()) - def test_no_fork_sock(self): - self.strategy.do_bind_ports() self.assertIsNone(self.strategy.no_fork_sock()) self.conf['workers'] = 0 self.strategy = wsgi.WorkersStrategy(self.conf, self.logger) - self.strategy.do_bind_ports() - self.assertEqual('abc', self.strategy.no_fork_sock()) + self.assertIs(self.mock_socket, self.strategy.no_fork_sock()) def test_new_worker_socks(self): - self.strategy.do_bind_ports() pid = 88 sock_count = 0 for s, i in self.strategy.new_worker_socks(): - self.assertEqual('abc', s) + self.assertEqual(self.mock_socket, s) self.assertIsNone(i) # unused for this strategy self.strategy.register_worker_start(s, 'unused', pid) pid += 1 @@ -1650,7 +1624,7 @@ class TestWorkersStrategy(unittest.TestCase, CommonTestMixin): ], self.logger.get_lines_for_level('error')) for s, i in self.strategy.new_worker_socks(): - self.assertEqual('abc', s) + self.assertEqual(self.mock_socket, s) self.assertIsNone(i) # unused for this strategy self.strategy.register_worker_start(s, 'unused', pid) pid += 1 @@ -1664,23 +1638,23 @@ class TestWorkersStrategy(unittest.TestCase, CommonTestMixin): ], self.logger.get_lines_for_level('notice')) def test_shutdown_sockets(self): - self.mock_get_socket.return_value = mock.MagicMock() - self.strategy.do_bind_ports() + self.mock_get_socket.side_effect = sockets = [ + mock.MagicMock(), mock.MagicMock()] + + pid = 88 + for s, i in self.strategy.new_worker_socks(): + self.strategy.register_worker_start(s, 'unused', pid) + pid += 1 + with mock.patch('swift.common.wsgi.greenio') as mock_greenio: self.strategy.shutdown_sockets() self.assertEqual([ - mock.call.shutdown_safe(self.mock_get_socket.return_value), + mock.call.shutdown_safe(s) + for s in sockets ], mock_greenio.mock_calls) - if six.PY2: - self.assertEqual([ - mock.call.__nonzero__(), - mock.call.close(), - ], self.mock_get_socket.return_value.mock_calls) - else: - self.assertEqual([ - mock.call.__bool__(), - mock.call.close(), - ], self.mock_get_socket.return_value.mock_calls) + self.assertEqual([ + [mock.call.close()] for _ in range(2) + ], [s.mock_calls for s in sockets]) def test_log_sock_exit(self): self.strategy.log_sock_exit('blahblah', 'blahblah')