removed unneeded daemonize function from utils, pulled get_socket out of run_wsgi, reworked test_utils and test_wsgi
This commit is contained in:
parent
d583fd9bdb
commit
c007d0296e
@ -38,7 +38,8 @@ class Daemon(object):
|
||||
def run(self, once=False, **kwargs):
|
||||
"""Run the daemon"""
|
||||
utils.validate_configuration()
|
||||
utils.daemonize(self.conf, self.logger, **kwargs)
|
||||
utils.capture_stdio(self.logger, **kwargs)
|
||||
utils.drop_privileges(self.conf.get('user', 'swift'))
|
||||
|
||||
def kill_children(*args):
|
||||
signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
||||
|
@ -419,18 +419,6 @@ def capture_stdio(logger, **kwargs):
|
||||
sys.stderr = LoggerFileObject(logger)
|
||||
|
||||
|
||||
def daemonize(conf, logger, **kwargs):
|
||||
"""
|
||||
Perform standard python/linux daemonization operations.
|
||||
|
||||
:param conf: Configuration dict to read settings from (i.e. user)
|
||||
:param logger: Logger object to handle stdio redirect and uncaught exc
|
||||
"""
|
||||
|
||||
capture_stdio(logger, **kwargs)
|
||||
drop_privileges(conf.get('user', 'swift'))
|
||||
|
||||
|
||||
def parse_options(usage="%prog CONFIG [options]", once=False, test_args=None):
|
||||
"""
|
||||
Parse standard swift server/daemon options with optparse.OptionParser.
|
||||
|
@ -56,6 +56,38 @@ def monkey_patch_mimetools():
|
||||
|
||||
mimetools.Message.parsetype = parsetype
|
||||
|
||||
def get_socket(conf, default_port=8080):
|
||||
"""Bind socket to bind ip:port in conf
|
||||
|
||||
:param conf: Configuration dict to read settings from
|
||||
:param default_port: port to use if not specified in conf
|
||||
|
||||
:returns : a socket object as returned from socket.listen or ssl.wrap_socket
|
||||
if conf specifies cert_file
|
||||
"""
|
||||
bind_addr = (conf.get('bind_ip', '0.0.0.0'),
|
||||
int(conf.get('bind_port', default_port)))
|
||||
sock = None
|
||||
retry_until = time.time() + 30
|
||||
while not sock and time.time() < retry_until:
|
||||
try:
|
||||
sock = listen(bind_addr, backlog=int(conf.get('backlog', 4096)))
|
||||
if 'cert_file' in conf:
|
||||
sock = ssl.wrap_socket(sock, certfile=conf['cert_file'],
|
||||
keyfile=conf['key_file'])
|
||||
except socket.error, err:
|
||||
if err.args[0] != errno.EADDRINUSE:
|
||||
raise
|
||||
sleep(0.1)
|
||||
if not sock:
|
||||
raise Exception('Could not bind to %s:%s after trying for 30 seconds' %
|
||||
bind_addr)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
# in my experience, sockets can hang around forever without keepalive
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 600)
|
||||
return sock
|
||||
|
||||
|
||||
# TODO: pull pieces of this out to test
|
||||
def run_wsgi(conf_file, app_section, *args, **kwargs):
|
||||
@ -84,29 +116,9 @@ def run_wsgi(conf_file, app_section, *args, **kwargs):
|
||||
|
||||
# redirect errors to logger and close stdio
|
||||
capture_stdio(logger)
|
||||
|
||||
bind_addr = (conf.get('bind_ip', '0.0.0.0'),
|
||||
int(conf.get('bind_port', kwargs.get('default_port', 8080))))
|
||||
sock = None
|
||||
retry_until = time.time() + 30
|
||||
while not sock and time.time() < retry_until:
|
||||
try:
|
||||
sock = listen(bind_addr, backlog=int(conf.get('backlog', 4096)))
|
||||
if 'cert_file' in conf:
|
||||
sock = ssl.wrap_socket(sock, certfile=conf['cert_file'],
|
||||
keyfile=conf['key_file'])
|
||||
except socket.error, err:
|
||||
if err.args[0] != errno.EADDRINUSE:
|
||||
raise
|
||||
sleep(0.1)
|
||||
if not sock:
|
||||
raise Exception('Could not bind to %s:%s after trying for 30 seconds' %
|
||||
bind_addr)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
# in my experience, sockets can hang around forever without keepalive
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 600)
|
||||
worker_count = int(conf.get('workers', '1'))
|
||||
# bind to address and port
|
||||
sock = get_socket(conf, default_port=kwargs.get('default_port', 8080))
|
||||
# remaining tasks should not require elevated privileges
|
||||
drop_privileges(conf.get('user', 'swift'))
|
||||
|
||||
# finally after binding to ports and privilege drop, run app __init__ code
|
||||
@ -125,6 +137,7 @@ def run_wsgi(conf_file, app_section, *args, **kwargs):
|
||||
raise
|
||||
pool.waitall()
|
||||
|
||||
worker_count = int(conf.get('workers', '1'))
|
||||
# Useful for profiling [no forks].
|
||||
if worker_count == 0:
|
||||
run_server()
|
||||
|
@ -62,7 +62,8 @@ class TestRunDaemon(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
utils.HASH_PATH_SUFFIX = 'endcap'
|
||||
utils.daemonize = lambda *args: None
|
||||
utils.drop_privileges = lambda *args: None
|
||||
utils.capture_stdio = lambda *args: None
|
||||
|
||||
def tearDown(self):
|
||||
reload(utils)
|
||||
|
@ -40,10 +40,11 @@ class MockOs():
|
||||
setattr(self, func, self.pass_func)
|
||||
self.called_funcs = {}
|
||||
for func in called_funcs:
|
||||
c_func = partial(self.called_func, name)
|
||||
c_func = partial(self.called_func, func)
|
||||
setattr(self, func, c_func)
|
||||
for func in raise_funcs:
|
||||
setattr(self, func, self.raise_func)
|
||||
r_func = partial(self.raise_func, func)
|
||||
setattr(self, func, r_func)
|
||||
|
||||
def pass_func(self, *args, **kwargs):
|
||||
pass
|
||||
@ -53,7 +54,8 @@ class MockOs():
|
||||
def called_func(self, name, *args, **kwargs):
|
||||
self.called_funcs[name] = True
|
||||
|
||||
def raise_func(self, *args, **kwargs):
|
||||
def raise_func(self, name, *args, **kwargs):
|
||||
self.called_funcs[name] = True
|
||||
raise OSError()
|
||||
|
||||
def dup2(self, source, target):
|
||||
@ -378,50 +380,58 @@ log_name = yarr'''
|
||||
self.assertEquals(result, expected)
|
||||
os.unlink('/tmp/test')
|
||||
|
||||
def test_daemonize(self):
|
||||
# default args
|
||||
conf = {'user': getuser()}
|
||||
logger = utils.get_logger(None, 'daemon')
|
||||
|
||||
# over-ride utils system modules with mocks
|
||||
utils.os = MockOs()
|
||||
utils.sys = MockSys()
|
||||
|
||||
utils.daemonize(conf, logger)
|
||||
self.assert_(utils.sys.excepthook is not None)
|
||||
self.assertEquals(utils.os.closed_fds, [0, 1, 2])
|
||||
self.assert_(utils.sys.stdout is not None)
|
||||
self.assert_(utils.sys.stderr is not None)
|
||||
def test_drop_privileges(self):
|
||||
user = getuser()
|
||||
# over-ride os with mock
|
||||
required_func_calls = ('setgid', 'setuid', 'setsid', 'chdir', 'umask')
|
||||
utils.os = MockOs(called_funcs=required_func_calls)
|
||||
# exercise the code
|
||||
utils.drop_privileges(user)
|
||||
for func in required_func_calls:
|
||||
self.assert_(utils.os.called_funcs[func])
|
||||
|
||||
# reset; test same args, OSError trying to get session leader
|
||||
utils.os = MockOs(raise_funcs=('setsid',))
|
||||
utils.sys = MockSys()
|
||||
utils.os = MockOs(called_funcs=required_func_calls,
|
||||
raise_funcs=('setsid',))
|
||||
for func in required_func_calls:
|
||||
self.assertFalse(utils.os.called_funcs.get(func, False))
|
||||
utils.drop_privileges(user)
|
||||
for func in required_func_calls:
|
||||
self.assert_(utils.os.called_funcs[func])
|
||||
|
||||
utils.daemonize(conf, logger)
|
||||
def test_capture_stdio(self):
|
||||
# stubs
|
||||
logger = utils.get_logger(None, 'dummy')
|
||||
|
||||
# mock utils system modules
|
||||
utils.sys = MockSys()
|
||||
utils.os = MockOs()
|
||||
|
||||
# basic test
|
||||
utils.capture_stdio(logger)
|
||||
self.assert_(utils.sys.excepthook is not None)
|
||||
self.assertEquals(utils.os.closed_fds, [0, 1, 2])
|
||||
self.assert_(utils.sys.stdout is not None)
|
||||
self.assert_(utils.sys.stderr is not None)
|
||||
|
||||
# reset; test same args, exc when trying to close stdio
|
||||
# reset; test same args, but exc when trying to close stdio
|
||||
utils.os = MockOs(raise_funcs=('dup2',))
|
||||
utils.sys = MockSys()
|
||||
|
||||
utils.daemonize(conf, logger)
|
||||
# test unable to close stdio
|
||||
utils.capture_stdio(logger)
|
||||
self.assert_(utils.sys.excepthook is not None)
|
||||
# unable to close stdio
|
||||
self.assertEquals(utils.os.closed_fds, [])
|
||||
self.assert_(utils.sys.stdout is not None)
|
||||
self.assert_(utils.sys.stderr is not None)
|
||||
|
||||
# reset; test some other args
|
||||
logger = utils.get_logger(None, log_to_console=True)
|
||||
utils.os = MockOs()
|
||||
utils.sys = MockSys()
|
||||
|
||||
conf = {'user': getuser()}
|
||||
logger = utils.get_logger(None, log_to_console=True)
|
||||
logger = logging.getLogger()
|
||||
utils.daemonize(conf, logger, capture_stdout=False,
|
||||
# test console log
|
||||
utils.capture_stdio(logger, capture_stdout=False,
|
||||
capture_stderr=False)
|
||||
self.assert_(utils.sys.excepthook is not None)
|
||||
# when logging to console, stderr remains open
|
||||
|
@ -25,12 +25,12 @@ import unittest
|
||||
from getpass import getuser
|
||||
from shutil import rmtree
|
||||
from StringIO import StringIO
|
||||
from collections import defaultdict
|
||||
|
||||
from eventlet import sleep
|
||||
|
||||
from swift.common import wsgi
|
||||
|
||||
|
||||
class TestWSGI(unittest.TestCase):
|
||||
""" Tests for swift.common.wsgi """
|
||||
|
||||
@ -72,5 +72,107 @@ class TestWSGI(unittest.TestCase):
|
||||
sio = StringIO('Content-Type: text/html; charset=ISO-8859-4')
|
||||
self.assertEquals(mimetools.Message(sio).subtype, 'html')
|
||||
|
||||
def test_get_socket(self):
|
||||
# stubs
|
||||
conf = {}
|
||||
ssl_conf = {
|
||||
'cert_file': '',
|
||||
'key_file': '',
|
||||
}
|
||||
|
||||
# mocks
|
||||
class MockSocket():
|
||||
def __init__(self):
|
||||
self.opts = defaultdict(dict)
|
||||
|
||||
def setsockopt(self, level, optname, value):
|
||||
self.opts[level][optname] = value
|
||||
|
||||
def mock_listen(*args, **kwargs):
|
||||
return MockSocket()
|
||||
|
||||
class MockSsl():
|
||||
def __init__(self):
|
||||
self.wrap_socket_called = []
|
||||
|
||||
def wrap_socket(self, sock, **kwargs):
|
||||
self.wrap_socket_called.append(kwargs)
|
||||
return sock
|
||||
|
||||
# patch
|
||||
old_listen = wsgi.listen
|
||||
old_ssl = wsgi.ssl
|
||||
try:
|
||||
wsgi.listen = mock_listen
|
||||
wsgi.ssl = MockSsl()
|
||||
# test
|
||||
sock = wsgi.get_socket(conf)
|
||||
# assert
|
||||
self.assert_(isinstance(sock, MockSocket))
|
||||
expected_socket_opts = {
|
||||
socket.SOL_SOCKET: {
|
||||
socket.SO_REUSEADDR: 1,
|
||||
socket.SO_KEEPALIVE: 1,
|
||||
},
|
||||
socket.IPPROTO_TCP: {
|
||||
socket.TCP_KEEPIDLE: 600,
|
||||
},
|
||||
}
|
||||
self.assertEquals(sock.opts, expected_socket_opts)
|
||||
# test ssl
|
||||
sock = wsgi.get_socket(ssl_conf)
|
||||
expected_kwargs = {
|
||||
'certfile': '',
|
||||
'keyfile': '',
|
||||
}
|
||||
self.assertEquals(wsgi.ssl.wrap_socket_called, [expected_kwargs])
|
||||
finally:
|
||||
wsgi.listen = old_listen
|
||||
wsgi.ssl = old_ssl
|
||||
|
||||
def test_address_in_use(self):
|
||||
# stubs
|
||||
conf = {}
|
||||
|
||||
# mocks
|
||||
def mock_listen(*args, **kwargs):
|
||||
raise socket.error(errno.EADDRINUSE)
|
||||
|
||||
def value_error_listen(*args, **kwargs):
|
||||
raise ValueError('fake')
|
||||
|
||||
def mock_sleep(*args):
|
||||
pass
|
||||
|
||||
class MockTime():
|
||||
"""Fast clock advances 10 seconds after every call to time
|
||||
"""
|
||||
def __init__(self):
|
||||
self.current_time = old_time.time()
|
||||
|
||||
def time(self, *args, **kwargs):
|
||||
rv = self.current_time
|
||||
# advance for next call
|
||||
self.current_time += 10
|
||||
return rv
|
||||
|
||||
old_listen = wsgi.listen
|
||||
old_sleep = wsgi.sleep
|
||||
old_time = wsgi.time
|
||||
try:
|
||||
wsgi.listen = mock_listen
|
||||
wsgi.sleep = mock_sleep
|
||||
wsgi.time = MockTime()
|
||||
# test error
|
||||
self.assertRaises(Exception, wsgi.get_socket, conf)
|
||||
# different error
|
||||
wsgi.listen = value_error_listen
|
||||
self.assertRaises(ValueError, wsgi.get_socket, conf)
|
||||
finally:
|
||||
wsgi.listen = old_listen
|
||||
wsgi.sleep = old_sleep
|
||||
wsgi.time = old_time
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user