removed unneeded daemonize function from utils, pulled get_socket out of run_wsgi, reworked test_utils and test_wsgi

This commit is contained in:
Clay Gerrard 2010-11-19 12:15:41 -06:00
parent d583fd9bdb
commit c007d0296e
6 changed files with 181 additions and 66 deletions

View File

@ -38,7 +38,8 @@ class Daemon(object):
def run(self, once=False, **kwargs): def run(self, once=False, **kwargs):
"""Run the daemon""" """Run the daemon"""
utils.validate_configuration() utils.validate_configuration()
utils.daemonize(self.conf, self.logger, **kwargs) utils.capture_stdio(self.logger, **kwargs)
utils.drop_privileges(self.conf.get('user', 'swift'))
def kill_children(*args): def kill_children(*args):
signal.signal(signal.SIGTERM, signal.SIG_IGN) signal.signal(signal.SIGTERM, signal.SIG_IGN)

View File

@ -419,18 +419,6 @@ def capture_stdio(logger, **kwargs):
sys.stderr = LoggerFileObject(logger) sys.stderr = LoggerFileObject(logger)
def daemonize(conf, logger, **kwargs):
"""
Perform standard python/linux daemonization operations.
:param conf: Configuration dict to read settings from (i.e. user)
:param logger: Logger object to handle stdio redirect and uncaught exc
"""
capture_stdio(logger, **kwargs)
drop_privileges(conf.get('user', 'swift'))
def parse_options(usage="%prog CONFIG [options]", once=False, test_args=None): def parse_options(usage="%prog CONFIG [options]", once=False, test_args=None):
""" """
Parse standard swift server/daemon options with optparse.OptionParser. Parse standard swift server/daemon options with optparse.OptionParser.

View File

@ -56,6 +56,38 @@ def monkey_patch_mimetools():
mimetools.Message.parsetype = parsetype mimetools.Message.parsetype = parsetype
def get_socket(conf, default_port=8080):
"""Bind socket to bind ip:port in conf
:param conf: Configuration dict to read settings from
:param default_port: port to use if not specified in conf
:returns : a socket object as returned from socket.listen or ssl.wrap_socket
if conf specifies cert_file
"""
bind_addr = (conf.get('bind_ip', '0.0.0.0'),
int(conf.get('bind_port', default_port)))
sock = None
retry_until = time.time() + 30
while not sock and time.time() < retry_until:
try:
sock = listen(bind_addr, backlog=int(conf.get('backlog', 4096)))
if 'cert_file' in conf:
sock = ssl.wrap_socket(sock, certfile=conf['cert_file'],
keyfile=conf['key_file'])
except socket.error, err:
if err.args[0] != errno.EADDRINUSE:
raise
sleep(0.1)
if not sock:
raise Exception('Could not bind to %s:%s after trying for 30 seconds' %
bind_addr)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# in my experience, sockets can hang around forever without keepalive
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 600)
return sock
# TODO: pull pieces of this out to test # TODO: pull pieces of this out to test
def run_wsgi(conf_file, app_section, *args, **kwargs): def run_wsgi(conf_file, app_section, *args, **kwargs):
@ -84,29 +116,9 @@ def run_wsgi(conf_file, app_section, *args, **kwargs):
# redirect errors to logger and close stdio # redirect errors to logger and close stdio
capture_stdio(logger) capture_stdio(logger)
# bind to address and port
bind_addr = (conf.get('bind_ip', '0.0.0.0'), sock = get_socket(conf, default_port=kwargs.get('default_port', 8080))
int(conf.get('bind_port', kwargs.get('default_port', 8080)))) # remaining tasks should not require elevated privileges
sock = None
retry_until = time.time() + 30
while not sock and time.time() < retry_until:
try:
sock = listen(bind_addr, backlog=int(conf.get('backlog', 4096)))
if 'cert_file' in conf:
sock = ssl.wrap_socket(sock, certfile=conf['cert_file'],
keyfile=conf['key_file'])
except socket.error, err:
if err.args[0] != errno.EADDRINUSE:
raise
sleep(0.1)
if not sock:
raise Exception('Could not bind to %s:%s after trying for 30 seconds' %
bind_addr)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# in my experience, sockets can hang around forever without keepalive
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 600)
worker_count = int(conf.get('workers', '1'))
drop_privileges(conf.get('user', 'swift')) drop_privileges(conf.get('user', 'swift'))
# finally after binding to ports and privilege drop, run app __init__ code # finally after binding to ports and privilege drop, run app __init__ code
@ -125,6 +137,7 @@ def run_wsgi(conf_file, app_section, *args, **kwargs):
raise raise
pool.waitall() pool.waitall()
worker_count = int(conf.get('workers', '1'))
# Useful for profiling [no forks]. # Useful for profiling [no forks].
if worker_count == 0: if worker_count == 0:
run_server() run_server()

View File

@ -62,7 +62,8 @@ class TestRunDaemon(unittest.TestCase):
def setUp(self): def setUp(self):
utils.HASH_PATH_SUFFIX = 'endcap' utils.HASH_PATH_SUFFIX = 'endcap'
utils.daemonize = lambda *args: None utils.drop_privileges = lambda *args: None
utils.capture_stdio = lambda *args: None
def tearDown(self): def tearDown(self):
reload(utils) reload(utils)

View File

@ -40,10 +40,11 @@ class MockOs():
setattr(self, func, self.pass_func) setattr(self, func, self.pass_func)
self.called_funcs = {} self.called_funcs = {}
for func in called_funcs: for func in called_funcs:
c_func = partial(self.called_func, name) c_func = partial(self.called_func, func)
setattr(self, func, c_func) setattr(self, func, c_func)
for func in raise_funcs: for func in raise_funcs:
setattr(self, func, self.raise_func) r_func = partial(self.raise_func, func)
setattr(self, func, r_func)
def pass_func(self, *args, **kwargs): def pass_func(self, *args, **kwargs):
pass pass
@ -53,7 +54,8 @@ class MockOs():
def called_func(self, name, *args, **kwargs): def called_func(self, name, *args, **kwargs):
self.called_funcs[name] = True self.called_funcs[name] = True
def raise_func(self, *args, **kwargs): def raise_func(self, name, *args, **kwargs):
self.called_funcs[name] = True
raise OSError() raise OSError()
def dup2(self, source, target): def dup2(self, source, target):
@ -378,50 +380,58 @@ log_name = yarr'''
self.assertEquals(result, expected) self.assertEquals(result, expected)
os.unlink('/tmp/test') os.unlink('/tmp/test')
def test_daemonize(self): def test_drop_privileges(self):
# default args user = getuser()
conf = {'user': getuser()} # over-ride os with mock
logger = utils.get_logger(None, 'daemon') required_func_calls = ('setgid', 'setuid', 'setsid', 'chdir', 'umask')
utils.os = MockOs(called_funcs=required_func_calls)
# over-ride utils system modules with mocks # exercise the code
utils.os = MockOs() utils.drop_privileges(user)
utils.sys = MockSys() for func in required_func_calls:
self.assert_(utils.os.called_funcs[func])
utils.daemonize(conf, logger)
self.assert_(utils.sys.excepthook is not None)
self.assertEquals(utils.os.closed_fds, [0, 1, 2])
self.assert_(utils.sys.stdout is not None)
self.assert_(utils.sys.stderr is not None)
# reset; test same args, OSError trying to get session leader # reset; test same args, OSError trying to get session leader
utils.os = MockOs(raise_funcs=('setsid',)) utils.os = MockOs(called_funcs=required_func_calls,
utils.sys = MockSys() raise_funcs=('setsid',))
for func in required_func_calls:
self.assertFalse(utils.os.called_funcs.get(func, False))
utils.drop_privileges(user)
for func in required_func_calls:
self.assert_(utils.os.called_funcs[func])
utils.daemonize(conf, logger) def test_capture_stdio(self):
# stubs
logger = utils.get_logger(None, 'dummy')
# mock utils system modules
utils.sys = MockSys()
utils.os = MockOs()
# basic test
utils.capture_stdio(logger)
self.assert_(utils.sys.excepthook is not None) self.assert_(utils.sys.excepthook is not None)
self.assertEquals(utils.os.closed_fds, [0, 1, 2]) self.assertEquals(utils.os.closed_fds, [0, 1, 2])
self.assert_(utils.sys.stdout is not None) self.assert_(utils.sys.stdout is not None)
self.assert_(utils.sys.stderr is not None) self.assert_(utils.sys.stderr is not None)
# reset; test same args, exc when trying to close stdio # reset; test same args, but exc when trying to close stdio
utils.os = MockOs(raise_funcs=('dup2',)) utils.os = MockOs(raise_funcs=('dup2',))
utils.sys = MockSys() utils.sys = MockSys()
utils.daemonize(conf, logger) # test unable to close stdio
utils.capture_stdio(logger)
self.assert_(utils.sys.excepthook is not None) self.assert_(utils.sys.excepthook is not None)
# unable to close stdio
self.assertEquals(utils.os.closed_fds, []) self.assertEquals(utils.os.closed_fds, [])
self.assert_(utils.sys.stdout is not None) self.assert_(utils.sys.stdout is not None)
self.assert_(utils.sys.stderr is not None) self.assert_(utils.sys.stderr is not None)
# reset; test some other args # reset; test some other args
logger = utils.get_logger(None, log_to_console=True)
utils.os = MockOs() utils.os = MockOs()
utils.sys = MockSys() utils.sys = MockSys()
conf = {'user': getuser()} # test console log
logger = utils.get_logger(None, log_to_console=True) utils.capture_stdio(logger, capture_stdout=False,
logger = logging.getLogger()
utils.daemonize(conf, logger, capture_stdout=False,
capture_stderr=False) capture_stderr=False)
self.assert_(utils.sys.excepthook is not None) self.assert_(utils.sys.excepthook is not None)
# when logging to console, stderr remains open # when logging to console, stderr remains open

View File

@ -25,12 +25,12 @@ import unittest
from getpass import getuser from getpass import getuser
from shutil import rmtree from shutil import rmtree
from StringIO import StringIO from StringIO import StringIO
from collections import defaultdict
from eventlet import sleep from eventlet import sleep
from swift.common import wsgi from swift.common import wsgi
class TestWSGI(unittest.TestCase): class TestWSGI(unittest.TestCase):
""" Tests for swift.common.wsgi """ """ Tests for swift.common.wsgi """
@ -72,5 +72,107 @@ class TestWSGI(unittest.TestCase):
sio = StringIO('Content-Type: text/html; charset=ISO-8859-4') sio = StringIO('Content-Type: text/html; charset=ISO-8859-4')
self.assertEquals(mimetools.Message(sio).subtype, 'html') self.assertEquals(mimetools.Message(sio).subtype, 'html')
def test_get_socket(self):
# stubs
conf = {}
ssl_conf = {
'cert_file': '',
'key_file': '',
}
# mocks
class MockSocket():
def __init__(self):
self.opts = defaultdict(dict)
def setsockopt(self, level, optname, value):
self.opts[level][optname] = value
def mock_listen(*args, **kwargs):
return MockSocket()
class MockSsl():
def __init__(self):
self.wrap_socket_called = []
def wrap_socket(self, sock, **kwargs):
self.wrap_socket_called.append(kwargs)
return sock
# patch
old_listen = wsgi.listen
old_ssl = wsgi.ssl
try:
wsgi.listen = mock_listen
wsgi.ssl = MockSsl()
# test
sock = wsgi.get_socket(conf)
# assert
self.assert_(isinstance(sock, MockSocket))
expected_socket_opts = {
socket.SOL_SOCKET: {
socket.SO_REUSEADDR: 1,
socket.SO_KEEPALIVE: 1,
},
socket.IPPROTO_TCP: {
socket.TCP_KEEPIDLE: 600,
},
}
self.assertEquals(sock.opts, expected_socket_opts)
# test ssl
sock = wsgi.get_socket(ssl_conf)
expected_kwargs = {
'certfile': '',
'keyfile': '',
}
self.assertEquals(wsgi.ssl.wrap_socket_called, [expected_kwargs])
finally:
wsgi.listen = old_listen
wsgi.ssl = old_ssl
def test_address_in_use(self):
# stubs
conf = {}
# mocks
def mock_listen(*args, **kwargs):
raise socket.error(errno.EADDRINUSE)
def value_error_listen(*args, **kwargs):
raise ValueError('fake')
def mock_sleep(*args):
pass
class MockTime():
"""Fast clock advances 10 seconds after every call to time
"""
def __init__(self):
self.current_time = old_time.time()
def time(self, *args, **kwargs):
rv = self.current_time
# advance for next call
self.current_time += 10
return rv
old_listen = wsgi.listen
old_sleep = wsgi.sleep
old_time = wsgi.time
try:
wsgi.listen = mock_listen
wsgi.sleep = mock_sleep
wsgi.time = MockTime()
# test error
self.assertRaises(Exception, wsgi.get_socket, conf)
# different error
wsgi.listen = value_error_listen
self.assertRaises(ValueError, wsgi.get_socket, conf)
finally:
wsgi.listen = old_listen
wsgi.sleep = old_sleep
wsgi.time = old_time
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()