From f7f3349d6a4def52f810ab1728879521c12fe2d0 Mon Sep 17 00:00:00 2001 From: elajkat Date: Tue, 8 Jun 2021 18:09:31 +0200 Subject: [PATCH] Add timeout to PrivContext and entrypoint_with_timeout decorator entrypoint_with_timeout decorator can be used with a timeout parameter, if the timeout is reached PrivsepTimeout is raised. The PrivContext has timeout variable, which will be used for all functions decorated with entrypoint, and PrivsepTimeout is raised if timeout is reached. Co-authored-by: Rodolfo Alonso Change-Id: Ie3b1fc255c0c05fd5403b90ef49b954fe397fb77 Related-Bug: #1930401 --- doc/source/user/index.rst | 55 +++++++++++++++++++ oslo_privsep/comm.py | 54 +++++++++++------- oslo_privsep/daemon.py | 40 +++++--------- oslo_privsep/functional/test_daemon.py | 43 +++++++++++++++ oslo_privsep/priv_context.py | 25 ++++++++- oslo_privsep/tests/test_daemon.py | 6 +- oslo_privsep/tests/test_priv_context.py | 18 ++++++ ...th_timeout_decorator-9aab5a74153b3632.yaml | 11 ++++ 8 files changed, 202 insertions(+), 50 deletions(-) create mode 100644 releasenotes/notes/add_entrypoint_with_timeout_decorator-9aab5a74153b3632.yaml diff --git a/doc/source/user/index.rst b/doc/source/user/index.rst index 85bf8b7..a15c4df 100644 --- a/doc/source/user/index.rst +++ b/doc/source/user/index.rst @@ -30,6 +30,31 @@ defines a sys_admin_pctxt with ``CAP_CHOWN``, ``CAP_DAC_OVERRIDE``, capabilities.CAP_SYS_ADMIN], ) +Defining a context with timeout +------------------------------- + +It is possible to initialize PrivContext with timeout:: + + from oslo_privsep import capabilities + from oslo_privsep import priv_context + + dhcp_release_cmd = priv_context.PrivContext( + __name__, + cfg_section='privsep_dhcp_release', + pypath=__name__ + '.dhcp_release_cmd', + capabilities=[caps.CAP_SYS_ADMIN, + caps.CAP_NET_ADMIN], + timeout=5 + ) + +``PrivsepTimeout`` is raised if timeout is reached. + +.. warning:: + + The daemon (the root process) task won't stop when timeout + is reached. That means we'll have less available threads if the related + thread never finishes. + Defining a privileged function ============================== @@ -51,6 +76,36 @@ generic ``update_file(filename, content)`` was created, it could be used to overwrite any file in the filesystem, allowing easy escalation to root rights. That would defeat the whole purpose of oslo.privsep. +Defining a privileged function with timeout +------------------------------------------- + +It is possible to use ``entrypoint_with_timeout`` decorator:: + + from oslo_privsep import daemon + + from neutron import privileged + + @privileged.default.entrypoint_with_timeout(timeout=5) + def get_link_devices(namespace, **kwargs): + try: + with get_iproute(namespace) as ip: + return make_serializable(ip.get_links(**kwargs)) + except OSError as e: + if e.errno == errno.ENOENT: + raise NetworkNamespaceNotFound(netns_name=namespace) + raise + except daemon.FailedToDropPrivileges: + raise + except daemon.PrivsepTimeout: + raise + +``PrivsepTimeout`` is raised if timeout is reached. + +.. warning:: + + The daemon (the root process) task won't stop when timeout + is reached. That means we'll have less available threads if the related + thread never finishes. Using a privileged function =========================== diff --git a/oslo_privsep/comm.py b/oslo_privsep/comm.py index 32bd800..fb1872f 100644 --- a/oslo_privsep/comm.py +++ b/oslo_privsep/comm.py @@ -20,6 +20,8 @@ python datatypes. Msgpack 'raw' is assumed to be a valid utf8 string converted to tuples during serialization/deserialization. """ +import datetime +import enum import logging import socket import threading @@ -28,22 +30,24 @@ import msgpack import six from oslo_privsep._i18n import _ - +from oslo_utils import uuidutils LOG = logging.getLogger(__name__) -try: - import greenlet +@enum.unique +class Message(enum.IntEnum): + """Types of messages sent across the communication channel""" + PING = 1 + PONG = 2 + CALL = 3 + RET = 4 + ERR = 5 + LOG = 6 - def _get_thread_ident(): - # This returns something sensible, even if the current thread - # isn't a greenthread - return id(greenlet.getcurrent()) -except ImportError: - def _get_thread_ident(): - return threading.current_thread().ident +class PrivsepTimeout(Exception): + pass class Serializer(object): @@ -89,10 +93,11 @@ class Deserializer(six.Iterator): class Future(object): """A very simple object to track the return of a function call""" - def __init__(self, lock): + def __init__(self, lock, timeout=None): self.condvar = threading.Condition(lock) self.error = None self.data = None + self.timeout = timeout def set_result(self, data): """Must already be holding lock used in constructor""" @@ -106,7 +111,16 @@ class Future(object): def result(self): """Must already be holding lock used in constructor""" - self.condvar.wait() + before = datetime.datetime.now() + if not self.condvar.wait(timeout=self.timeout): + now = datetime.datetime.now() + LOG.warning('Timeout while executing a command, timeout: %s, ' + 'time elapsed: %s', self.timeout, + (now - before).total_seconds()) + return (Message.ERR.value, + '%s.%s' % (PrivsepTimeout.__module__, + PrivsepTimeout.__name__), + '') if self.error is not None: raise self.error return self.data @@ -138,8 +152,9 @@ class ClientChannel(object): else: with self.lock: if msgid not in self.outstanding_msgs: - raise AssertionError("msgid should in " - "outstanding_msgs.") + LOG.warning("msgid should be in oustanding_msgs, it is" + "possible that timeout is reached!") + continue self.outstanding_msgs[msgid].set_result(data) # EOF. Perhaps the privileged process exited? @@ -158,13 +173,14 @@ class ClientChannel(object): """Received OOB message. Subclasses might want to override this.""" pass - def send_recv(self, msg): - myid = _get_thread_ident() - future = Future(self.lock) + def send_recv(self, msg, timeout=None): + myid = uuidutils.generate_uuid() + while myid in self.outstanding_msgs: + LOG.warning("myid shoudn't be in outstanding_msgs.") + myid = uuidutils.generate_uuid() + future = Future(self.lock, timeout) with self.lock: - if myid in self.outstanding_msgs: - raise AssertionError("myid shoudn't be in outstanding_msgs.") self.outstanding_msgs[myid] = future try: self.writer.send((myid, msg)) diff --git a/oslo_privsep/daemon.py b/oslo_privsep/daemon.py index cde8f52..8bb20d1 100644 --- a/oslo_privsep/daemon.py +++ b/oslo_privsep/daemon.py @@ -109,17 +109,6 @@ class StdioFd(enum.IntEnum): STDERR = 2 -@enum.unique -class Message(enum.IntEnum): - """Types of messages sent across the communication channel""" - PING = 1 - PONG = 2 - CALL = 3 - RET = 4 - ERR = 5 - LOG = 6 - - class FailedToDropPrivileges(Exception): pass @@ -187,7 +176,7 @@ class PrivsepLogHandler(pylogging.Handler): data['msg'] = record.getMessage() data['args'] = () - self.channel.send((None, (Message.LOG, data))) + self.channel.send((None, (comm.Message.LOG, data))) class _ClientChannel(comm.ClientChannel): @@ -201,8 +190,8 @@ class _ClientChannel(comm.ClientChannel): def exchange_ping(self): try: # exchange "ready" messages - reply = self.send_recv((Message.PING.value,)) - success = reply[0] == Message.PONG + reply = self.send_recv((comm.Message.PING.value,)) + success = reply[0] == comm.Message.PONG except Exception as e: self.log.exception('Error while sending initial PING to privsep: ' '%s', e) @@ -212,12 +201,13 @@ class _ClientChannel(comm.ClientChannel): self.log.critical(msg) raise FailedToDropPrivileges(msg) - def remote_call(self, name, args, kwargs): - result = self.send_recv((Message.CALL.value, name, args, kwargs)) - if result[0] == Message.RET: + def remote_call(self, name, args, kwargs, timeout): + result = self.send_recv((comm.Message.CALL.value, name, args, kwargs), + timeout) + if result[0] == comm.Message.RET: # (RET, return value) return result[1] - elif result[0] == Message.ERR: + elif result[0] == comm.Message.ERR: # (ERR, exc_type, args) # # TODO(gus): see what can be done to preserve traceback @@ -228,7 +218,7 @@ class _ClientChannel(comm.ClientChannel): raise ProtocolError(_('Unexpected response: %r') % result) def out_of_band(self, msg): - if msg[0] == Message.LOG: + if msg[0] == comm.Message.LOG: # (LOG, LogRecord __dict__) message = {encodeutils.safe_decode(k): v for k, v in msg[1].items()} @@ -470,11 +460,11 @@ class Daemon(object): :return: A tuple of the return status, optional call output, and optional error information. """ - if cmd == Message.PING: - return (Message.PONG.value,) + if cmd == comm.Message.PING: + return (comm.Message.PONG.value,) try: - if cmd != Message.CALL: + if cmd != comm.Message.CALL: raise ProtocolError(_('Unknown privsep cmd: %s') % cmd) # Extract the callable and arguments @@ -485,14 +475,14 @@ class Daemon(object): raise NameError(msg) ret = func(*f_args, **f_kwargs) - return (Message.RET.value, ret) + return (comm.Message.RET.value, ret) except Exception as e: LOG.debug( 'privsep: Exception during request[%(msgid)s]: ' '%(err)s', {'msgid': msgid, 'err': e}, exc_info=True) cls = e.__class__ cls_name = '%s.%s' % (cls.__module__, cls.__name__) - return (Message.ERR.value, cls_name, e.args) + return (comm.Message.ERR.value, cls_name, e.args) def _create_done_callback(self, msgid): """Creates a future callback to receive command execution results. @@ -520,7 +510,7 @@ class Daemon(object): '%(err)s', {'msgid': msgid, 'err': e}, exc_info=True) cls = e.__class__ cls_name = '%s.%s' % (cls.__module__, cls.__name__) - reply = (Message.ERR.value, cls_name, e.args) + reply = (comm.Message.ERR.value, cls_name, e.args) try: channel.send((msgid, reply)) except IOError: diff --git a/oslo_privsep/functional/test_daemon.py b/oslo_privsep/functional/test_daemon.py index 2c7a5ef..8cc572a 100644 --- a/oslo_privsep/functional/test_daemon.py +++ b/oslo_privsep/functional/test_daemon.py @@ -20,6 +20,7 @@ import unittest from oslo_config import fixture as config_fixture from oslotest import base +from oslo_privsep import comm from oslo_privsep import priv_context @@ -30,6 +31,14 @@ test_context = priv_context.PrivContext( capabilities=[], ) +test_context_with_timeout = priv_context.PrivContext( + __name__, + cfg_section='privsep', + pypath=__name__ + '.test_context_with_timeout', + capabilities=[], + timeout=0.03 +) + @test_context.entrypoint def sleep(): @@ -37,6 +46,18 @@ def sleep(): time.sleep(.001) +@test_context.entrypoint_with_timeout(0.03) +def sleep_with_timeout(long_timeout=0.04): + time.sleep(long_timeout) + return 42 + + +@test_context_with_timeout.entrypoint +def sleep_with_t_context(long_timeout=0.04): + time.sleep(long_timeout) + return 42 + + @test_context.entrypoint def one(): return 1 @@ -65,6 +86,28 @@ class TestDaemon(base.BaseTestCase): # Make sure the daemon is still working self.assertEqual(1, one()) + def test_entrypoint_with_timeout(self): + thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size + for _ in range(thread_pool_size + 1): + self.assertRaises(comm.PrivsepTimeout, sleep_with_timeout) + + def test_entrypoint_with_timeout_pass(self): + thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size + for _ in range(thread_pool_size + 1): + res = sleep_with_timeout(0.01) + self.assertEqual(42, res) + + def test_context_with_timeout(self): + thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size + for _ in range(thread_pool_size + 1): + self.assertRaises(comm.PrivsepTimeout, sleep_with_t_context) + + def test_context_with_timeout_pass(self): + thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size + for _ in range(thread_pool_size + 1): + res = sleep_with_t_context(0.01) + self.assertEqual(42, res) + def test_logging(self): logs() self.assertIn('foo', self.log_fixture.logger.output) diff --git a/oslo_privsep/priv_context.py b/oslo_privsep/priv_context.py index 58b4b5e..c0fd3d5 100644 --- a/oslo_privsep/priv_context.py +++ b/oslo_privsep/priv_context.py @@ -128,7 +128,8 @@ def init(root_helper=None): class PrivContext(object): def __init__(self, prefix, cfg_section='privsep', pypath=None, - capabilities=None, logger_name='oslo_privsep.daemon'): + capabilities=None, logger_name='oslo_privsep.daemon', + timeout=None): # Note that capabilities=[] means retaining no capabilities # and leaves even uid=0 with no powers except being able to @@ -156,6 +157,7 @@ class PrivContext(object): default=capabilities) cfg.CONF.set_default('logger_name', group=cfg_section, default=logger_name) + self.timeout = timeout @property def conf(self): @@ -221,7 +223,22 @@ class PrivContext(object): def entrypoint(self, func): """This is intended to be used as a decorator.""" + return self._entrypoint(func) + def entrypoint_with_timeout(self, timeout): + """This is intended to be used as a decorator with timeout.""" + + def wrap(func): + + @functools.wraps(func) + def inner(*args, **kwargs): + f = self._entrypoint(func) + return f(*args, _wrap_timeout=timeout, **kwargs) + setattr(inner, _ENTRYPOINT_ATTR, self) + return inner + return wrap + + def _entrypoint(self, func): if not func.__module__.startswith(self.prefix): raise AssertionError('%r entrypoints must be below "%s"' % (self, self.prefix)) @@ -242,7 +259,7 @@ class PrivContext(object): def is_entrypoint(self, func): return getattr(func, _ENTRYPOINT_ATTR, None) is self - def _wrap(self, func, *args, **kwargs): + def _wrap(self, func, *args, _wrap_timeout=None, **kwargs): if self.client_mode: name = '%s.%s' % (func.__module__, func.__name__) if self.channel is not None and not self.channel.running: @@ -250,7 +267,9 @@ class PrivContext(object): self.stop() if self.channel is None: self.start() - return self.channel.remote_call(name, args, kwargs) + r_call_timeout = _wrap_timeout or self.timeout + return self.channel.remote_call(name, args, kwargs, + r_call_timeout) else: return func(*args, **kwargs) diff --git a/oslo_privsep/tests/test_daemon.py b/oslo_privsep/tests/test_daemon.py index f391e97..cdde642 100644 --- a/oslo_privsep/tests/test_daemon.py +++ b/oslo_privsep/tests/test_daemon.py @@ -216,7 +216,7 @@ class ClientChannelTestCase(base.BaseTestCase): @mock.patch.object(daemon.LOG.logger, 'handle') def test_out_of_band_log_message(self, handle_mock): - message = [daemon.Message.LOG, self.DICT] + message = [comm.Message.LOG, self.DICT] self.assertEqual(self.client_channel.log, daemon.LOG) with mock.patch.object(pylogging, 'makeLogRecord') as mock_make_log, \ mock.patch.object(daemon.LOG, 'isEnabledFor', @@ -229,7 +229,7 @@ class ClientChannelTestCase(base.BaseTestCase): def test_out_of_band_not_log_message(self): with mock.patch.object(daemon.LOG, 'warning') as mock_warning: - self.client_channel.out_of_band([daemon.Message.PING]) + self.client_channel.out_of_band([comm.Message.PING]) mock_warning.assert_called_once() @mock.patch.object(daemon.logging, 'getLogger') @@ -245,7 +245,7 @@ class ClientChannelTestCase(base.BaseTestCase): get_logger_mock.assert_called_once_with(logger_name) self.assertEqual(get_logger_mock.return_value, channel.log) - message = [daemon.Message.LOG, self.DICT] + message = [comm.Message.LOG, self.DICT] channel.out_of_band(message) make_log_mock.assert_called_once_with(self.EXPECTED) diff --git a/oslo_privsep/tests/test_priv_context.py b/oslo_privsep/tests/test_priv_context.py index d6480db..98f73c9 100644 --- a/oslo_privsep/tests/test_priv_context.py +++ b/oslo_privsep/tests/test_priv_context.py @@ -19,10 +19,12 @@ import pipes import platform import sys import tempfile +import time from unittest import mock import testtools +from oslo_privsep import comm from oslo_privsep import daemon from oslo_privsep import priv_context from oslo_privsep.tests import testctx @@ -40,6 +42,12 @@ def add1(arg): return arg + 1 +@testctx.context.entrypoint_with_timeout(0.2) +def do_some_long(long_timeout=0.4): + time.sleep(long_timeout) + return 42 + + class CustomError(Exception): def __init__(self, code, msg): super(CustomError, self).__init__(code, msg) @@ -188,6 +196,16 @@ class RootwrapTest(testctx.TestContextTestCase): priv_pid = priv_getpid() self.assertNotMyPid(priv_pid) + def test_long_call_with_timeout(self): + self.assertRaises( + comm.PrivsepTimeout, + do_some_long + ) + + def test_long_call_within_timeout(self): + res = do_some_long(0.001) + self.assertEqual(42, res) + @testtools.skipIf(platform.system() != 'Linux', 'works only on Linux platform.') diff --git a/releasenotes/notes/add_entrypoint_with_timeout_decorator-9aab5a74153b3632.yaml b/releasenotes/notes/add_entrypoint_with_timeout_decorator-9aab5a74153b3632.yaml new file mode 100644 index 0000000..30148e1 --- /dev/null +++ b/releasenotes/notes/add_entrypoint_with_timeout_decorator-9aab5a74153b3632.yaml @@ -0,0 +1,11 @@ +--- +features: + - | + Add ``timeout`` as parameter to ``PrivContext`` and add + ``entrypoint_with_timeout`` decorator to cover the issues with + commands which take random time to finish. + ``PrivsepTimeout`` is raised if timeout is reached. + + ``Warning``: The daemon (the root process) task won't stop when timeout + is reached. That means we'll have less available threads if the related + thread never finishes.