diff --git a/cloudbaseinit/conf/default.py b/cloudbaseinit/conf/default.py index a13ade9f..c01ddee0 100644 --- a/cloudbaseinit/conf/default.py +++ b/cloudbaseinit/conf/default.py @@ -215,6 +215,15 @@ class GlobalOptions(conf_base.Options): 'display_idle_timeout', default=0, help='The idle timeout, in seconds, before powering off ' 'the display. Set 0 to leave the display always on'), + cfg.ListOpt( + 'page_file_volume_labels', default=[], + help='Labels of volumes on which a Windows page file needs to ' + 'be created. E.g.: "Temporary Storage"'), + cfg.ListOpt( + 'page_file_volume_mount_points', default=[], + help='Volume mount points on which a Windows page file needs ' + 'to be created. E.g.: ' + '"\\\\?\\GLOBALROOT\\device\\Harddisk1\\Partition1\\"'), ] self._cli_options = [ diff --git a/cloudbaseinit/osutils/windows.py b/cloudbaseinit/osutils/windows.py index fc88e388..4a21daf7 100644 --- a/cloudbaseinit/osutils/windows.py +++ b/cloudbaseinit/osutils/windows.py @@ -242,6 +242,18 @@ kernel32.HeapFree.argtypes = [wintypes.HANDLE, wintypes.DWORD, wintypes.LPVOID] kernel32.HeapFree.restype = wintypes.BOOL +kernel32.GetVolumeNameForVolumeMountPointW.argtypes = [wintypes.LPCWSTR, + wintypes.LPWSTR, + wintypes.DWORD] +kernel32.GetVolumeNameForVolumeMountPointW.restype = wintypes.BOOL + +kernel32.GetVolumePathNamesForVolumeNameW.argtypes = [wintypes.LPCWSTR, + wintypes.LPWSTR, + wintypes.DWORD, + ctypes.POINTER( + wintypes.DWORD)] +kernel32.GetVolumePathNamesForVolumeNameW.restype = wintypes.BOOL + iphlpapi.GetIpForwardTable.argtypes = [ ctypes.POINTER(Win32_MIB_IPFORWARDTABLE), ctypes.POINTER(wintypes.ULONG), @@ -294,9 +306,12 @@ GUID_DEVINTERFACE_DISK = disk.GUID(0x53f56307, 0xb6bf, 0x11d0, 0x94, 0xf2, class WindowsUtils(base.BaseOSUtils): NERR_GroupNotFound = 2220 NERR_UserNotFound = 2221 + ERROR_PATH_NOT_FOUND = 3 ERROR_ACCESS_DENIED = 5 ERROR_INSUFFICIENT_BUFFER = 122 + ERROR_INVALID_NAME = 123 ERROR_NO_DATA = 232 + ERROR_MORE_DATA = 234 ERROR_NO_SUCH_MEMBER = 1387 ERROR_MEMBER_IN_ALIAS = 1378 ERROR_INVALID_MEMBER = 1388 @@ -1052,6 +1067,38 @@ class WindowsUtils(base.BaseOSUtils): if ret_val: return label.value + def get_volume_path_names_by_mount_point(self, mount_point): + max_volume_name_len = 50 + volume_name = ctypes.create_unicode_buffer(max_volume_name_len) + + if not kernel32.GetVolumeNameForVolumeMountPointW( + six.text_type(mount_point), volume_name, + max_volume_name_len): + if kernel32.GetLastError() in [self.ERROR_INVALID_NAME, + self.ERROR_PATH_NOT_FOUND]: + raise exception.ItemNotFoundException( + "Mount point not found: %s" % mount_point) + else: + raise exception.WindowsCloudbaseInitException( + "Failed to get volume name for mount point: %s. " + "Error: %%r" % mount_point) + + volume_path_names_len = wintypes.DWORD(100) + while True: + volume_path_names = ctypes.create_unicode_buffer( + volume_path_names_len.value) + if not kernel32.GetVolumePathNamesForVolumeNameW( + volume_name, volume_path_names, volume_path_names_len, + ctypes.byref(volume_path_names_len)): + if kernel32.GetLastError() == self.ERROR_MORE_DATA: + continue + else: + raise exception.WindowsCloudbaseInitException( + "Failed to get path names for volume name: %s." + "Error: %%r" % volume_name.value) + return [n for n in volume_path_names[ + :volume_path_names_len.value - 1].split('\0') if n] + def generate_random_password(self, length): while True: pwd = super(WindowsUtils, self).generate_random_password(length) @@ -1079,7 +1126,7 @@ class WindowsUtils(base.BaseOSUtils): return values - def _get_logical_drives(self): + def get_logical_drives(self): buf_size = self.MAX_PATH buf = ctypes.create_unicode_buffer(buf_size + 1) buf_len = kernel32.GetLogicalDriveStringsW(buf_size, buf) @@ -1090,7 +1137,7 @@ class WindowsUtils(base.BaseOSUtils): return self._split_str_buf_list(buf, buf_len) def get_cdrom_drives(self): - drives = self._get_logical_drives() + drives = self.get_logical_drives() return [d for d in drives if kernel32.GetDriveTypeW(d) == self.DRIVE_CDROM] @@ -1342,3 +1389,30 @@ class WindowsUtils(base.BaseOSUtils): 0, winreg.KEY_ALL_ACCESS) as key: winreg.SetValueEx(key, 'RealTimeIsUniversal', 0, winreg.REG_DWORD, 1 if utc else 0) + + def get_page_files(self): + with winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, + 'SYSTEM\\CurrentControlSet\\Control\\' + 'Session Manager\\Memory Management') as key: + values = winreg.QueryValueEx(key, 'PagingFiles')[0] + + page_files = [] + for value in values: + v = value.split(" ") + path = v[0] + min_size_mb = int(v[1]) if len(v) > 1 else 0 + max_size_mb = int(v[2]) if len(v) > 2 else 0 + page_files.append((path, min_size_mb, max_size_mb)) + return page_files + + def set_page_files(self, page_files): + values = [] + for path, min_size_mb, max_size_mb in page_files: + values.append("%s %d %d" % (path, min_size_mb, max_size_mb)) + + with winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, + 'SYSTEM\\CurrentControlSet\\Control\\' + 'Session Manager\\Memory Management', + 0, winreg.KEY_ALL_ACCESS) as key: + winreg.SetValueEx(key, 'PagingFiles', 0, + winreg.REG_MULTI_SZ, values) diff --git a/cloudbaseinit/plugins/windows/pagefiles.py b/cloudbaseinit/plugins/windows/pagefiles.py new file mode 100644 index 00000000..292bde97 --- /dev/null +++ b/cloudbaseinit/plugins/windows/pagefiles.py @@ -0,0 +1,79 @@ +# Copyright (c) 2017 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os + +from oslo_log import log as oslo_logging + +from cloudbaseinit import conf as cloudbaseinit_conf +from cloudbaseinit import exception +from cloudbaseinit.osutils import factory as osutils_factory +from cloudbaseinit.plugins.common import base + +CONF = cloudbaseinit_conf.CONF +LOG = oslo_logging.getLogger(__name__) + + +class PageFilesPlugin(base.BasePlugin): + + def _get_page_file_volumes_by_mount_point(self, osutils): + page_file_volume_paths = [] + for mount_point in CONF.page_file_volume_mount_points: + try: + paths = osutils.get_volume_path_names_by_mount_point( + mount_point) + if paths: + page_file_volume_paths.append(paths[0]) + except exception.ItemNotFoundException: + LOG.info("Mount point not found: %s", mount_point) + return page_file_volume_paths + + def _get_page_file_volumes_by_label(self, osutils): + page_file_logical_drives = [] + logical_drives = osutils.get_logical_drives() + for logical_drive in logical_drives: + label = osutils.get_volume_label(logical_drive) + if not label: + continue + if label.upper() in [ + v.upper() for v in CONF.page_file_volume_labels]: + page_file_logical_drives.append(logical_drive) + return page_file_logical_drives + + def _get_page_file_volumes(self, osutils): + return list(set(self._get_page_file_volumes_by_mount_point(osutils)) | + set(self._get_page_file_volumes_by_label(osutils))) + + def execute(self, service, shared_data): + osutils = osutils_factory.get_os_utils() + page_file_volumes = sorted(self._get_page_file_volumes(osutils)) + reboot_required = False + + if not page_file_volumes: + LOG.info("No page file volume found, skipping configuration") + else: + page_files = [ + (os.path.join(v, "pagefile.sys"), 0, 0) + for v in page_file_volumes] + + current_page_files = osutils.get_page_files() + if sorted(current_page_files) != sorted(page_files): + osutils.set_page_files(page_files) + LOG.info("Page file configuration set: %s", page_files) + reboot_required = True + + return base.PLUGIN_EXECUTE_ON_NEXT_BOOT, reboot_required + + def get_os_requirements(self): + return 'win32', (5, 2) diff --git a/cloudbaseinit/tests/osutils/test_windows.py b/cloudbaseinit/tests/osutils/test_windows.py index 2d8fad40..2278f7e9 100644 --- a/cloudbaseinit/tests/osutils/test_windows.py +++ b/cloudbaseinit/tests/osutils/test_windows.py @@ -1284,6 +1284,46 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase): def test_get_volume_label_no_return_value(self): self._test_get_volume_label(None) + def _test_get_volume_path_names_by_mount_point(self, err=False, + last_err=None, + volume_name=None): + mock_mount_point = mock.Mock() + self._windll_mock.kernel32.GetLastError.return_value = last_err + if err: + (self._windll_mock.kernel32. + GetVolumeNameForVolumeMountPointW.return_value) = None + if last_err in [self._winutils.ERROR_INVALID_NAME, + self._winutils.ERROR_PATH_NOT_FOUND]: + self.assertRaises( + exception.ItemNotFoundException, + self._winutils.get_volume_path_names_by_mount_point, + mock_mount_point) + else: + self.assertRaises( + exception.WindowsCloudbaseInitException, + self._winutils.get_volume_path_names_by_mount_point, + mock_mount_point) + return + (self._windll_mock.kernel32. + GetVolumePathNamesForVolumeNameW.return_value) = volume_name + if not volume_name: + if last_err != self._winutils.ERROR_MORE_DATA: + self.assertRaises( + exception.WindowsCloudbaseInitException, + self._winutils.get_volume_path_names_by_mount_point, + mock_mount_point) + + def test_get_volume_path_names_by_mount_point_not_found(self): + self._test_get_volume_path_names_by_mount_point(err=True) + + def test_get_volume_path_names_by_mount_point_failed(self): + self._test_get_volume_path_names_by_mount_point( + err=True, last_err=self._winutils.ERROR_INVALID_NAME) + + def test_get_volume_path_names_by_mount_point_error(self): + self._test_get_volume_path_names_by_mount_point( + volume_name=True, last_err=self._winutils.ERROR_MORE_DATA) + @mock.patch('re.search') @mock.patch('cloudbaseinit.osutils.base.BaseOSUtils.' 'generate_random_password') @@ -1309,9 +1349,9 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase): if buf_length is None: with self.assert_raises_windows_message( "GetLogicalDriveStringsW failed: %r", last_error): - self._winutils._get_logical_drives() + self._winutils.get_logical_drives() else: - response = self._winutils._get_logical_drives() + response = self._winutils.get_logical_drives() self._ctypes_mock.create_unicode_buffer.assert_called_with(261) mock_get_drives.assert_called_with(260, mock_buf) @@ -1324,7 +1364,7 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase): self._test_get_logical_drives(buf_length=2) @mock.patch('cloudbaseinit.osutils.windows.WindowsUtils.' - '_get_logical_drives') + 'get_logical_drives') @mock.patch('cloudbaseinit.osutils.windows.kernel32') def test_get_cdrom_drives(self, mock_kernel32, mock_get_logical_drives): mock_get_logical_drives.return_value = ['drive'] @@ -2268,3 +2308,31 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase): def test_set_real_time_clock_utc(self): self._test_set_real_time_clock_utc(utc=1) + + def test_get_page_files(self): + mock_value = [u'?:\\pagefile.sys'] + expected_page_files = [(mock_value[0], 0, 0)] + self._winreg_mock.QueryValueEx.return_value = [mock_value] + res = self._winutils.get_page_files() + key = self._winreg_mock.OpenKey.return_value.__enter__.return_value + self._winreg_mock.OpenKey.assert_called_with( + self._winreg_mock.HKEY_LOCAL_MACHINE, + 'SYSTEM\\CurrentControlSet\\Control\\' + 'Session Manager\\Memory Management') + self._winreg_mock.QueryValueEx.assert_called_with(key, 'PagingFiles') + self.assertEqual(res, expected_page_files) + + def test_set_page_files(self): + mock_path = mock.Mock() + page_files = [(mock_path, 0, 0)] + self._winutils.set_page_files(page_files) + expected_values = ["%s %d %d" % (mock_path, 0, 0)] + key = self._winreg_mock.OpenKey.return_value.__enter__.return_value + self._winreg_mock.OpenKey.assert_called_with( + self._winreg_mock.HKEY_LOCAL_MACHINE, + 'SYSTEM\\CurrentControlSet\\Control\\' + 'Session Manager\\Memory Management', + 0, self._winreg_mock.KEY_ALL_ACCESS) + self._winreg_mock.SetValueEx.assert_called_with( + key, 'PagingFiles', 0, self._winreg_mock.REG_MULTI_SZ, + expected_values) diff --git a/cloudbaseinit/tests/plugins/windows/test_pagefiles.py b/cloudbaseinit/tests/plugins/windows/test_pagefiles.py new file mode 100644 index 00000000..e37a6454 --- /dev/null +++ b/cloudbaseinit/tests/plugins/windows/test_pagefiles.py @@ -0,0 +1,166 @@ +# Copyright 2017 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import unittest + +try: + import unittest.mock as mock +except ImportError: + import mock + +from cloudbaseinit import conf as cloudbaseinit_conf +from cloudbaseinit import exception +from cloudbaseinit.plugins.common import base +from cloudbaseinit.plugins.windows import pagefiles +from cloudbaseinit.tests import testutils + +CONF = cloudbaseinit_conf.CONF +MODPATH = "cloudbaseinit.plugins.windows.pagefiles.PageFilesPlugin" + + +class PageFilesPluginTest(unittest.TestCase): + + def setUp(self): + self._pagefiles = pagefiles.PageFilesPlugin() + self._logsnatcher = testutils.LogSnatcher("cloudbaseinit.plugins" + ".windows.pagefiles") + + def _test_get_page_file_volumes_by_mount_point( + self, mount_points=None, paths=None, + get_endpoint_side_effect=False): + mock_osutils = mock.MagicMock() + mock_get_volume = mock_osutils.get_volume_path_names_by_mount_point + if get_endpoint_side_effect: + mock_get_volume.side_effect = exception.ItemNotFoundException + else: + mock_get_volume.return_value = paths + if not mount_points: + res = self._pagefiles._get_page_file_volumes_by_mount_point( + mock_osutils) + self.assertEqual(res, []) + return + with testutils.ConfPatcher("page_file_volume_mount_points", + mount_points): + if not paths: + expected_logging = ("Mount point not found: %s" + % mount_points[0]) + with self._logsnatcher: + res = (self._pagefiles. + _get_page_file_volumes_by_mount_point(mock_osutils)) + self.assertEqual(self._logsnatcher.output[0], expected_logging) + return + res = self._pagefiles._get_page_file_volumes_by_mount_point( + mock_osutils) + self.assertEqual(res, paths) + + def test_get_page_file_volumes_by_mount_point_no_mount_point(self): + self._test_get_page_file_volumes_by_mount_point(mount_points=[]) + + def test_get_page_file_volumes_by_mount_point_no_endpoint(self): + mount_points = [mock.sentinel.path] + self._test_get_page_file_volumes_by_mount_point( + mount_points=mount_points, get_endpoint_side_effect=True) + + def test_get_page_file_volumes_by_mount_point(self): + mount_points = [mock.sentinel.path] + paths = [mock.sentinel.file_path] + self._test_get_page_file_volumes_by_mount_point( + mount_points=mount_points, paths=paths) + + def _test_get_page_file_volumes_by_label(self, drives=None, + labels=None, conf_labels=None): + mock_osutils = mock.MagicMock() + mock_osutils.get_logical_drives.return_value = drives + mock_osutils.get_volume_label.return_value = labels + if not labels: + res = self._pagefiles._get_page_file_volumes_by_label(mock_osutils) + self.assertEqual(res, []) + return + with testutils.ConfPatcher("page_file_volume_labels", + conf_labels): + res = self._pagefiles._get_page_file_volumes_by_label(mock_osutils) + self.assertEqual(res, drives) + + def test_get_page_file_volumes_by_label_no_labels(self): + mock_drives = [mock.sentinel.fake_drive] + self._test_get_page_file_volumes_by_label(drives=mock_drives) + + def test_get_page_file_volumes_by_label_drive_found(self): + fake_label = mock.Mock() + mock_drives = [mock.sentinel.fake_drive] + mock_labels = fake_label + mock_conf_labels = [fake_label] + self._test_get_page_file_volumes_by_label(drives=mock_drives, + labels=mock_labels, + conf_labels=mock_conf_labels) + + @mock.patch(MODPATH + "._get_page_file_volumes_by_label") + @mock.patch(MODPATH + "._get_page_file_volumes_by_mount_point") + def test_get_page_file_volumes(self, mock_get_by_mount_point, + mock_get_by_label): + mock_osutils = mock.MagicMock() + mock_get_by_mount_point.return_value = [mock.sentinel.path_mount] + mock_get_by_label.return_value = [mock.sentinel.path_label] + self._pagefiles._get_page_file_volumes(mock_osutils) + mock_get_by_label.assert_called_once_with(mock_osutils) + mock_get_by_mount_point.assert_called_once_with(mock_osutils) + + @mock.patch('cloudbaseinit.osutils.factory.get_os_utils') + @mock.patch(MODPATH + "._get_page_file_volumes") + def _test_execute(self, mock_get_page_file_volumes, + mock_get_os_utils, file_volumes=None, + get_page_files_res=None): + mock_service = mock.Mock() + mock_shared_data = mock.Mock() + mock_osutils = mock.Mock() + mock_get_os_utils.return_value = mock_osutils + mock_get_page_file_volumes.return_value = file_volumes + mock_osutils.get_page_files.return_value = get_page_files_res + mock_osutils.set_page_files.return_value = file_volumes + if not file_volumes: + expected_res = (base.PLUGIN_EXECUTE_ON_NEXT_BOOT, False) + expected_logging = [ + "No page file volume found, skipping configuration"] + with self._logsnatcher: + res = self._pagefiles.execute(mock_service, mock_shared_data) + + self.assertEqual(res, expected_res) + self.assertEqual(self._logsnatcher.output, expected_logging) + mock_get_page_file_volumes.assert_called_once_with(mock_osutils) + return + with self._logsnatcher: + expected_file = [ + (os.path.join(file_volumes[0], "pagefile.sys"), 0, 0)] + expected_logging = [ + "Page file configuration set: %s" % expected_file] + expected_res = (base.PLUGIN_EXECUTE_ON_NEXT_BOOT, True) + res = self._pagefiles.execute(mock_service, mock_shared_data) + self.assertEqual(res, expected_res) + self.assertEqual(self._logsnatcher.output, expected_logging) + mock_osutils.get_page_files.assert_called_once_with() + mock_osutils.set_page_files.assert_called_once_with(expected_file) + + def test_execute_no_page_file(self): + self._test_execute(file_volumes=[]) + + def test_execute_page_files_found(self): + mock_file_volumes = [str(mock.sentinel.fake_file)] + self._test_execute(file_volumes=mock_file_volumes, + get_page_files_res=[]) + + def test_get_os_requirements(self): + res = self._pagefiles.get_os_requirements() + expected_res = ('win32', (5, 2)) + self.assertEqual(res, expected_res)