diff --git a/cloudbaseinit/osutils/windows.py b/cloudbaseinit/osutils/windows.py index a7d6acce..072da28d 100644 --- a/cloudbaseinit/osutils/windows.py +++ b/cloudbaseinit/osutils/windows.py @@ -28,21 +28,14 @@ from ctypes import wintypes from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.osutils import base +LOG = logging.getLogger(__name__) + advapi32 = windll.advapi32 kernel32 = windll.kernel32 netapi32 = windll.netapi32 userenv = windll.userenv - -kernel32.SetComputerNameExW.argtypes = [ctypes.c_int, wintypes.LPCWSTR] -kernel32.SetComputerNameExW.restype = wintypes.BOOL - -kernel32.GetLogicalDriveStringsW.argtypes = [wintypes.DWORD, wintypes.LPWSTR] -kernel32.GetLogicalDriveStringsW.restype = wintypes.DWORD - -kernel32.GetDriveTypeW.argtypes = [wintypes.LPCWSTR] -kernel32.GetDriveTypeW.restype = wintypes.UINT - -LOG = logging.getLogger(__name__) +iphlpapi = windll.iphlpapi +Ws2_32 = windll.Ws2_32 class Win32_PROFILEINFO(ctypes.Structure): @@ -64,9 +57,67 @@ class Win32_LOCALGROUP_MEMBERS_INFO_3(ctypes.Structure): ] +class Win32_MIB_IPFORWARDROW(ctypes.Structure): + _fields_ = [ + ('dwForwardDest', wintypes.DWORD), + ('dwForwardMask', wintypes.DWORD), + ('dwForwardPolicy', wintypes.DWORD), + ('dwForwardNextHop', wintypes.DWORD), + ('dwForwardIfIndex', wintypes.DWORD), + ('dwForwardType', wintypes.DWORD), + ('dwForwardProto', wintypes.DWORD), + ('dwForwardAge', wintypes.DWORD), + ('dwForwardNextHopAS', wintypes.DWORD), + ('dwForwardMetric1', wintypes.DWORD), + ('dwForwardMetric2', wintypes.DWORD), + ('dwForwardMetric3', wintypes.DWORD), + ('dwForwardMetric4', wintypes.DWORD), + ('dwForwardMetric5', wintypes.DWORD) + ] + + +class Win32_MIB_IPFORWARDTABLE(ctypes.Structure): + _fields_ = [ + ('dwNumEntries', wintypes.DWORD), + ('table', Win32_MIB_IPFORWARDROW * 1) + ] + + +kernel32.SetComputerNameExW.argtypes = [ctypes.c_int, wintypes.LPCWSTR] +kernel32.SetComputerNameExW.restype = wintypes.BOOL + +kernel32.GetLogicalDriveStringsW.argtypes = [wintypes.DWORD, wintypes.LPWSTR] +kernel32.GetLogicalDriveStringsW.restype = wintypes.DWORD + +kernel32.GetDriveTypeW.argtypes = [wintypes.LPCWSTR] +kernel32.GetDriveTypeW.restype = wintypes.UINT + +kernel32.GetProcessHeap.argtypes = [] +kernel32.GetProcessHeap.restype = wintypes.HANDLE + +# Note: wintypes.ULONG must be replaced with a 64 bit variable on x64 +kernel32.HeapAlloc.argtypes = [wintypes.HANDLE, wintypes.DWORD, + wintypes.ULONG] +kernel32.HeapAlloc.restype = wintypes.LPVOID + +kernel32.HeapFree.argtypes = [wintypes.HANDLE, wintypes.DWORD, + wintypes.LPVOID] +kernel32.HeapFree.restype = wintypes.BOOL + +iphlpapi.GetIpForwardTable.argtypes = [ + ctypes.POINTER(Win32_MIB_IPFORWARDTABLE), + ctypes.POINTER(wintypes.ULONG), + wintypes.BOOL] +iphlpapi.GetIpForwardTable.restype = wintypes.DWORD + +Ws2_32.inet_ntoa.restype = ctypes.c_char_p + + class WindowsUtils(base.BaseOSUtils): NERR_GroupNotFound = 2220 ERROR_ACCESS_DENIED = 5 + ERROR_INSUFFICIENT_BUFFER = 122 + ERROR_NO_DATA = 232 ERROR_NO_SUCH_MEMBER = 1387 ERROR_MEMBER_IN_ALIAS = 1378 ERROR_INVALID_MEMBER = 1388 @@ -345,16 +396,68 @@ class WindowsUtils(base.BaseOSUtils): self._stop_service(self._service_name) def get_default_gateway(self): - conn = wmi.WMI(moniker='//./root/cimv2') - for net_adapter_config in conn.Win32_NetworkAdapterConfiguration(): - if net_adapter_config.DefaultIPGateway: - return (net_adapter_config.InterfaceIndex, - net_adapter_config.DefaultIPGateway[0]) - return (None, None) + default_routes = [r for r in self._get_ipv4_routing_table() + if r[0] == '0.0.0.0'] + if default_routes: + return (default_routes[0][3], default_routes[0][2]) + else: + return (None, None) + + def _get_ipv4_routing_table(self): + routing_table = [] + + heap = kernel32.GetProcessHeap() + + size = wintypes.ULONG(ctypes.sizeof(Win32_MIB_IPFORWARDTABLE)) + p = kernel32.HeapAlloc(heap, 0, size) + if not p: + raise Exception('Unable to allocate memory for the IP forward ' + 'table') + p_forward_table = ctypes.cast( + p, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE)) + + try: + err = iphlpapi.GetIpForwardTable(p_forward_table, + ctypes.byref(size), 0) + if err == self.ERROR_INSUFFICIENT_BUFFER: + kernel32.HeapFree(heap, 0, p_forward_table) + p = kernel32.HeapAlloc(heap, 0, size) + if not p: + raise Exception('Unable to allocate memory for the IP ' + 'forward table') + p_forward_table = ctypes.cast( + p, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE)) + + err = iphlpapi.GetIpForwardTable(p_forward_table, + ctypes.byref(size), 0) + if err != self.ERROR_NO_DATA: + if err: + raise Exception('Unable to get IP forward table. ' + 'Error: %s' % err) + + forward_table = p_forward_table.contents + table = ctypes.cast(ctypes.addressof(forward_table.table), + ctypes.POINTER(Win32_MIB_IPFORWARDROW * + forward_table.dwNumEntries)).contents + + i = 0 + while i < forward_table.dwNumEntries: + row = table[i] + routing_table.append(( + Ws2_32.inet_ntoa(row.dwForwardDest), + Ws2_32.inet_ntoa(row.dwForwardMask), + Ws2_32.inet_ntoa(row.dwForwardNextHop), + row.dwForwardIfIndex, + row.dwForwardMetric1)) + i += 1 + + return routing_table + finally: + kernel32.HeapFree(heap, 0, p_forward_table) def check_static_route_exists(self, destination): - conn = wmi.WMI(moniker='//./root/cimv2') - return len(conn.Win32_IP4RouteTable(Destination=destination)) > 0 + return len([r for r in self._get_ipv4_routing_table() + if r[0] == destination]) > 0 def add_static_route(self, destination, mask, next_hop, interface_index, metric): @@ -364,27 +467,6 @@ class WindowsUtils(base.BaseOSUtils): if err: raise Exception('Unable to add route: %(err)s' % locals()) - # TODO(alexpilotti): The following code creates the route properly and - # "route print" shows the added route, but routing to the destination - # fails. This option would be preferable compared to spawning a - # "ROUTE ADD" process. - ''' - ROUTE_PROTOCOL_NETMGMT = 3 - ROUTE_TYPE_INDIRECT = 4 - - conn = wmi.WMI(moniker='//./root/cimv2') - - route = conn.Win32_IP4RouteTable.SpawnInstance_() - route.Destination = destination - route.Mask = mask - route.NextHop = next_hop - route.InterfaceIndex = interface_index - route.Metric1 = metric - route.Protocol = self.ROUTE_PROTOCOL_NETMGMT - route.Type = self.ROUTE_TYPE_INDIRECT - route.Put_() - ''' - def get_os_version(self): conn = wmi.WMI(moniker='//./root/cimv2') return conn.Win32_OperatingSystem()[0].Version