removed unneeded daemonize function from utils, pulled get_socket out of run_wsgi, reworked test_utils and test_wsgi

This commit is contained in:
Clay Gerrard 2010-11-19 12:15:41 -06:00
parent d583fd9bdb
commit c007d0296e
6 changed files with 181 additions and 66 deletions

View File

@ -38,7 +38,8 @@ class Daemon(object):
def run(self, once=False, **kwargs):
"""Run the daemon"""
utils.validate_configuration()
utils.daemonize(self.conf, self.logger, **kwargs)
utils.capture_stdio(self.logger, **kwargs)
utils.drop_privileges(self.conf.get('user', 'swift'))
def kill_children(*args):
signal.signal(signal.SIGTERM, signal.SIG_IGN)

View File

@ -419,18 +419,6 @@ def capture_stdio(logger, **kwargs):
sys.stderr = LoggerFileObject(logger)
def daemonize(conf, logger, **kwargs):
"""
Perform standard python/linux daemonization operations.
:param conf: Configuration dict to read settings from (i.e. user)
:param logger: Logger object to handle stdio redirect and uncaught exc
"""
capture_stdio(logger, **kwargs)
drop_privileges(conf.get('user', 'swift'))
def parse_options(usage="%prog CONFIG [options]", once=False, test_args=None):
"""
Parse standard swift server/daemon options with optparse.OptionParser.

View File

@ -56,6 +56,38 @@ def monkey_patch_mimetools():
mimetools.Message.parsetype = parsetype
def get_socket(conf, default_port=8080):
"""Bind socket to bind ip:port in conf
:param conf: Configuration dict to read settings from
:param default_port: port to use if not specified in conf
:returns : a socket object as returned from socket.listen or ssl.wrap_socket
if conf specifies cert_file
"""
bind_addr = (conf.get('bind_ip', '0.0.0.0'),
int(conf.get('bind_port', default_port)))
sock = None
retry_until = time.time() + 30
while not sock and time.time() < retry_until:
try:
sock = listen(bind_addr, backlog=int(conf.get('backlog', 4096)))
if 'cert_file' in conf:
sock = ssl.wrap_socket(sock, certfile=conf['cert_file'],
keyfile=conf['key_file'])
except socket.error, err:
if err.args[0] != errno.EADDRINUSE:
raise
sleep(0.1)
if not sock:
raise Exception('Could not bind to %s:%s after trying for 30 seconds' %
bind_addr)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# in my experience, sockets can hang around forever without keepalive
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 600)
return sock
# TODO: pull pieces of this out to test
def run_wsgi(conf_file, app_section, *args, **kwargs):
@ -84,29 +116,9 @@ def run_wsgi(conf_file, app_section, *args, **kwargs):
# redirect errors to logger and close stdio
capture_stdio(logger)
bind_addr = (conf.get('bind_ip', '0.0.0.0'),
int(conf.get('bind_port', kwargs.get('default_port', 8080))))
sock = None
retry_until = time.time() + 30
while not sock and time.time() < retry_until:
try:
sock = listen(bind_addr, backlog=int(conf.get('backlog', 4096)))
if 'cert_file' in conf:
sock = ssl.wrap_socket(sock, certfile=conf['cert_file'],
keyfile=conf['key_file'])
except socket.error, err:
if err.args[0] != errno.EADDRINUSE:
raise
sleep(0.1)
if not sock:
raise Exception('Could not bind to %s:%s after trying for 30 seconds' %
bind_addr)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# in my experience, sockets can hang around forever without keepalive
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 600)
worker_count = int(conf.get('workers', '1'))
# bind to address and port
sock = get_socket(conf, default_port=kwargs.get('default_port', 8080))
# remaining tasks should not require elevated privileges
drop_privileges(conf.get('user', 'swift'))
# finally after binding to ports and privilege drop, run app __init__ code
@ -125,6 +137,7 @@ def run_wsgi(conf_file, app_section, *args, **kwargs):
raise
pool.waitall()
worker_count = int(conf.get('workers', '1'))
# Useful for profiling [no forks].
if worker_count == 0:
run_server()

View File

@ -62,7 +62,8 @@ class TestRunDaemon(unittest.TestCase):
def setUp(self):
utils.HASH_PATH_SUFFIX = 'endcap'
utils.daemonize = lambda *args: None
utils.drop_privileges = lambda *args: None
utils.capture_stdio = lambda *args: None
def tearDown(self):
reload(utils)

View File

@ -40,10 +40,11 @@ class MockOs():
setattr(self, func, self.pass_func)
self.called_funcs = {}
for func in called_funcs:
c_func = partial(self.called_func, name)
c_func = partial(self.called_func, func)
setattr(self, func, c_func)
for func in raise_funcs:
setattr(self, func, self.raise_func)
r_func = partial(self.raise_func, func)
setattr(self, func, r_func)
def pass_func(self, *args, **kwargs):
pass
@ -53,7 +54,8 @@ class MockOs():
def called_func(self, name, *args, **kwargs):
self.called_funcs[name] = True
def raise_func(self, *args, **kwargs):
def raise_func(self, name, *args, **kwargs):
self.called_funcs[name] = True
raise OSError()
def dup2(self, source, target):
@ -378,51 +380,59 @@ log_name = yarr'''
self.assertEquals(result, expected)
os.unlink('/tmp/test')
def test_daemonize(self):
# default args
conf = {'user': getuser()}
logger = utils.get_logger(None, 'daemon')
# over-ride utils system modules with mocks
utils.os = MockOs()
utils.sys = MockSys()
utils.daemonize(conf, logger)
self.assert_(utils.sys.excepthook is not None)
self.assertEquals(utils.os.closed_fds, [0, 1, 2])
self.assert_(utils.sys.stdout is not None)
self.assert_(utils.sys.stderr is not None)
def test_drop_privileges(self):
user = getuser()
# over-ride os with mock
required_func_calls = ('setgid', 'setuid', 'setsid', 'chdir', 'umask')
utils.os = MockOs(called_funcs=required_func_calls)
# exercise the code
utils.drop_privileges(user)
for func in required_func_calls:
self.assert_(utils.os.called_funcs[func])
# reset; test same args, OSError trying to get session leader
utils.os = MockOs(raise_funcs=('setsid',))
utils.sys = MockSys()
utils.os = MockOs(called_funcs=required_func_calls,
raise_funcs=('setsid',))
for func in required_func_calls:
self.assertFalse(utils.os.called_funcs.get(func, False))
utils.drop_privileges(user)
for func in required_func_calls:
self.assert_(utils.os.called_funcs[func])
utils.daemonize(conf, logger)
def test_capture_stdio(self):
# stubs
logger = utils.get_logger(None, 'dummy')
# mock utils system modules
utils.sys = MockSys()
utils.os = MockOs()
# basic test
utils.capture_stdio(logger)
self.assert_(utils.sys.excepthook is not None)
self.assertEquals(utils.os.closed_fds, [0, 1, 2])
self.assert_(utils.sys.stdout is not None)
self.assert_(utils.sys.stderr is not None)
# reset; test same args, exc when trying to close stdio
# reset; test same args, but exc when trying to close stdio
utils.os = MockOs(raise_funcs=('dup2',))
utils.sys = MockSys()
utils.daemonize(conf, logger)
# test unable to close stdio
utils.capture_stdio(logger)
self.assert_(utils.sys.excepthook is not None)
# unable to close stdio
self.assertEquals(utils.os.closed_fds, [])
self.assert_(utils.sys.stdout is not None)
self.assert_(utils.sys.stderr is not None)
# reset; test some other args
logger = utils.get_logger(None, log_to_console=True)
utils.os = MockOs()
utils.sys = MockSys()
conf = {'user': getuser()}
logger = utils.get_logger(None, log_to_console=True)
logger = logging.getLogger()
utils.daemonize(conf, logger, capture_stdout=False,
capture_stderr=False)
# test console log
utils.capture_stdio(logger, capture_stdout=False,
capture_stderr=False)
self.assert_(utils.sys.excepthook is not None)
# when logging to console, stderr remains open
self.assertEquals(utils.os.closed_fds, [0, 1])

View File

@ -25,12 +25,12 @@ import unittest
from getpass import getuser
from shutil import rmtree
from StringIO import StringIO
from collections import defaultdict
from eventlet import sleep
from swift.common import wsgi
class TestWSGI(unittest.TestCase):
""" Tests for swift.common.wsgi """
@ -72,5 +72,107 @@ class TestWSGI(unittest.TestCase):
sio = StringIO('Content-Type: text/html; charset=ISO-8859-4')
self.assertEquals(mimetools.Message(sio).subtype, 'html')
def test_get_socket(self):
# stubs
conf = {}
ssl_conf = {
'cert_file': '',
'key_file': '',
}
# mocks
class MockSocket():
def __init__(self):
self.opts = defaultdict(dict)
def setsockopt(self, level, optname, value):
self.opts[level][optname] = value
def mock_listen(*args, **kwargs):
return MockSocket()
class MockSsl():
def __init__(self):
self.wrap_socket_called = []
def wrap_socket(self, sock, **kwargs):
self.wrap_socket_called.append(kwargs)
return sock
# patch
old_listen = wsgi.listen
old_ssl = wsgi.ssl
try:
wsgi.listen = mock_listen
wsgi.ssl = MockSsl()
# test
sock = wsgi.get_socket(conf)
# assert
self.assert_(isinstance(sock, MockSocket))
expected_socket_opts = {
socket.SOL_SOCKET: {
socket.SO_REUSEADDR: 1,
socket.SO_KEEPALIVE: 1,
},
socket.IPPROTO_TCP: {
socket.TCP_KEEPIDLE: 600,
},
}
self.assertEquals(sock.opts, expected_socket_opts)
# test ssl
sock = wsgi.get_socket(ssl_conf)
expected_kwargs = {
'certfile': '',
'keyfile': '',
}
self.assertEquals(wsgi.ssl.wrap_socket_called, [expected_kwargs])
finally:
wsgi.listen = old_listen
wsgi.ssl = old_ssl
def test_address_in_use(self):
# stubs
conf = {}
# mocks
def mock_listen(*args, **kwargs):
raise socket.error(errno.EADDRINUSE)
def value_error_listen(*args, **kwargs):
raise ValueError('fake')
def mock_sleep(*args):
pass
class MockTime():
"""Fast clock advances 10 seconds after every call to time
"""
def __init__(self):
self.current_time = old_time.time()
def time(self, *args, **kwargs):
rv = self.current_time
# advance for next call
self.current_time += 10
return rv
old_listen = wsgi.listen
old_sleep = wsgi.sleep
old_time = wsgi.time
try:
wsgi.listen = mock_listen
wsgi.sleep = mock_sleep
wsgi.time = MockTime()
# test error
self.assertRaises(Exception, wsgi.get_socket, conf)
# different error
wsgi.listen = value_error_listen
self.assertRaises(ValueError, wsgi.get_socket, conf)
finally:
wsgi.listen = old_listen
wsgi.sleep = old_sleep
wsgi.time = old_time
if __name__ == '__main__':
unittest.main()