From 6d5d3994ea25a2704c017e241058383c9aa7cb79 Mon Sep 17 00:00:00 2001 From: Claudiu Popa Date: Thu, 20 Nov 2014 01:58:44 +0200 Subject: [PATCH] Rewrite exec utilies to be generic This patch adds executil.BaseCommand, which contains some scaffolding for building generic and simple CLI commands. Also, executil.py contains a couple of predefined commands, such as Python, Shell, Powershell etc. fileexecutils, as well as userdatautils, were rewritten to use the new command framework instead, which greatly increases usability and simplifies cross platform issues. Change-Id: Ifcd91da95a272ef68d56491f90bc213826eb7679 --- cloudbaseinit/plugins/common/__init__.py | 0 cloudbaseinit/plugins/common/executil.py | 145 +++++++++++++++++ .../plugins/windows/fileexecutils.py | 58 +++---- .../plugins/windows/userdatautils.py | 85 ++++------ .../tests/plugins/common/__init__.py | 0 .../tests/plugins/common/test_executil.py | 110 +++++++++++++ .../plugins/windows/test_fileexecutils.py | 83 +++++----- .../plugins/windows/test_userdatautils.py | 151 ++++++------------ cloudbaseinit/tests/testutils.py | 59 +++++++ 9 files changed, 460 insertions(+), 231 deletions(-) create mode 100644 cloudbaseinit/plugins/common/__init__.py create mode 100644 cloudbaseinit/plugins/common/executil.py create mode 100644 cloudbaseinit/tests/plugins/common/__init__.py create mode 100644 cloudbaseinit/tests/plugins/common/test_executil.py create mode 100644 cloudbaseinit/tests/testutils.py diff --git a/cloudbaseinit/plugins/common/__init__.py b/cloudbaseinit/plugins/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cloudbaseinit/plugins/common/executil.py b/cloudbaseinit/plugins/common/executil.py new file mode 100644 index 00000000..c6660f49 --- /dev/null +++ b/cloudbaseinit/plugins/common/executil.py @@ -0,0 +1,145 @@ +# Copyright 2014 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import os +import tempfile +import uuid + +from cloudbaseinit.osutils import factory as osutils_factory + + +__all__ = ( + 'BaseCommand', + 'Shell', + 'Python', + 'Bash', + 'Powershell', + 'PowershellSysnative', +) + + +class BaseCommand(object): + """Implements logic for executing an user command. + + This is intended to be subclassed and each subclass should change the + attributes which controls the behaviour of the execution. + It must be instantiated with a file. + It can also execute string commands, by using the alternate + constructor :meth:`~from_data`. + + The following attributes can control the behaviour of the command: + + * shell: Run the command as a shell command. + * extension: + + A string, which will be appended to a generated script file. + This is important for certain commands, e.g. Powershell, + which can't execute something without the `.ps1` extension. + + * command: + + A program which will execute the underlying command, + e.g. `python`, `bash` etc. + + """ + shell = False + extension = None + command = None + + def __init__(self, target_path, cleanup=None): + """Instantiate the command. + + The parameter *target_path* represents the file which will be + executed. The optional parameter *cleanup* can be a callable, + which will be called after executing a command, no matter if the + execution was succesful or not. + """ + + self._target_path = target_path + self._cleanup = cleanup + self._osutils = osutils_factory.get_os_utils() + + @property + def args(self): + """Return a list of commands. + + The list will be passed to :meth:`~execute_process`. + """ + if not self.command: + # Then we can assume it's a shell command. + return [self._target_path] + else: + return [self.command, self._target_path] + + def get_execute_method(self): + """Return a callable, which will be called by :meth:`~execute`.""" + return functools.partial(self._osutils.execute_process, + self.args, shell=self.shell) + + def execute(self): + """Execute the underlying command.""" + try: + return self.get_execute_method()() + finally: + if self._cleanup: + self._cleanup() + + __call__ = execute + + @classmethod + def from_data(cls, command): + """Create a new command class from the given command data.""" + def safe_remove(target_path): + try: + os.remove(target_path) + except OSError: # pragma: no cover + pass + + tmp = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) + if cls.extension: + tmp += cls.extension + with open(tmp, 'wb') as stream: + stream.write(command) + return cls(tmp, cleanup=functools.partial(safe_remove, tmp)) + + +class Shell(BaseCommand): + shell = True + extension = '.cmd' + + +class Python(BaseCommand): + extension = '.py' + command = 'python' + + +class Bash(BaseCommand): + extension = '.sh' + command = 'bash' + + +class PowershellSysnative(BaseCommand): + extension = '.ps1' + sysnative = True + + def get_execute_method(self): + return functools.partial( + self._osutils.execute_powershell_script, + self._target_path, + self.sysnative) + + +class Powershell(PowershellSysnative): + sysnative = False diff --git a/cloudbaseinit/plugins/windows/fileexecutils.py b/cloudbaseinit/plugins/windows/fileexecutils.py index dba60777..270b950a 100644 --- a/cloudbaseinit/plugins/windows/fileexecutils.py +++ b/cloudbaseinit/plugins/windows/fileexecutils.py @@ -15,47 +15,37 @@ import os from cloudbaseinit.openstack.common import log as logging -from cloudbaseinit.osutils import factory as osutils_factory +from cloudbaseinit.plugins.common import executil LOG = logging.getLogger(__name__) +FORMATS = { + "cmd": executil.Shell, + "exe": executil.Shell, + "sh": executil.Bash, + "py": executil.Python, + "ps1": executil.PowershellSysnative, +} + def exec_file(file_path): - shell = False - powershell = False - + ret_val = 0 + out = err = None ext = os.path.splitext(file_path)[1][1:].lower() - - if ext == "cmd": - args = [file_path] - shell = True - elif ext == "exe": - args = [file_path] - elif ext == "sh": - args = ["bash.exe", file_path] - elif ext == "py": - args = ["python.exe", file_path] - elif ext == "ps1": - powershell = True - else: + command = FORMATS.get(ext) + if not command: # Unsupported - LOG.warning('Unsupported script file type: %s' % ext) - return 0 - - osutils = osutils_factory.get_os_utils() + LOG.warning('Unsupported script file type: %s', ext) + return ret_val try: - if powershell: - (out, err, - ret_val) = osutils.execute_powershell_script(file_path) - else: - (out, err, ret_val) = osutils.execute_process(args, shell) - - LOG.info('Script "%(file_path)s" ended with exit code: %(ret_val)d' % - {"file_path": file_path, "ret_val": ret_val}) - LOG.debug('User_data stdout:\n%s' % out) - LOG.debug('User_data stderr:\n%s' % err) - - return ret_val + out, err, ret_val = command(file_path).execute() except Exception as ex: - LOG.warning('An error occurred during file execution: \'%s\'' % ex) + LOG.warning('An error occurred during file execution: \'%s\'', ex) + else: + LOG.debug('User_data stdout:\n%s', out) + LOG.debug('User_data stderr:\n%s', err) + + LOG.info('Script "%(file_path)s" ended with exit code: %(ret_val)d', + {"file_path": file_path, "ret_val": ret_val}) + return ret_val diff --git a/cloudbaseinit/plugins/windows/userdatautils.py b/cloudbaseinit/plugins/windows/userdatautils.py index 2037d639..a646ebe1 100644 --- a/cloudbaseinit/plugins/windows/userdatautils.py +++ b/cloudbaseinit/plugins/windows/userdatautils.py @@ -12,66 +12,51 @@ # License for the specific language governing permissions and limitations # under the License. -import os +import functools import re -import tempfile -import uuid from cloudbaseinit.openstack.common import log as logging -from cloudbaseinit.osutils import factory as osutils_factory -from cloudbaseinit.utils import encoding +from cloudbaseinit.plugins.common import executil LOG = logging.getLogger(__name__) +# Avoid 80+ length by using a local variable, which +# is deleted afterwards. +_compile = functools.partial(re.compile, flags=re.I) +FORMATS = ( + (_compile(br'^rem cmd\s'), executil.Shell), + (_compile(br'^#!/usr/bin/env\spython\s'), executil.Python), + (_compile(br'^#!'), executil.Bash), + (_compile(br'^#(ps1|ps1_sysnative)\s'), executil.PowershellSysnative), + (_compile(br'^#ps1_x86\s'), executil.Powershell), +) +del _compile + + +def _get_command(data): + # Get the command which should process the given data. + for pattern, command_class in FORMATS: + if pattern.search(data): + return command_class.from_data(data) + def execute_user_data_script(user_data): - osutils = osutils_factory.get_os_utils() - - shell = False - powershell = False - sysnative = True - - target_path = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) - if re.search(r'^rem cmd\s', user_data, re.I): - target_path += '.cmd' - args = [target_path] - shell = True - elif re.search(r'^#!/usr/bin/env\spython\s', user_data, re.I): - target_path += '.py' - args = ['python.exe', target_path] - elif re.search(r'^#!', user_data, re.I): - target_path += '.sh' - args = ['bash.exe', target_path] - elif re.search(r'^#(ps1|ps1_sysnative)\s', user_data, re.I): - target_path += '.ps1' - powershell = True - elif re.search(r'^#ps1_x86\s', user_data, re.I): - target_path += '.ps1' - powershell = True - sysnative = False - else: + ret_val = 0 + out = err = None + command = _get_command(user_data) + if not command: # Unsupported LOG.warning('Unsupported user_data format') - return 0 + return ret_val try: - encoding.write_file(target_path, user_data) - - if powershell: - (out, err, - ret_val) = osutils.execute_powershell_script(target_path, - sysnative) - else: - (out, err, ret_val) = osutils.execute_process(args, shell) - - LOG.info('User_data script ended with return code: %d' % ret_val) - LOG.debug('User_data stdout:\n%s' % out) - LOG.debug('User_data stderr:\n%s' % err) - - return ret_val + out, err, ret_val = command() except Exception as ex: - LOG.warning('An error occurred during user_data execution: \'%s\'' - % ex) - finally: - if os.path.exists(target_path): - os.remove(target_path) + LOG.warning('An error occurred during user_data execution: \'%s\'', + ex) + else: + LOG.debug('User_data stdout:\n%s', out) + LOG.debug('User_data stderr:\n%s', err) + + LOG.info('User_data script ended with return code: %d', ret_val) + return ret_val diff --git a/cloudbaseinit/tests/plugins/common/__init__.py b/cloudbaseinit/tests/plugins/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cloudbaseinit/tests/plugins/common/test_executil.py b/cloudbaseinit/tests/plugins/common/test_executil.py new file mode 100644 index 00000000..3d2df40b --- /dev/null +++ b/cloudbaseinit/tests/plugins/common/test_executil.py @@ -0,0 +1,110 @@ +# Copyright 2014 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import unittest + +import mock + +from cloudbaseinit.plugins.common import executil +from cloudbaseinit.tests import testutils + + +def _remove_file(filepath): + try: + os.remove(filepath) + except OSError: + pass + + +@mock.patch('cloudbaseinit.osutils.factory.get_os_utils') +class ExecUtilTest(unittest.TestCase): + + def test_from_data(self, _): + command = executil.BaseCommand.from_data(b"test") + + self.assertIsInstance(command, executil.BaseCommand) + + # Not public API, though. + self.assertTrue(os.path.exists(command._target_path), + command._target_path) + self.addCleanup(_remove_file, command._target_path) + + with open(command._target_path) as stream: + data = stream.read() + + self.assertEqual("test", data) + command._cleanup() + self.assertFalse(os.path.exists(command._target_path), + command._target_path) + + def test_args(self, _): + class FakeCommand(executil.BaseCommand): + command = mock.sentinel.command + + with testutils.create_tempfile() as tmp: + fake_command = FakeCommand(tmp) + self.assertEqual([mock.sentinel.command, tmp], + fake_command.args) + + fake_command = executil.BaseCommand(tmp) + self.assertEqual([tmp], fake_command.args) + + def test_from_data_extension(self, _): + class FakeCommand(executil.BaseCommand): + command = mock.sentinel.command + extension = ".test" + + command = FakeCommand.from_data(b"test") + self.assertIsInstance(command, FakeCommand) + + self.addCleanup(os.remove, command._target_path) + self.assertTrue(command._target_path.endswith(".test")) + + def test_execute_normal_command(self, mock_get_os_utils): + mock_osutils = mock_get_os_utils() + + with testutils.create_tempfile() as tmp: + command = executil.BaseCommand(tmp) + command.execute() + + mock_osutils.execute_process.assert_called_once_with( + [command._target_path], + shell=command.shell) + + # test __call__ API. + mock_osutils.execute_process.reset_mock() + command() + + mock_osutils.execute_process.assert_called_once_with( + [command._target_path], + shell=command.shell) + + def test_execute_powershell_command(self, mock_get_os_utils): + mock_osutils = mock_get_os_utils() + + with testutils.create_tempfile() as tmp: + command = executil.Powershell(tmp) + command.execute() + + mock_osutils.execute_powershell_script.assert_called_once_with( + command._target_path, command.sysnative) + + def test_execute_cleanup(self, _): + with testutils.create_tempfile() as tmp: + cleanup = mock.Mock() + command = executil.BaseCommand(tmp, cleanup=cleanup) + command.execute() + + cleanup.assert_called_once_with() diff --git a/cloudbaseinit/tests/plugins/windows/test_fileexecutils.py b/cloudbaseinit/tests/plugins/windows/test_fileexecutils.py index 9a670d65..1da69743 100644 --- a/cloudbaseinit/tests/plugins/windows/test_fileexecutils.py +++ b/cloudbaseinit/tests/plugins/windows/test_fileexecutils.py @@ -12,60 +12,49 @@ # License for the specific language governing permissions and limitations # under the License. -import mock import unittest +import mock + +from cloudbaseinit.plugins.common import executil from cloudbaseinit.plugins.windows import fileexecutils +@mock.patch('cloudbaseinit.osutils.factory.get_os_utils') class TestFileExecutilsPlugin(unittest.TestCase): - @mock.patch('cloudbaseinit.osutils.factory.get_os_utils') - def _test_exec_file(self, mock_get_os_utils, filename, exception=False): - mock_osutils = mock.MagicMock() - mock_part = mock.MagicMock() - mock_part.get_filename.return_value = filename - mock_get_os_utils.return_value = mock_osutils - if exception: - mock_osutils.execute_process.side_effect = [Exception] - with mock.patch("cloudbaseinit.plugins.windows.userdataplugins." - "shellscript.open", mock.mock_open(), create=True): - response = fileexecutils.exec_file(filename) - if filename.endswith(".cmd"): - mock_osutils.execute_process.assert_called_once_with( - [filename], True) - elif filename.endswith(".sh"): - mock_osutils.execute_process.assert_called_once_with( - ['bash.exe', filename], False) - elif filename.endswith(".py"): - mock_osutils.execute_process.assert_called_once_with( - ['python.exe', filename], False) - elif filename.endswith(".exe"): - mock_osutils.execute_process.assert_called_once_with( - [filename], False) - elif filename.endswith(".ps1"): - mock_osutils.execute_powershell_script.assert_called_once_with( - filename) - else: - self.assertEqual(0, response) + def test_exec_file_no_executor(self, _): + retval = fileexecutils.exec_file("fake.fake") + self.assertEqual(0, retval) - def test_process_cmd(self): - self._test_exec_file(filename='fake.cmd') + def test_executors_mapping(self, _): + self.assertEqual(fileexecutils.FORMATS["cmd"], + executil.Shell) + self.assertEqual(fileexecutils.FORMATS["exe"], + executil.Shell) + self.assertEqual(fileexecutils.FORMATS["sh"], + executil.Bash) + self.assertEqual(fileexecutils.FORMATS["py"], + executil.Python) + self.assertEqual(fileexecutils.FORMATS["ps1"], + executil.PowershellSysnative) - def test_process_sh(self): - self._test_exec_file(filename='fake.sh') + @mock.patch('cloudbaseinit.plugins.common.executil.' + 'BaseCommand.execute') + def test_exec_file_fails(self, mock_execute, _): + mock_execute.side_effect = ValueError + retval = fileexecutils.exec_file("fake.py") + mock_execute.assert_called_once_with() + self.assertEqual(0, retval) - def test_process_py(self): - self._test_exec_file(filename='fake.py') - - def test_process_ps1(self): - self._test_exec_file(filename='fake.ps1') - - def test_process_other(self): - self._test_exec_file(filename='fake.other') - - def test_process_exe(self): - self._test_exec_file(filename='fake.exe') - - def test_process_exception(self): - self._test_exec_file(filename='fake.exe', exception=True) + @mock.patch('cloudbaseinit.plugins.common.executil.' + 'BaseCommand.execute') + def test_exec_file_(self, mock_execute, _): + mock_execute.return_value = ( + mock.sentinel.out, + mock.sentinel.error, + 0, + ) + retval = fileexecutils.exec_file("fake.py") + mock_execute.assert_called_once_with() + self.assertEqual(0, retval) diff --git a/cloudbaseinit/tests/plugins/windows/test_userdatautils.py b/cloudbaseinit/tests/plugins/windows/test_userdatautils.py index ffc2617c..6ac7fa1e 100644 --- a/cloudbaseinit/tests/plugins/windows/test_userdatautils.py +++ b/cloudbaseinit/tests/plugins/windows/test_userdatautils.py @@ -12,123 +12,74 @@ # License for the specific language governing permissions and limitations # under the License. -import mock import os import unittest -from oslo.config import cfg +import mock +from cloudbaseinit.plugins.common import executil from cloudbaseinit.plugins.windows import userdatautils -from cloudbaseinit.tests.metadata import fake_json_response - -CONF = cfg.CONF +def _safe_remove(filepath): + try: + os.remove(filepath) + except OSError: + pass + + +@mock.patch('cloudbaseinit.osutils.factory.get_os_utils') class UserDataUtilsTest(unittest.TestCase): - def setUp(self): - self.fake_data = fake_json_response.get_fake_metadata_json( - '2013-04-04') + def _get_command(self, data): + """Get a command from the given data. - @mock.patch('re.search') - @mock.patch('tempfile.gettempdir') - @mock.patch('os.remove') - @mock.patch('os.path.exists') - @mock.patch('os.path.expandvars') - @mock.patch('cloudbaseinit.osutils.factory.get_os_utils') - @mock.patch('uuid.uuid4') - @mock.patch('cloudbaseinit.utils.encoding.write_file') - def _test_execute_user_data_script(self, mock_write_file, mock_uuid4, - mock_get_os_utils, mock_path_expandvars, - mock_path_exists, mock_os_remove, - mock_gettempdir, mock_re_search, - fake_user_data): - mock_osutils = mock.MagicMock() - mock_gettempdir.return_value = 'fake_temp' - mock_uuid4.return_value = 'randomID' - match_instance = mock.MagicMock() - path = os.path.join('fake_temp', 'randomID') - args = None - powershell = False - mock_get_os_utils.return_value = mock_osutils - mock_path_exists.return_value = True - extension = '' + If a command was obtained, then a cleanup will be added in order + to remove the underlying target path of the command. + """ + command = userdatautils._get_command(data) + if command: + self.addCleanup(_safe_remove, command._target_path) + return command - if fake_user_data == '^rem cmd\s': - side_effect = [match_instance] - number_of_calls = 1 - extension = '.cmd' - args = [path + extension] - shell = True - elif fake_user_data == '^#!/usr/bin/env\spython\s': - side_effect = [None, match_instance] - number_of_calls = 2 - extension = '.py' - args = ['python.exe', path + extension] - shell = False - elif fake_user_data == '#!': - side_effect = [None, None, match_instance] - number_of_calls = 3 - extension = '.sh' - args = ['bash.exe', path + extension] - shell = False - elif fake_user_data == '#ps1_sysnative\s': - side_effect = [None, None, None, match_instance] - number_of_calls = 4 - extension = '.ps1' - sysnative = True - powershell = True - elif fake_user_data == '#ps1_x86\s': - side_effect = [None, None, None, None, match_instance] - number_of_calls = 5 - extension = '.ps1' - shell = False - sysnative = False - powershell = True - else: - side_effect = [None, None, None, None, None] - number_of_calls = 5 + def test__get_command(self, _): + command = self._get_command(b'rem cmd test') + self.assertIsInstance(command, executil.Shell) - mock_re_search.side_effect = side_effect + command = self._get_command(b'#!/usr/bin/env python\ntest') + self.assertIsInstance(command, executil.Python) - response = userdatautils.execute_user_data_script(fake_user_data) + command = self._get_command(b'#!/bin/bash') + self.assertIsInstance(command, executil.Bash) - mock_gettempdir.assert_called_once_with() + command = self._get_command(b'#ps1_sysnative\n') + self.assertIsInstance(command, executil.PowershellSysnative) - self.assertEqual(number_of_calls, mock_re_search.call_count) - if args: - mock_write_file.assert_called_once_with(path + extension, - fake_user_data) - mock_osutils.execute_process.assert_called_with(args, shell) - mock_os_remove.assert_called_once_with(path + extension) - self.assertEqual(None, response) - elif powershell: - mock_osutils.execute_powershell_script.assert_called_with( - path + extension, sysnative) - mock_os_remove.assert_called_once_with(path + extension) - self.assertEqual(None, response) - else: - self.assertEqual(0, response) + command = self._get_command(b'#ps1_x86\n') + self.assertIsInstance(command, executil.Powershell) - def test_handle_batch(self): - fake_user_data = b'^rem cmd\s' - self._test_execute_user_data_script(fake_user_data=fake_user_data) + command = self._get_command(b'unknown') + self.assertIsNone(command) - def test_handle_python(self): - self._test_execute_user_data_script( - fake_user_data=b'^#!/usr/bin/env\spython\s') + def test_execute_user_data_script_no_commands(self, _): + retval = userdatautils.execute_user_data_script(b"unknown") + self.assertEqual(0, retval) - def test_handle_shell(self): - self._test_execute_user_data_script(fake_user_data=b'^#!') + @mock.patch('cloudbaseinit.plugins.windows.userdatautils.' + '_get_command') + def test_execute_user_data_script_fails(self, mock_get_command, _): + mock_get_command.return_value.side_effect = ValueError + retval = userdatautils.execute_user_data_script( + mock.sentinel.user_data) - def test_handle_powershell(self): - self._test_execute_user_data_script(fake_user_data=b'^#ps1\s') + self.assertEqual(0, retval) - def test_handle_powershell_sysnative(self): - self._test_execute_user_data_script(fake_user_data=b'#ps1_sysnative\s') - - def test_handle_powershell_sysnative_no_sysnative(self): - self._test_execute_user_data_script(fake_user_data=b'#ps1_x86\s') - - def test_handle_unsupported_format(self): - self._test_execute_user_data_script(fake_user_data=b'unsupported') + @mock.patch('cloudbaseinit.plugins.windows.userdatautils.' + '_get_command') + def test_execute_user_data_script(self, mock_get_command, _): + mock_get_command.return_value.return_value = ( + mock.sentinel.output, mock.sentinel.error, -1 + ) + retval = userdatautils.execute_user_data_script( + mock.sentinel.user_data) + self.assertEqual(-1, retval) diff --git a/cloudbaseinit/tests/testutils.py b/cloudbaseinit/tests/testutils.py new file mode 100644 index 00000000..dfd0cac5 --- /dev/null +++ b/cloudbaseinit/tests/testutils.py @@ -0,0 +1,59 @@ +# Copyright 2014 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import contextlib +import os +import shutil +import tempfile + + +__all__ = ( + 'create_tempfile', + 'create_tempdir', +) + + +@contextlib.contextmanager +def create_tempdir(): + """Create a temporary directory. + + This is a context manager, which creates a new temporary + directory and removes it when exiting from the context manager + block. + """ + tempdir = tempfile.mkdtemp(prefix="cloudbaseinit-tests") + try: + yield tempdir + finally: + shutil.rmtree(tempdir) + + +@contextlib.contextmanager +def create_tempfile(content=None): + """Create a temporary file. + + This is a context manager, which uses `create_tempdir` to obtain a + temporary directory, where the file will be placed. + + :param content: + Additionally, a string which will be written + in the new file. + """ + with create_tempdir() as temp: + fd, path = tempfile.mkstemp(dir=temp) + os.close(fd) + if content: + with open(path, 'w') as stream: + stream.write(content) + yield path