Reimplement SSH port forwarding
New implementation adds verification of created port forwarding. Also it produce detailed error reporting including verbose output from ssh process. Change-Id: I8841e89cf7e78fad0f99e05d1d83a73cc253d05b Closes-Bug: #1644835
This commit is contained in:
parent
65f5f65984
commit
3c8941ffbe
0
bareon_ironic/common/__init__.py
Normal file
0
bareon_ironic/common/__init__.py
Normal file
303
bareon_ironic/common/ssh_utils.py
Normal file
303
bareon_ironic/common/ssh_utils.py
Normal file
@ -0,0 +1,303 @@
|
||||
#
|
||||
# Copyright 2016 Cray Inc., All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import fcntl
|
||||
import os
|
||||
import select
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
|
||||
import six
|
||||
from ironic.common import exception
|
||||
|
||||
from bareon_ironic import exception as exc
|
||||
from bareon_ironic.common import subproc_utils
|
||||
|
||||
NetAddr = collections.namedtuple('NetAddr', ('host', 'port'))
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class _SSHPortForwardingAbstract(object):
|
||||
"""Abstract base class for SSH port forwarding handlers
|
||||
|
||||
Local and remote port forwarding management handlers are based on this
|
||||
abstract class. It implement context management protocol so it must be used
|
||||
inside "with" keyword. Port forwarding is activated into __enter__ method
|
||||
and deactivated into __exit__ method.
|
||||
"""
|
||||
|
||||
proc = None
|
||||
subprocess_output = None
|
||||
|
||||
def __init__(self, user, key_file, host, bind, forward,
|
||||
ssh_port=None, validate_timeout=30):
|
||||
"""Collect and store data required for port forwarding setup
|
||||
|
||||
:param user: username for SSH connections
|
||||
:type user: str
|
||||
:param key_file: path to private key for SSH authentication
|
||||
:type key_file: str
|
||||
:param host: address of remote hosts
|
||||
:type host: str
|
||||
:param bind: address/port pair to bind forwarded port
|
||||
:type bind: NetAddr
|
||||
:param forward: address/port pair defining destination of port
|
||||
forwarding
|
||||
:type forward: NetAddr
|
||||
:param ssh_port: SSH port on remote host
|
||||
:type ssh_port: int
|
||||
:param validate_timeout: amount of time we will wait for creation of
|
||||
port forwarding. If set to zero - skip
|
||||
validation.
|
||||
:type validate_timeout: int
|
||||
"""
|
||||
self.user = user
|
||||
self.key = key_file
|
||||
self.host = host
|
||||
self.bind = bind
|
||||
self.forward = forward
|
||||
self.ssh_port = ssh_port
|
||||
self.validate_timeout = max(validate_timeout, 0)
|
||||
|
||||
def __repr__(self):
|
||||
setup = self._ssh_command_add_forward_arguments([])
|
||||
setup = ' '.join(setup)
|
||||
return '<{} {}>'.format(type(self).__name__, setup)
|
||||
|
||||
def __enter__(self):
|
||||
self.subprocess_output = tempfile.TemporaryFile()
|
||||
|
||||
cmd = self._make_ssh_command()
|
||||
proc_args = {
|
||||
'close_fds': True,
|
||||
'stderr': self.subprocess_output}
|
||||
proc_args = self._subprocess_add_args(proc_args)
|
||||
try:
|
||||
self.proc = subprocess.Popen(cmd, **proc_args)
|
||||
if self.validate_timeout:
|
||||
self._check_port_forwarding()
|
||||
except OSError as e:
|
||||
raise exc.SubprocessError(command=cmd, error=e)
|
||||
except exc.SSHForwardedPortValidationError as e:
|
||||
self._kill()
|
||||
remote = self._ssh_command_add_destination([])
|
||||
remote = ' '.join(str(x) for x in remote)
|
||||
raise exc.SSHSetupForwardingError(
|
||||
forward=self, error=e.message, remote=remote,
|
||||
output=self._grab_subprocess_output())
|
||||
except Exception:
|
||||
self._kill()
|
||||
raise
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
del exc_info
|
||||
if self.proc.poll() is not None:
|
||||
raise exception.SSHConnectFailed(
|
||||
'Unexpected SSH process termination.\n'
|
||||
'{}'.format(self._grab_subprocess_output()))
|
||||
self._kill()
|
||||
|
||||
def _make_ssh_command(self):
|
||||
cmd = ['ssh']
|
||||
cmd = self._ssh_command_add_auth_arguments(cmd)
|
||||
cmd = self._ssh_command_add_forward_arguments(cmd)
|
||||
cmd = self._ssh_command_add_extra_arguments(cmd)
|
||||
cmd = self._ssh_command_add_destination(cmd)
|
||||
cmd = self._ssh_command_add_command(cmd)
|
||||
|
||||
return cmd
|
||||
|
||||
def _ssh_command_add_auth_arguments(self, cmd):
|
||||
return cmd + ['-i', self.key]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _ssh_command_add_forward_arguments(self, cmd):
|
||||
return cmd
|
||||
|
||||
def _ssh_command_add_extra_arguments(self, cmd):
|
||||
return cmd + [
|
||||
'-v',
|
||||
'-o', 'BatchMode=yes',
|
||||
'-o', 'RequestTTY=no',
|
||||
'-o', 'ExitOnForwardFailure=yes',
|
||||
'-o', 'StrictHostKeyChecking=no',
|
||||
'-o', 'UserKnownHostsFile=/dev/null']
|
||||
|
||||
def _ssh_command_add_destination(self, cmd):
|
||||
if self.ssh_port is not None:
|
||||
cmd += ['-p', str(self.ssh_port)]
|
||||
return cmd + ['@'.join((self.user, self.host))]
|
||||
|
||||
def _ssh_command_add_command(self, cmd):
|
||||
return cmd
|
||||
|
||||
def _subprocess_add_args(self, args):
|
||||
return args
|
||||
|
||||
@abc.abstractmethod
|
||||
def _check_port_forwarding(self):
|
||||
pass
|
||||
|
||||
def _kill(self):
|
||||
terminator = subproc_utils.ProcTerminator(self.proc)
|
||||
terminator()
|
||||
|
||||
def _grab_subprocess_output(self):
|
||||
output = 'There is no output from SSH process'
|
||||
if self.subprocess_output.tell():
|
||||
self.subprocess_output.seek(os.SEEK_SET)
|
||||
output = self.subprocess_output.read()
|
||||
return output
|
||||
|
||||
|
||||
class SSHLocalPortForwarding(_SSHPortForwardingAbstract):
|
||||
def _ssh_command_add_forward_arguments(self, cmd):
|
||||
forward = self.bind + self.forward
|
||||
forward = [str(x) for x in forward]
|
||||
return cmd + ['-L', ':'.join(forward)]
|
||||
|
||||
def _ssh_command_add_extra_arguments(self, cmd):
|
||||
parent = super(SSHLocalPortForwarding, self)
|
||||
return parent._ssh_command_add_extra_arguments(cmd) + ['-N']
|
||||
|
||||
def _check_port_forwarding(self):
|
||||
now = time.time()
|
||||
time_end = now + self.validate_timeout
|
||||
while now < time_end and self.proc.poll() is None:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
s.connect(self.bind)
|
||||
except socket.error:
|
||||
time.sleep(1)
|
||||
now = time.time()
|
||||
continue
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
break
|
||||
else:
|
||||
raise exc.SSHForwardedPortValidationError
|
||||
|
||||
|
||||
class SSHRemotePortForwarding(_SSHPortForwardingAbstract):
|
||||
def _ssh_command_add_forward_arguments(self, cmd):
|
||||
forward = self.bind + self.forward
|
||||
forward = [str(x) for x in forward]
|
||||
return cmd + ['-R', ':'.join(forward)]
|
||||
|
||||
def _ssh_command_add_extra_arguments(self, cmd):
|
||||
parent = super(SSHRemotePortForwarding, self)
|
||||
cmd = parent._ssh_command_add_extra_arguments(cmd)
|
||||
if not self.validate_timeout:
|
||||
cmd += ['-N']
|
||||
return cmd
|
||||
|
||||
def _ssh_command_add_command(self, cmd):
|
||||
parent = super(SSHRemotePortForwarding, self)
|
||||
cmd = parent._ssh_command_add_command(cmd)
|
||||
if self.validate_timeout:
|
||||
cmd += ['python']
|
||||
return cmd
|
||||
|
||||
def _subprocess_add_args(self, args):
|
||||
args = super(SSHRemotePortForwarding, self)._subprocess_add_args(args)
|
||||
if self.validate_timeout:
|
||||
args['stdin'] = subprocess.PIPE
|
||||
args['stdout'] = subprocess.PIPE
|
||||
return args
|
||||
|
||||
def _check_port_forwarding(self):
|
||||
fail_marker = 'CONNECT FAILED'
|
||||
success_marker = 'CONNECT APPROVED'
|
||||
validate_snippet = textwrap.dedent("""
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
timeout = max({timeout}, 1)
|
||||
addr = '{address.host}'
|
||||
port = {address.port}
|
||||
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
now = time.time()
|
||||
time_end = now + timeout
|
||||
|
||||
while now < time_end:
|
||||
try:
|
||||
s.connect((addr, port))
|
||||
except socket.error:
|
||||
time.sleep(1)
|
||||
now = time.time()
|
||||
continue
|
||||
break
|
||||
else:
|
||||
sys.stderr.write('{fail_marker}')
|
||||
sys.stderr.write(os.linesep)
|
||||
sys.exit(1)
|
||||
|
||||
sys.stdout.write('{success_marker}')
|
||||
sys.stdout.write(os.linesep)
|
||||
sys.stdout.close()
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
while True:
|
||||
time.sleep(128)
|
||||
""").lstrip().format(
|
||||
address=self.bind, timeout=self.validate_timeout,
|
||||
success_marker=success_marker, fail_marker=fail_marker)
|
||||
|
||||
self.proc.stdin.write(validate_snippet)
|
||||
self.proc.stdin.close()
|
||||
|
||||
stdout = self.proc.stdout.fileno()
|
||||
opts = fcntl.fcntl(stdout, fcntl.F_GETFL)
|
||||
fcntl.fcntl(stdout, fcntl.F_SETFL, opts | os.O_NONBLOCK)
|
||||
|
||||
now = time.time()
|
||||
time_end = now + self.validate_timeout
|
||||
output = []
|
||||
while now < time_end:
|
||||
ready = select.select([stdout], [], [], time_end - now)
|
||||
ready_read = ready[0]
|
||||
if ready_read:
|
||||
chunk = self.proc.stdout.read(512)
|
||||
if not chunk:
|
||||
break
|
||||
output.append(chunk)
|
||||
|
||||
if success_marker in ''.join(output):
|
||||
break
|
||||
|
||||
now = time.time()
|
||||
else:
|
||||
raise exc.SSHForwardedPortValidationError
|
||||
|
||||
|
||||
def forward_remote_port(
|
||||
port, user, key_file, target_host, ssh_port=None):
|
||||
return SSHRemotePortForwarding(
|
||||
user, key_file, target_host,
|
||||
NetAddr('127.0.0.1', port), NetAddr('127.0.0.1', port),
|
||||
ssh_port=ssh_port)
|
80
bareon_ironic/common/subproc_utils.py
Normal file
80
bareon_ironic/common/subproc_utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
#
|
||||
# Copyright 2016 Cray Inc., All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import errno
|
||||
import functools
|
||||
|
||||
from oslo_service import loopingcall
|
||||
|
||||
from bareon_ironic import exception as exc
|
||||
|
||||
|
||||
class ProcTerminator(object):
|
||||
_poll_delay = 0.5
|
||||
is_terminated = False
|
||||
is_forced = False
|
||||
|
||||
def __init__(self, proc, timeout=10, force_timeout=2, force=True):
|
||||
self.proc = proc
|
||||
self._actions_queue = [
|
||||
(self._do_terminate, timeout)]
|
||||
if force:
|
||||
self._actions_queue.append(
|
||||
(functools.partial(
|
||||
self._do_terminate, force=True), force_timeout))
|
||||
|
||||
def __call__(self):
|
||||
action_iterator = loopingcall.DynamicLoopingCall(self._do_action)
|
||||
action = action_iterator.start()
|
||||
poll_iterator = loopingcall.FixedIntervalLoopingCall(
|
||||
self._do_poll, action)
|
||||
poll = poll_iterator.start(self._poll_delay)
|
||||
try:
|
||||
poll.wait()
|
||||
if self.is_terminated:
|
||||
action_iterator.stop()
|
||||
else:
|
||||
action_iterator.wait()
|
||||
except OSError as e:
|
||||
if e.errno != errno.ESRCH:
|
||||
raise
|
||||
|
||||
def _do_poll(self, action):
|
||||
if action.ready():
|
||||
raise loopingcall.LoopingCallDone()
|
||||
|
||||
rcode = self.proc.poll()
|
||||
if rcode is None:
|
||||
return
|
||||
|
||||
self.is_terminated = True
|
||||
raise loopingcall.LoopingCallDone()
|
||||
|
||||
def _do_action(self):
|
||||
try:
|
||||
action, timeout = self._actions_queue.pop(0)
|
||||
except IndexError:
|
||||
raise exc.SurvivedSubprocess(pid=self.proc.pid)
|
||||
action()
|
||||
|
||||
return timeout
|
||||
|
||||
def _do_terminate(self, force=False):
|
||||
if not force:
|
||||
self.proc.terminate()
|
||||
return
|
||||
|
||||
self.proc.kill()
|
||||
self.is_forced = True
|
36
bareon_ironic/exception.py
Normal file
36
bareon_ironic/exception.py
Normal file
@ -0,0 +1,36 @@
|
||||
#
|
||||
# Copyright 2016 Cray Inc., All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ironic_lib import exception
|
||||
|
||||
|
||||
class SurvivedSubprocess(exception.IronicException):
|
||||
message = (
|
||||
'Subprocess PID:%(pid)d survived despite our termination requests')
|
||||
|
||||
|
||||
class SubprocessError(exception.IronicException):
|
||||
message = 'Can\'t execute command: %(command)s\n%(error)s'
|
||||
|
||||
|
||||
class SSHSetupForwardingError(exception.IronicException):
|
||||
message = (
|
||||
'Unable to setup SSH port forwarding: %(error)s\n'
|
||||
'(forward=%(forward)r endpoint=%(remote)s)\n\n%(output)s')
|
||||
|
||||
|
||||
class SSHForwardedPortValidationError(exception.IronicException):
|
||||
message = (
|
||||
'Unable to make TCP connection via SSH forwarder port')
|
@ -19,6 +19,7 @@ Bareon Rsync deploy driver.
|
||||
|
||||
from oslo_config import cfg
|
||||
|
||||
from bareon_ironic.common import ssh_utils
|
||||
from bareon_ironic.modules import bareon_utils
|
||||
from bareon_ironic.modules import bareon_base
|
||||
from bareon_ironic.modules.resources import resources
|
||||
@ -60,8 +61,8 @@ class BareonRsyncVendor(bareon_base.BareonVendor):
|
||||
ssh_port = kwargs.get('bareon_ssh_port', 22)
|
||||
host = (kwargs.get('host') or
|
||||
bareon_utils.get_node_ip(kwargs.get('task')))
|
||||
with bareon_utils.ssh_tunnel(rsync.RSYNC_PORT, user,
|
||||
key_file, host, ssh_port):
|
||||
with ssh_utils.forward_remote_port(
|
||||
rsync.RSYNC_PORT, user, key_file, host, ssh_port):
|
||||
return super(
|
||||
BareonRsyncVendor, self
|
||||
)._execute_deploy_script(task, ssh, cmd, **kwargs)
|
||||
|
@ -13,24 +13,21 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import six
|
||||
from oslo_concurrency import processutils
|
||||
from oslo_config import cfg
|
||||
from oslo_log import log as logging
|
||||
from oslo_utils import strutils
|
||||
|
||||
from ironic.common import dhcp_factory
|
||||
from ironic.common import exception
|
||||
from ironic.common import keystone
|
||||
from ironic.common import utils
|
||||
from ironic.common.i18n import _, _LW
|
||||
from oslo_concurrency import processutils
|
||||
from oslo_config import cfg
|
||||
from oslo_log import log as logging
|
||||
from oslo_utils import strutils
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
CONF = cfg.CONF
|
||||
@ -95,34 +92,6 @@ def get_ssh_connection(task, **kwargs):
|
||||
return ssh
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ssh_tunnel(port, user, key_file, target_host, ssh_port=22):
|
||||
tunnel = _create_ssh_tunnel(port, port, user, key_file, target_host,
|
||||
local_forwarding=False)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tunnel.terminate()
|
||||
|
||||
|
||||
def _create_ssh_tunnel(remote_port, local_port, user, key_file, target_host,
|
||||
remote_ip='127.0.0.1', local_ip='127.0.0.1',
|
||||
local_forwarding=True,
|
||||
ssh_port=22):
|
||||
cmd = ['ssh', '-N', '-o', 'StrictHostKeyChecking=no', '-o',
|
||||
'UserKnownHostsFile=/dev/null', '-p', str(ssh_port), '-i', key_file]
|
||||
if local_forwarding:
|
||||
cmd += ['-L', '{}:{}:{}:{}'.format(local_ip, local_port, remote_ip,
|
||||
remote_port)]
|
||||
else:
|
||||
cmd += ['-R', '{}:{}:{}:{}'.format(remote_ip, remote_port, local_ip,
|
||||
local_port)]
|
||||
|
||||
cmd.append('@'.join((user, target_host)))
|
||||
# TODO(lobur): Make this sync, check status. (may use ssh control socket).
|
||||
return subprocess.Popen(cmd)
|
||||
|
||||
|
||||
def sftp_write_to(sftp, data, path):
|
||||
with tempfile.NamedTemporaryFile(dir=CONF.tempdir) as f:
|
||||
f.write(data)
|
||||
|
@ -21,12 +21,12 @@ import datetime
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from ironic.common import exception
|
||||
from oslo_concurrency import processutils
|
||||
from oslo_config import cfg
|
||||
from oslo_log import log
|
||||
|
||||
from ironic.common import exception
|
||||
|
||||
from bareon_ironic.common import ssh_utils
|
||||
from bareon_ironic.modules import bareon_utils
|
||||
from bareon_ironic.modules.resources import resources
|
||||
from bareon_ironic.modules.resources import rsync
|
||||
@ -170,8 +170,9 @@ class ActionController(bareon_utils.RawToPropertyMixin):
|
||||
ssh_key_file = ssh_params.get('key_filename')
|
||||
ssh_host = ssh_params.get('host')
|
||||
ssh_port = ssh_params.get('port', 22)
|
||||
with bareon_utils.ssh_tunnel(rsync.RSYNC_PORT, ssh_user,
|
||||
ssh_key_file, ssh_host, ssh_port):
|
||||
with ssh_utils.forward_remote_port(
|
||||
rsync.RSYNC_PORT, ssh_user,
|
||||
ssh_key_file, ssh_host, ssh_port):
|
||||
self._execute(ssh, sftp)
|
||||
else:
|
||||
self._execute(ssh, sftp)
|
||||
|
@ -0,0 +1,19 @@
|
||||
#
|
||||
# Copyright 2016 Cray Inc., All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import eventlet
|
||||
|
||||
|
||||
eventlet.monkey_patch(os=False)
|
@ -13,9 +13,8 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
import testtools
|
||||
|
||||
|
||||
class TestDummy(unittest.TestCase):
|
||||
def test_dummy(self):
|
||||
pass
|
||||
class AbstractTestCase(testtools.TestCase):
|
||||
pass
|
0
bareon_ironic/tests/common/__init__.py
Normal file
0
bareon_ironic/tests/common/__init__.py
Normal file
137
bareon_ironic/tests/common/test_ssh_utils.py
Normal file
137
bareon_ironic/tests/common/test_ssh_utils.py
Normal file
@ -0,0 +1,137 @@
|
||||
#
|
||||
# Copyright 2016 Cray Inc., All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import itertools
|
||||
import subprocess
|
||||
|
||||
import mock
|
||||
|
||||
from bareon_ironic import exception as exc
|
||||
from bareon_ironic.common import ssh_utils
|
||||
from bareon_ironic.tests import base
|
||||
from bareon_ironic.tests.common import test_subproc_utils
|
||||
|
||||
|
||||
class SSHPortForwardingTestCase(base.AbstractTestCase):
|
||||
user = 'john-doe'
|
||||
key_file = '/path/to/john-doe/auth/keys/host-key.rsa'
|
||||
host = 'dummy.remote.local'
|
||||
bind = ssh_utils.NetAddr('127.0.1.2', 4080)
|
||||
forward = ssh_utils.NetAddr('127.3.4.5', 5080)
|
||||
|
||||
def setUp(self):
|
||||
super(SSHPortForwardingTestCase, self).setUp()
|
||||
|
||||
# must be created before mocking
|
||||
dummy_proc = test_subproc_utils.DummyPopen()
|
||||
self.ssh_proc = mock.Mock(wraps=dummy_proc)
|
||||
|
||||
self.mock = {}
|
||||
self.mock_patch = {
|
||||
'subprocess.Popen': mock.patch(
|
||||
'subprocess.Popen', return_value=self.ssh_proc),
|
||||
'tempfile.TemporaryFile': mock.patch(
|
||||
'tempfile.TemporaryFile', return_value=self.ssh_proc.stderr),
|
||||
'time.time': mock.patch(
|
||||
'time.time', side_effect=itertools.count()),
|
||||
'time.sleep': mock.patch('time.sleep'),
|
||||
'socket.socket': mock.patch('socket.socket')}
|
||||
|
||||
for name in self.mock_patch:
|
||||
patch = self.mock_patch[name]
|
||||
self.mock[name] = patch.start()
|
||||
self.addCleanup(patch.stop)
|
||||
|
||||
self.local_forwarding = ssh_utils.SSHLocalPortForwarding(
|
||||
self.user, self.key_file, self.host, self.bind, self.forward)
|
||||
self.remote_forwarding = ssh_utils.SSHRemotePortForwarding(
|
||||
self.user, self.key_file, self.host, self.bind, self.forward)
|
||||
|
||||
@mock.patch.object(
|
||||
ssh_utils.SSHRemotePortForwarding, '_check_port_forwarding')
|
||||
def test_remote_without_validation(self, validate_method):
|
||||
self.remote_forwarding.validate_timeout = 0
|
||||
|
||||
forward_argument = self.bind + self.forward
|
||||
forward_argument = [str(x) for x in forward_argument]
|
||||
forward_argument = ':'.join(forward_argument)
|
||||
user_host = '@'.join((self.user, self.host))
|
||||
|
||||
with self.remote_forwarding:
|
||||
self.assertEqual(1, self.mock['subprocess.Popen'].call_count)
|
||||
popen_args, popen_kwargs = self.mock['subprocess.Popen'].call_args
|
||||
cmd = popen_args[0]
|
||||
self.assertEqual(['ssh'], cmd[:1])
|
||||
try:
|
||||
actual_forward = cmd[cmd.index('-R') + 1]
|
||||
except IndexError:
|
||||
raise AssertionError(
|
||||
'Missing expected arguments -R <forward_spec> in SSH call')
|
||||
self.assertEqual(forward_argument, actual_forward)
|
||||
|
||||
self.assertIn('-N', cmd)
|
||||
self.assertEqual(user_host, cmd[-1])
|
||||
|
||||
self.assertIs(self.ssh_proc.stderr, popen_kwargs.get('stderr'))
|
||||
self.assertEqual(0, self.ssh_proc.terminate.call_count)
|
||||
|
||||
self.assertEqual(0, validate_method.call_count)
|
||||
self.assertEqual(1, self.ssh_proc.terminate.call_count)
|
||||
|
||||
@mock.patch('select.select')
|
||||
def test_remote_validation(self, select_mock):
|
||||
self.ssh_proc.stdout.write('CONNECT APPROVED\n')
|
||||
self.ssh_proc.stdout.seek(0)
|
||||
|
||||
select_mock.return_value = [[self.ssh_proc.stdout.fileno()], [], []]
|
||||
with self.remote_forwarding:
|
||||
popen_args, popen_kwargs = self.mock['subprocess.Popen'].call_args
|
||||
cmd = popen_args[0]
|
||||
|
||||
self.assertNotIn('-N', cmd)
|
||||
self.assertEqual(['python'], cmd[-1:])
|
||||
|
||||
self.assertIs(subprocess.PIPE, popen_kwargs['stdout'])
|
||||
self.assertIs(subprocess.PIPE, popen_kwargs['stdin'])
|
||||
|
||||
self.assertTrue(self.ssh_proc.stdin.closed)
|
||||
self.assertEqual(0, self.ssh_proc.terminate.call_count)
|
||||
self.assertEqual(1, self.ssh_proc.terminate.call_count)
|
||||
|
||||
@mock.patch('select.select')
|
||||
def test_remote_validation_fail(self, select_mock):
|
||||
ssh_output_indicator = 'SSH output grabbing indicator'
|
||||
|
||||
self.ssh_proc.stdout.write('output don\'t matching success marker\n')
|
||||
self.ssh_proc.stdout.seek(0)
|
||||
self.ssh_proc.stderr.write(ssh_output_indicator)
|
||||
|
||||
select_mock.side_effect = itertools.chain(
|
||||
([[self.ssh_proc.stdout.fileno()], [], []], ),
|
||||
itertools.repeat([[], [], []]))
|
||||
try:
|
||||
with self.remote_forwarding:
|
||||
pass
|
||||
except exc.SSHSetupForwardingError as e:
|
||||
self.assertIn(ssh_output_indicator, str(e))
|
||||
except Exception as e:
|
||||
raise AssertionError('Catch {!r} instead of {!r}'.format(
|
||||
e, exc.SSHSetupForwardingError))
|
||||
else:
|
||||
raise AssertionError(
|
||||
'There was no expected exception: {!r}'.format(
|
||||
exc.SSHSetupForwardingError))
|
||||
|
||||
self.assertEqual(1, self.ssh_proc.terminate.call_count)
|
115
bareon_ironic/tests/common/test_subproc_utils.py
Normal file
115
bareon_ironic/tests/common/test_subproc_utils.py
Normal file
@ -0,0 +1,115 @@
|
||||
#
|
||||
# Copyright 2016 Cray Inc., All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import errno
|
||||
import tempfile
|
||||
|
||||
import mock
|
||||
|
||||
from bareon_ironic import exception as exc
|
||||
from bareon_ironic.common import subproc_utils
|
||||
from bareon_ironic.tests import base
|
||||
|
||||
|
||||
class ProcTerminatorTestCase(base.AbstractTestCase):
|
||||
def test(self):
|
||||
proc = DummyPopen()
|
||||
|
||||
terminator = subproc_utils.ProcTerminator(proc)
|
||||
terminator()
|
||||
self.assertEqual(proc.term_rcode, proc.rcode)
|
||||
|
||||
def test_killed_in_advance(self):
|
||||
proc = DummyPopen()
|
||||
proc.terminate()
|
||||
|
||||
terminator = subproc_utils.ProcTerminator(proc)
|
||||
terminator()
|
||||
self.assertEqual(proc.term_rcode, proc.rcode)
|
||||
|
||||
def test_ignore_TERM(self):
|
||||
proc = DummyPopen()
|
||||
proc.terminate = mock.Mock(spec=proc.terminate)
|
||||
|
||||
terminator = subproc_utils.ProcTerminator(
|
||||
proc, timeout=.01, force_timeout=.01)
|
||||
terminator._poll_delay = 0.001
|
||||
terminator()
|
||||
self.assertEqual(True, terminator.is_terminated)
|
||||
self.assertEqual(True, terminator.is_forced)
|
||||
|
||||
def test_ignore_TERM_without_force(self):
|
||||
proc = DummyPopen()
|
||||
proc.terminate = mock.Mock(spec=proc.terminate)
|
||||
|
||||
terminator = subproc_utils.ProcTerminator(
|
||||
proc, timeout=.01, force=False)
|
||||
terminator._poll_delay = 0.1
|
||||
self.assertRaises(exc.SurvivedSubprocess, terminator)
|
||||
self.assertEqual(False, terminator.is_terminated)
|
||||
self.assertEqual(False, terminator.is_forced)
|
||||
|
||||
def test_ignore_KILL(self):
|
||||
"""Process dies too slow"""
|
||||
|
||||
proc = DummyPopen()
|
||||
proc.terminate = mock.Mock(spec=proc.terminate)
|
||||
proc.kill = mock.Mock(spec=proc.kill)
|
||||
|
||||
terminator = subproc_utils.ProcTerminator(
|
||||
proc, timeout=.01, force_timeout=.01)
|
||||
terminator._poll_delay = 0.001
|
||||
self.assertRaises(exc.SurvivedSubprocess, terminator)
|
||||
self.assertEqual(False, terminator.is_terminated)
|
||||
self.assertEqual(True, terminator.is_forced)
|
||||
|
||||
def test_missing_proc(self):
|
||||
proc = DummyPopen()
|
||||
proc.terminate = mock.Mock(
|
||||
side_effect=OSError(errno.ESRCH, 'Fake error'))
|
||||
|
||||
terminator = subproc_utils.ProcTerminator(proc)
|
||||
terminator()
|
||||
self.assertEqual(False, terminator.is_terminated)
|
||||
|
||||
def test_insufficient_access(self):
|
||||
proc = DummyPopen()
|
||||
proc.terminate = mock.Mock(
|
||||
side_effect=OSError(errno.EPERM, 'Fake error'))
|
||||
|
||||
terminator = subproc_utils.ProcTerminator(proc)
|
||||
self.assertRaises(OSError, terminator)
|
||||
|
||||
|
||||
class DummyPopen(object):
|
||||
def __init__(self, rcode=None, term_rcode=-15, kill_rcode=-9, pid=1234):
|
||||
self.rcode = rcode
|
||||
self.term_rcode = term_rcode
|
||||
self.kill_rcode = kill_rcode
|
||||
self.pid = pid
|
||||
|
||||
self.rcode = None
|
||||
self.stdin = tempfile.TemporaryFile()
|
||||
self.stdout = tempfile.TemporaryFile()
|
||||
self.stderr = tempfile.TemporaryFile()
|
||||
|
||||
def terminate(self):
|
||||
self.rcode = self.term_rcode
|
||||
|
||||
def kill(self):
|
||||
self.rcode = self.kill_rcode
|
||||
|
||||
def poll(self):
|
||||
return self.rcode
|
@ -1,8 +1,8 @@
|
||||
# The order of packages is significant, because pip processes them in the order
|
||||
# of appearance. Changing the order has an impact on the overall integration
|
||||
# process, which may cause wedges in the gate later.
|
||||
# The driver uses Ironic code base and it's requirements, no additional
|
||||
# requirements needed
|
||||
# Since Ironic is not published to pip, Ironic must be installed on the system
|
||||
# before test run
|
||||
#ironic>=4.3.0
|
||||
oslo.service>=1.10.0 # Apache-2.0
|
||||
ironic-lib>=2.2.0 # Apache-2.0
|
||||
|
||||
# Since Ironic is not published to pip and deny to use links in dependencies
|
||||
# ironic dependency is defined into tox.ini
|
||||
|
@ -5,5 +5,7 @@
|
||||
hacking<0.11,>=0.10.2 # Apache-2.0
|
||||
testrepository>=0.0.18 # Apache-2.0/BSD
|
||||
|
||||
mock>=2.0 # BSD
|
||||
|
||||
# this is required for the docs build jobs
|
||||
sphinx!=1.3b1,<1.4,>=1.2.1 # BSD
|
||||
|
1
tox.ini
1
tox.ini
@ -11,6 +11,7 @@ setenv = VIRTUAL_ENV={envdir}
|
||||
LANGUAGE=en_US
|
||||
deps = -r{toxinidir}/requirements.txt
|
||||
-r{toxinidir}/test-requirements.txt
|
||||
git+https://github.com/openstack/ironic.git@5.1.2#egg=ironic-5.1.2
|
||||
whitelist_externals = bash
|
||||
commands =
|
||||
bash -c "TESTS_DIR=./bareon_ironic/tests/ python setup.py testr --slowest --testr-args='{posargs}'"
|
||||
|
Loading…
Reference in New Issue
Block a user