From c007d0296e49eef70542526a0b49429287aba8f2 Mon Sep 17 00:00:00 2001 From: Clay Gerrard Date: Fri, 19 Nov 2010 12:15:41 -0600 Subject: [PATCH] removed unneeded daemonize function from utils, pulled get_socket out of run_wsgi, reworked test_utils and test_wsgi --- swift/common/daemon.py | 3 +- swift/common/utils.py | 12 ---- swift/common/wsgi.py | 59 +++++++++++------- test/unit/common/test_daemon.py | 3 +- test/unit/common/test_utils.py | 66 +++++++++++--------- test/unit/common/test_wsgi.py | 104 +++++++++++++++++++++++++++++++- 6 files changed, 181 insertions(+), 66 deletions(-) diff --git a/swift/common/daemon.py b/swift/common/daemon.py index faa4112a1d..d305c247f6 100644 --- a/swift/common/daemon.py +++ b/swift/common/daemon.py @@ -38,7 +38,8 @@ class Daemon(object): def run(self, once=False, **kwargs): """Run the daemon""" utils.validate_configuration() - utils.daemonize(self.conf, self.logger, **kwargs) + utils.capture_stdio(self.logger, **kwargs) + utils.drop_privileges(self.conf.get('user', 'swift')) def kill_children(*args): signal.signal(signal.SIGTERM, signal.SIG_IGN) diff --git a/swift/common/utils.py b/swift/common/utils.py index 9311c28fce..1c48c61339 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -419,18 +419,6 @@ def capture_stdio(logger, **kwargs): sys.stderr = LoggerFileObject(logger) -def daemonize(conf, logger, **kwargs): - """ - Perform standard python/linux daemonization operations. - - :param conf: Configuration dict to read settings from (i.e. user) - :param logger: Logger object to handle stdio redirect and uncaught exc - """ - - capture_stdio(logger, **kwargs) - drop_privileges(conf.get('user', 'swift')) - - def parse_options(usage="%prog CONFIG [options]", once=False, test_args=None): """ Parse standard swift server/daemon options with optparse.OptionParser. diff --git a/swift/common/wsgi.py b/swift/common/wsgi.py index 790de69c37..a93c21aa8a 100644 --- a/swift/common/wsgi.py +++ b/swift/common/wsgi.py @@ -56,6 +56,38 @@ def monkey_patch_mimetools(): mimetools.Message.parsetype = parsetype +def get_socket(conf, default_port=8080): + """Bind socket to bind ip:port in conf + + :param conf: Configuration dict to read settings from + :param default_port: port to use if not specified in conf + + :returns : a socket object as returned from socket.listen or ssl.wrap_socket + if conf specifies cert_file + """ + bind_addr = (conf.get('bind_ip', '0.0.0.0'), + int(conf.get('bind_port', default_port))) + sock = None + retry_until = time.time() + 30 + while not sock and time.time() < retry_until: + try: + sock = listen(bind_addr, backlog=int(conf.get('backlog', 4096))) + if 'cert_file' in conf: + sock = ssl.wrap_socket(sock, certfile=conf['cert_file'], + keyfile=conf['key_file']) + except socket.error, err: + if err.args[0] != errno.EADDRINUSE: + raise + sleep(0.1) + if not sock: + raise Exception('Could not bind to %s:%s after trying for 30 seconds' % + bind_addr) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # in my experience, sockets can hang around forever without keepalive + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 600) + return sock + # TODO: pull pieces of this out to test def run_wsgi(conf_file, app_section, *args, **kwargs): @@ -84,29 +116,9 @@ def run_wsgi(conf_file, app_section, *args, **kwargs): # redirect errors to logger and close stdio capture_stdio(logger) - - bind_addr = (conf.get('bind_ip', '0.0.0.0'), - int(conf.get('bind_port', kwargs.get('default_port', 8080)))) - sock = None - retry_until = time.time() + 30 - while not sock and time.time() < retry_until: - try: - sock = listen(bind_addr, backlog=int(conf.get('backlog', 4096))) - if 'cert_file' in conf: - sock = ssl.wrap_socket(sock, certfile=conf['cert_file'], - keyfile=conf['key_file']) - except socket.error, err: - if err.args[0] != errno.EADDRINUSE: - raise - sleep(0.1) - if not sock: - raise Exception('Could not bind to %s:%s after trying for 30 seconds' % - bind_addr) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # in my experience, sockets can hang around forever without keepalive - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 600) - worker_count = int(conf.get('workers', '1')) + # bind to address and port + sock = get_socket(conf, default_port=kwargs.get('default_port', 8080)) + # remaining tasks should not require elevated privileges drop_privileges(conf.get('user', 'swift')) # finally after binding to ports and privilege drop, run app __init__ code @@ -125,6 +137,7 @@ def run_wsgi(conf_file, app_section, *args, **kwargs): raise pool.waitall() + worker_count = int(conf.get('workers', '1')) # Useful for profiling [no forks]. if worker_count == 0: run_server() diff --git a/test/unit/common/test_daemon.py b/test/unit/common/test_daemon.py index 83bb971907..c5b95a3013 100644 --- a/test/unit/common/test_daemon.py +++ b/test/unit/common/test_daemon.py @@ -62,7 +62,8 @@ class TestRunDaemon(unittest.TestCase): def setUp(self): utils.HASH_PATH_SUFFIX = 'endcap' - utils.daemonize = lambda *args: None + utils.drop_privileges = lambda *args: None + utils.capture_stdio = lambda *args: None def tearDown(self): reload(utils) diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index 5427ba65c5..2e9d4b0005 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -40,10 +40,11 @@ class MockOs(): setattr(self, func, self.pass_func) self.called_funcs = {} for func in called_funcs: - c_func = partial(self.called_func, name) + c_func = partial(self.called_func, func) setattr(self, func, c_func) for func in raise_funcs: - setattr(self, func, self.raise_func) + r_func = partial(self.raise_func, func) + setattr(self, func, r_func) def pass_func(self, *args, **kwargs): pass @@ -53,7 +54,8 @@ class MockOs(): def called_func(self, name, *args, **kwargs): self.called_funcs[name] = True - def raise_func(self, *args, **kwargs): + def raise_func(self, name, *args, **kwargs): + self.called_funcs[name] = True raise OSError() def dup2(self, source, target): @@ -378,51 +380,59 @@ log_name = yarr''' self.assertEquals(result, expected) os.unlink('/tmp/test') - def test_daemonize(self): - # default args - conf = {'user': getuser()} - logger = utils.get_logger(None, 'daemon') - - # over-ride utils system modules with mocks - utils.os = MockOs() - utils.sys = MockSys() - - utils.daemonize(conf, logger) - self.assert_(utils.sys.excepthook is not None) - self.assertEquals(utils.os.closed_fds, [0, 1, 2]) - self.assert_(utils.sys.stdout is not None) - self.assert_(utils.sys.stderr is not None) + def test_drop_privileges(self): + user = getuser() + # over-ride os with mock + required_func_calls = ('setgid', 'setuid', 'setsid', 'chdir', 'umask') + utils.os = MockOs(called_funcs=required_func_calls) + # exercise the code + utils.drop_privileges(user) + for func in required_func_calls: + self.assert_(utils.os.called_funcs[func]) # reset; test same args, OSError trying to get session leader - utils.os = MockOs(raise_funcs=('setsid',)) - utils.sys = MockSys() + utils.os = MockOs(called_funcs=required_func_calls, + raise_funcs=('setsid',)) + for func in required_func_calls: + self.assertFalse(utils.os.called_funcs.get(func, False)) + utils.drop_privileges(user) + for func in required_func_calls: + self.assert_(utils.os.called_funcs[func]) - utils.daemonize(conf, logger) + def test_capture_stdio(self): + # stubs + logger = utils.get_logger(None, 'dummy') + + # mock utils system modules + utils.sys = MockSys() + utils.os = MockOs() + + # basic test + utils.capture_stdio(logger) self.assert_(utils.sys.excepthook is not None) self.assertEquals(utils.os.closed_fds, [0, 1, 2]) self.assert_(utils.sys.stdout is not None) self.assert_(utils.sys.stderr is not None) - # reset; test same args, exc when trying to close stdio + # reset; test same args, but exc when trying to close stdio utils.os = MockOs(raise_funcs=('dup2',)) utils.sys = MockSys() - utils.daemonize(conf, logger) + # test unable to close stdio + utils.capture_stdio(logger) self.assert_(utils.sys.excepthook is not None) - # unable to close stdio self.assertEquals(utils.os.closed_fds, []) self.assert_(utils.sys.stdout is not None) self.assert_(utils.sys.stderr is not None) # reset; test some other args + logger = utils.get_logger(None, log_to_console=True) utils.os = MockOs() utils.sys = MockSys() - conf = {'user': getuser()} - logger = utils.get_logger(None, log_to_console=True) - logger = logging.getLogger() - utils.daemonize(conf, logger, capture_stdout=False, - capture_stderr=False) + # test console log + utils.capture_stdio(logger, capture_stdout=False, + capture_stderr=False) self.assert_(utils.sys.excepthook is not None) # when logging to console, stderr remains open self.assertEquals(utils.os.closed_fds, [0, 1]) diff --git a/test/unit/common/test_wsgi.py b/test/unit/common/test_wsgi.py index 1f81962ff3..2df6936a83 100644 --- a/test/unit/common/test_wsgi.py +++ b/test/unit/common/test_wsgi.py @@ -25,12 +25,12 @@ import unittest from getpass import getuser from shutil import rmtree from StringIO import StringIO +from collections import defaultdict from eventlet import sleep from swift.common import wsgi - class TestWSGI(unittest.TestCase): """ Tests for swift.common.wsgi """ @@ -72,5 +72,107 @@ class TestWSGI(unittest.TestCase): sio = StringIO('Content-Type: text/html; charset=ISO-8859-4') self.assertEquals(mimetools.Message(sio).subtype, 'html') + def test_get_socket(self): + # stubs + conf = {} + ssl_conf = { + 'cert_file': '', + 'key_file': '', + } + + # mocks + class MockSocket(): + def __init__(self): + self.opts = defaultdict(dict) + + def setsockopt(self, level, optname, value): + self.opts[level][optname] = value + + def mock_listen(*args, **kwargs): + return MockSocket() + + class MockSsl(): + def __init__(self): + self.wrap_socket_called = [] + + def wrap_socket(self, sock, **kwargs): + self.wrap_socket_called.append(kwargs) + return sock + + # patch + old_listen = wsgi.listen + old_ssl = wsgi.ssl + try: + wsgi.listen = mock_listen + wsgi.ssl = MockSsl() + # test + sock = wsgi.get_socket(conf) + # assert + self.assert_(isinstance(sock, MockSocket)) + expected_socket_opts = { + socket.SOL_SOCKET: { + socket.SO_REUSEADDR: 1, + socket.SO_KEEPALIVE: 1, + }, + socket.IPPROTO_TCP: { + socket.TCP_KEEPIDLE: 600, + }, + } + self.assertEquals(sock.opts, expected_socket_opts) + # test ssl + sock = wsgi.get_socket(ssl_conf) + expected_kwargs = { + 'certfile': '', + 'keyfile': '', + } + self.assertEquals(wsgi.ssl.wrap_socket_called, [expected_kwargs]) + finally: + wsgi.listen = old_listen + wsgi.ssl = old_ssl + + def test_address_in_use(self): + # stubs + conf = {} + + # mocks + def mock_listen(*args, **kwargs): + raise socket.error(errno.EADDRINUSE) + + def value_error_listen(*args, **kwargs): + raise ValueError('fake') + + def mock_sleep(*args): + pass + + class MockTime(): + """Fast clock advances 10 seconds after every call to time + """ + def __init__(self): + self.current_time = old_time.time() + + def time(self, *args, **kwargs): + rv = self.current_time + # advance for next call + self.current_time += 10 + return rv + + old_listen = wsgi.listen + old_sleep = wsgi.sleep + old_time = wsgi.time + try: + wsgi.listen = mock_listen + wsgi.sleep = mock_sleep + wsgi.time = MockTime() + # test error + self.assertRaises(Exception, wsgi.get_socket, conf) + # different error + wsgi.listen = value_error_listen + self.assertRaises(ValueError, wsgi.get_socket, conf) + finally: + wsgi.listen = old_listen + wsgi.sleep = old_sleep + wsgi.time = old_time + + if __name__ == '__main__': unittest.main()