diff --git a/zun/common/exception.py b/zun/common/exception.py index 4492a633f..32a1c4883 100644 --- a/zun/common/exception.py +++ b/zun/common/exception.py @@ -522,23 +522,23 @@ class ResourcesUnavailable(ZunException): message = _("Insufficient compute resources: %(reason)s.") -class PciConfigInvalidWhitelist(Invalid): - msg_fmt = _("Invalid PCI devices Whitelist config %(reason)s") +class PciConfigInvalidWhitelist(ZunException): + message = _("Invalid PCI devices Whitelist config %(reason)s") class PciDeviceWrongAddressFormat(ZunException): - msg_fmt = _("The PCI address %(address)s has an incorrect format.") + message = _("The PCI address %(address)s has an incorrect format.") class PciDeviceInvalidDeviceName(ZunException): - msg_fmt = _("Invalid PCI Whitelist: " + message = _("Invalid PCI Whitelist: " "The PCI whitelist can specify devname or address," " but not both") class PciDeviceNotFoundById(NotFound): - msg_fmt = _("PCI device %(id)s not found") + message = _("PCI device %(id)s not found") class PciDeviceNotFound(NotFound): - msg_fmt = _("PCI Device %(node_id)s:%(address)s not found.") + message = _("PCI Device %(node_id)s:%(address)s not found.") diff --git a/zun/pci/devspec.py b/zun/pci/devspec.py new file mode 100644 index 000000000..364a2322b --- /dev/null +++ b/zun/pci/devspec.py @@ -0,0 +1,289 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import abc +import re +import string + +import six + +from zun.common import exception +from zun.pci import utils + +MAX_VENDOR_ID = 0xFFFF +MAX_PRODUCT_ID = 0xFFFF +MAX_FUNC = 0x7 +MAX_DOMAIN = 0xFFFF +MAX_BUS = 0xFF +MAX_SLOT = 0x1F +ANY = '*' +REGEX_ANY = '.*' + + +def get_pci_dev_info(pci_obj, property, max, hex_value): + a = getattr(pci_obj, property) + if a == ANY: + return + try: + v = int(a, 16) + except ValueError: + raise exception.PciConfigInvalidWhitelist( + reason="invalid %s %s" % (property, a)) + if v > max: + raise exception.PciConfigInvalidWhitelist( + reason=("invalid %(property)s %(attr)s" % + {'property': property, 'attr': a})) + setattr(pci_obj, property, hex_value % v) + + +@six.add_metaclass(abc.ABCMeta) +class PciAddressSpec(object): + """Abstract class for all PCI address spec styles + + This class checks the address fields of the pci.passthrough_whitelist + """ + + @abc.abstractmethod + def match(self, pci_addr): + pass + + def is_single_address(self): + return all([ + all(c in string.hexdigits for c in self.domain), + all(c in string.hexdigits for c in self.bus), + all(c in string.hexdigits for c in self.slot), + all(c in string.hexdigits for c in self.func)]) + + +class PhysicalPciAddress(PciAddressSpec): + """Manages the address fields for a fully-qualified PCI address. + + This function class will validate the address fields for a single + PCI device. + """ + def __init__(self, pci_addr): + try: + if isinstance(pci_addr, dict): + self.domain = pci_addr['domain'] + self.bus = pci_addr['bus'] + self.slot = pci_addr['slot'] + self.func = pci_addr['function'] + else: + self.domain, self.bus, self.slot, self.func = ( + utils.get_pci_address_fields(pci_addr)) + get_pci_dev_info(self, 'func', MAX_FUNC, '%1x') + get_pci_dev_info(self, 'domain', MAX_DOMAIN, '%04x') + get_pci_dev_info(self, 'bus', MAX_BUS, '%02x') + get_pci_dev_info(self, 'slot', MAX_SLOT, '%02x') + except (KeyError, ValueError): + raise exception.PciDeviceWrongAddressFormat(address=pci_addr) + + def match(self, phys_pci_addr): + conditions = [ + self.domain == phys_pci_addr.domain, + self.bus == phys_pci_addr.bus, + self.slot == phys_pci_addr.slot, + self.func == phys_pci_addr.func, + ] + return all(conditions) + + +class PciAddressGlobSpec(PciAddressSpec): + """Manages the address fields with glob style. + + This function class will validate the address fields with glob style, + check for wildcards, and insert wildcards where the field is left blank. + """ + + def __init__(self, pci_addr): + self.domain = ANY + self.bus = ANY + self.slot = ANY + self.func = ANY + + dbs, sep, func = pci_addr.partition('.') + if func: + self.func = func.strip() + get_pci_dev_info(self, 'func', MAX_FUNC, '%01x') + if dbs: + dbs_fields = dbs.split(':') + if len(dbs_fields) > 3: + raise exception.PciDeviceWrongAddressFormat(address=pci_addr) + # If we got a partial address like ":00.", we need to turn this + # into a domain of ANY, a bus of ANY, and a slot of 00. This code + # allows the address bus and/or domain to be left off + dbs_all = [ANY] * (3 - len(dbs_fields)) + dbs_all.extend(dbs_fields) + dbs_checked = [s.strip() or ANY for s in dbs_all] + self.domain, self.bus, self.slot = dbs_checked + get_pci_dev_info(self, 'domain', MAX_DOMAIN, '%04x') + get_pci_dev_info(self, 'bus', MAX_BUS, '%02x') + get_pci_dev_info(self, 'slot', MAX_SLOT, '%02x') + + def match(self, phys_pci_addr): + conditions = [ + self.domain in (ANY, phys_pci_addr.domain), + self.bus in (ANY, phys_pci_addr.bus), + self.slot in (ANY, phys_pci_addr.slot), + self.func in (ANY, phys_pci_addr.func) + ] + return all(conditions) + + +class PciAddressRegexSpec(PciAddressSpec): + """Manages the address fields with regex style. + + This function class will validate the address fields with regex style. + The validation includes check for all PCI address attributes and validate + their regex. + """ + def __init__(self, pci_addr): + try: + self.domain = pci_addr.get('domain', REGEX_ANY) + self.bus = pci_addr.get('bus', REGEX_ANY) + self.slot = pci_addr.get('slot', REGEX_ANY) + self.func = pci_addr.get('function', REGEX_ANY) + self.domain_regex = re.compile(self.domain) + self.bus_regex = re.compile(self.bus) + self.slot_regex = re.compile(self.slot) + self.func_regex = re.compile(self.func) + except re.error: + raise exception.PciDeviceWrongAddressFormat(address=pci_addr) + + def match(self, phys_pci_addr): + conditions = [ + bool(self.domain_regex.match(phys_pci_addr.domain)), + bool(self.bus_regex.match(phys_pci_addr.bus)), + bool(self.slot_regex.match(phys_pci_addr.slot)), + bool(self.func_regex.match(phys_pci_addr.func)) + ] + return all(conditions) + + +class WhitelistPciAddress(object): + """Manages the address fields of the whitelist. + + This class checks the address fields of the pci.passthrough_whitelist + configuration option, validating the address fields. + Example config are: + + | [pci] + | passthrough_whitelist = {"address":"*:0a:00.*", + | "physical_network":"physnet1"} + | passthrough_whitelist = {"address": {"domain": ".*", + "bus": "02", + "slot": "01", + "function": "[0-2]"}, + "physical_network":"net1"} + | passthrough_whitelist = {"vendor_id":"1137","product_id":"0071"} + + """ + def __init__(self, pci_addr, is_physical_function): + self.is_physical_function = is_physical_function + self._init_address_fields(pci_addr) + + def _check_physical_function(self): + if self.pci_address_spec.is_single_address(): + self.is_physical_function = ( + utils.is_physical_function( + self.pci_address_spec.domain, + self.pci_address_spec.bus, + self.pci_address_spec.slot, + self.pci_address_spec.func)) + + def _init_address_fields(self, pci_addr): + if not self.is_physical_function: + if isinstance(pci_addr, six.string_types): + self.pci_address_spec = PciAddressGlobSpec(pci_addr) + elif isinstance(pci_addr, dict): + self.pci_address_spec = PciAddressRegexSpec(pci_addr) + else: + raise exception.PciDeviceWrongAddressFormat(address=pci_addr) + self._check_physical_function() + else: + self.pci_address_spec = PhysicalPciAddress(pci_addr) + + def match(self, pci_addr, pci_phys_addr): + """Match a device to this PciAddress. + + Assume this is called given pci_addr and pci_phys_addr, no attempt is + made to verify if pci_addr is a VF of pci_phys_addr. + + :param pci_addr: PCI address of the device to match. + :param pci_phys_addr: PCI address of the parent of the device to match + (or None if the device is not a VF). + """ + + # Try to match on the parent PCI address if the PciDeviceSpec is a + # PF (sriov is available) and the device to match is a VF. This + # makes possible to specify the PCI address of a PF in the + # pci_passthrough_whitelist to match any of it's VFs PCI devices. + if self.is_physical_function and pci_phys_addr: + pci_phys_addr_obj = PhysicalPciAddress(pci_phys_addr) + if self.pci_address_spec.match(pci_phys_addr_obj): + return True + + # Try to match on the device PCI address only. + pci_addr_obj = PhysicalPciAddress(pci_addr) + return self.pci_address_spec.match(pci_addr_obj) + + +class PciDeviceSpec(object): + def __init__(self, dev_spec): + self.tags = dev_spec + self._init_dev_details() + + def _init_dev_details(self): + self.vendor_id = self.tags.pop("vendor_id", ANY) + self.product_id = self.tags.pop("product_id", ANY) + # Note(moshele): The address attribute can be a string or a dict. + # For glob syntax or specific pci it is a string and for regex syntax + # it is a dict. The WhitelistPciAddress class handles both types. + self.address = self.tags.pop("address", None) + self.dev_name = self.tags.pop("devname", None) + + self.vendor_id = self.vendor_id.strip() + get_pci_dev_info(self, 'vendor_id', MAX_VENDOR_ID, '%04x') + get_pci_dev_info(self, 'product_id', MAX_PRODUCT_ID, '%04x') + + if self.address and self.dev_name: + raise exception.PciDeviceInvalidDeviceName() + if not self.dev_name: + pci_address = self.address or "*:*:*.*" + self.address = WhitelistPciAddress(pci_address, False) + + def match(self, dev_dict): + if self.dev_name: + address_str, pf = utils.get_function_by_ifname( + self.dev_name) + if not address_str: + return False + # Note(moshele): In this case we always passing a string + # of the PF pci address + address_obj = WhitelistPciAddress(address_str, pf) + elif self.address: + address_obj = self.address + return all([ + self.vendor_id in (ANY, dev_dict['vendor_id']), + self.product_id in (ANY, dev_dict['product_id']), + address_obj.match(dev_dict['address'], + dev_dict.get('parent_addr'))]) + + def match_pci_obj(self, pci_obj): + return self.match({'vendor_id': pci_obj.vendor_id, + 'product_id': pci_obj.product_id, + 'address': pci_obj.address, + 'parent_addr': pci_obj.parent_addr}) + + def get_tags(self): + return self.tags diff --git a/zun/tests/unit/pci/test_devspec.py b/zun/tests/unit/pci/test_devspec.py new file mode 100644 index 000000000..0d3251b21 --- /dev/null +++ b/zun/tests/unit/pci/test_devspec.py @@ -0,0 +1,417 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import mock +import six + +from zun.common import exception +from zun.pci import devspec +from zun.tests import base + +dev = {"vendor_id": "8086", + "product_id": "5057", + "address": "0000:0a:00.5", + "parent_addr": "0000:0a:00.0"} + + +class PciAddressSpecTestCase(base.TestCase): + def test_pci_address_spec_abstact_instance_fail(self): + self.assertRaises(TypeError, devspec.PciAddressSpec) + + +class PhysicalPciAddressTestCase(base.TestCase): + pci_addr = {"domain": "0000", + "bus": "0a", + "slot": "00", + "function": "5"} + + def test_init_by_dict(self): + phys_addr = devspec.PhysicalPciAddress(self.pci_addr) + self.assertEqual(phys_addr.domain, self.pci_addr['domain']) + self.assertEqual(phys_addr.bus, self.pci_addr['bus']) + self.assertEqual(phys_addr.slot, self.pci_addr['slot']) + self.assertEqual(phys_addr.func, self.pci_addr['function']) + + def test_init_by_dict_invalid_address_values(self): + invalid_val_addr = {"domain": devspec.MAX_DOMAIN + 1, + "bus": devspec.MAX_BUS + 1, + "slot": devspec.MAX_SLOT + 1, + "function": devspec.MAX_FUNC + 1} + for component in invalid_val_addr: + address = dict(self.pci_addr) + address[component] = str(invalid_val_addr[component]) + self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PhysicalPciAddress, address) + + def test_init_by_dict_missing_values(self): + for component in self.pci_addr: + address = dict(self.pci_addr) + del address[component] + self.assertRaises(exception.PciDeviceWrongAddressFormat, + devspec.PhysicalPciAddress, address) + + def test_init_by_string(self): + address_str = "0000:0a:00.5" + phys_addr = devspec.PhysicalPciAddress(address_str) + self.assertEqual(phys_addr.domain, "0000") + self.assertEqual(phys_addr.bus, "0a") + self.assertEqual(phys_addr.slot, "00") + self.assertEqual(phys_addr.func, "5") + + def test_init_by_string_invalid_values(self): + invalid_addresses = [str(devspec.MAX_DOMAIN + 1) + ":0a:00.5", + "0000:" + str(devspec.MAX_BUS + 1) + ":00.5", + "0000:0a:" + str(devspec.MAX_SLOT + 1) + ".5", + "0000:0a:00." + str(devspec.MAX_FUNC + 1)] + for address in invalid_addresses: + self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PhysicalPciAddress, address) + + def test_init_by_string_missing_values(self): + invalid_addresses = ["00:0000:0a:00.5", "0a:00.5", "0000:00.5"] + for address in invalid_addresses: + self.assertRaises(exception.PciDeviceWrongAddressFormat, + devspec.PhysicalPciAddress, address) + + def test_match(self): + address_str = "0000:0a:00.5" + phys_addr1 = devspec.PhysicalPciAddress(address_str) + phys_addr2 = devspec.PhysicalPciAddress(address_str) + self.assertTrue(phys_addr1.match(phys_addr2)) + + def test_false_match(self): + address_str = "0000:0a:00.5" + phys_addr1 = devspec.PhysicalPciAddress(address_str) + addresses = ["0010:0a:00.5", "0000:0b:00.5", + "0000:0a:01.5", "0000:0a:00.4"] + for address in addresses: + phys_addr2 = devspec.PhysicalPciAddress(address) + self.assertFalse(phys_addr1.match(phys_addr2)) + + +class PciAddressGlobSpecTestCase(base.TestCase): + def test_init(self): + address_str = "0000:0a:00.5" + phys_addr = devspec.PciAddressGlobSpec(address_str) + self.assertEqual(phys_addr.domain, "0000") + self.assertEqual(phys_addr.bus, "0a") + self.assertEqual(phys_addr.slot, "00") + self.assertEqual(phys_addr.func, "5") + + def test_init_invalid_address(self): + invalid_addresses = ["00:0000:0a:00.5"] + for address in invalid_addresses: + self.assertRaises(exception.PciDeviceWrongAddressFormat, + devspec.PciAddressGlobSpec, address) + + def test_init_invalid_values(self): + invalid_addresses = [str(devspec.MAX_DOMAIN + 1) + ":0a:00.5", + "0000:" + str(devspec.MAX_BUS + 1) + ":00.5", + "0000:0a:" + str(devspec.MAX_SLOT + 1) + ".5", + "0000:0a:00." + str(devspec.MAX_FUNC + 1)] + for address in invalid_addresses: + self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PciAddressGlobSpec, address) + + def test_match(self): + address_str = "0000:0a:00.5" + phys_addr = devspec.PhysicalPciAddress(address_str) + addresses = ["0000:0a:00.5", "*:0a:00.5", "0000:*:00.5", + "0000:0a:*.5", "0000:0a:00.*"] + for address in addresses: + glob_addr = devspec.PciAddressGlobSpec(address) + self.assertTrue(glob_addr.match(phys_addr)) + + def test_false_match(self): + address_str = "0000:0a:00.5" + phys_addr = devspec.PhysicalPciAddress(address_str) + addresses = ["0010:0a:00.5", "0000:0b:00.5", + "*:0a:01.5", "0000:0a:*.4"] + for address in addresses: + glob_addr = devspec.PciAddressGlobSpec(address) + self.assertFalse(phys_addr.match(glob_addr)) + + +class PciAddressRegexSpecTestCase(base.TestCase): + def test_init(self): + address_regex = {"domain": ".*", + "bus": "02", + "slot": "01", + "function": "[0-2]"} + phys_addr = devspec.PciAddressRegexSpec(address_regex) + self.assertEqual(phys_addr.domain, ".*") + self.assertEqual(phys_addr.bus, "02") + self.assertEqual(phys_addr.slot, "01") + self.assertEqual(phys_addr.func, "[0-2]") + + def test_init_invalid_address(self): + invalid_addresses = [{"domain": "*", + "bus": "02", + "slot": "01", + "function": "[0-2]"}] + + for address in invalid_addresses: + self.assertRaises(exception.PciDeviceWrongAddressFormat, + devspec.PciAddressRegexSpec, address) + + def test_match(self): + address_str = "0000:0a:00.5" + phys_addr = devspec.PhysicalPciAddress(address_str) + addresses = [{"domain": ".*", "bus": "0a", + "slot": "00", "function": "[5-6]"}, + {"domain": ".*", "bus": "0a", + "slot": ".*", "function": "[4-5]"}, + {"domain": ".*", "bus": "0a", + "slot": "[0-3]", "function": ".*"}] + for address in addresses: + regex_addr = devspec.PciAddressRegexSpec(address) + self.assertTrue(regex_addr.match(phys_addr)) + + def test_false_match(self): + address_str = "0000:0b:00.5" + phys_addr = devspec.PhysicalPciAddress(address_str) + addresses = [{"domain": ".*", "bus": "0a", + "slot": "00", "function": "[5-6]"}, + {"domain": ".*", "bus": "02", + "slot": ".*", "function": "[4-5]"}, + {"domain": ".*", "bus": "02", + "slot": "[0-3]", "function": ".*"}] + for address in addresses: + regex_addr = devspec.PciAddressRegexSpec(address) + self.assertFalse(regex_addr.match(phys_addr)) + + +class PciAddressTestCase(base.TestCase): + def test_wrong_address(self): + pci_info = {"vendor_id": "8086", "address": "*: *: *.6", + "product_id": "5057", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertFalse(pci.match(dev)) + + def test_address_too_big(self): + pci_info = {"address": "0000:0a:0b:00.5", + "physical_network": "hr_net"} + self.assertRaises(exception.PciDeviceWrongAddressFormat, + devspec.PciDeviceSpec, pci_info) + + def test_address_invalid_character(self): + pci_info = {"address": "0000:h4.12:6", "physical_network": "hr_net"} + exc = self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PciDeviceSpec, pci_info) + msg = ('Invalid PCI devices Whitelist config invalid func 12:6') + self.assertEqual(msg, six.text_type(exc)) + + def test_max_func(self): + pci_info = {"address": "0000:0a:00.%s" % (devspec.MAX_FUNC + 1), + "physical_network": "hr_net"} + exc = self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PciDeviceSpec, pci_info) + msg = ('Invalid PCI devices Whitelist config invalid func %x' + % (devspec.MAX_FUNC + 1)) + self.assertEqual(msg, six.text_type(exc)) + + def test_max_domain(self): + pci_info = {"address": "%x:0a:00.5" % (devspec.MAX_DOMAIN + 1), + "physical_network": "hr_net"} + exc = self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PciDeviceSpec, pci_info) + msg = ('Invalid PCI devices Whitelist config invalid domain %x' + % (devspec.MAX_DOMAIN + 1)) + self.assertEqual(msg, six.text_type(exc)) + + def test_max_bus(self): + pci_info = {"address": "0000:%x:00.5" % (devspec.MAX_BUS + 1), + "physical_network": "hr_net"} + exc = self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PciDeviceSpec, pci_info) + msg = ('Invalid PCI devices Whitelist config invalid bus %x' + % (devspec.MAX_BUS + 1)) + self.assertEqual(msg, six.text_type(exc)) + + def test_max_slot(self): + pci_info = {"address": "0000:0a:%x.5" % (devspec.MAX_SLOT + 1), + "physical_network": "hr_net"} + exc = self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PciDeviceSpec, pci_info) + msg = ('Invalid PCI devices Whitelist config invalid slot %x' + % (devspec.MAX_SLOT + 1)) + self.assertEqual(msg, six.text_type(exc)) + + def test_address_is_undefined(self): + pci_info = {"vendor_id": "8086", "product_id": "5057"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertTrue(pci.match(dev)) + + def test_partial_address(self): + pci_info = {"address": ":0a:00.", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + dev = {"vendor_id": "1137", + "product_id": "0071", + "address": "0000:0a:00.5", + "parent_addr": "0000:0a:00.0"} + self.assertTrue(pci.match(dev)) + + def test_partial_address_func(self): + pci_info = {"address": ".5", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + dev = {"vendor_id": "1137", + "product_id": "0071", + "address": "0000:0a:00.5", + "phys_function": "0000:0a:00.0"} + self.assertTrue(pci.match(dev)) + + @mock.patch('zun.pci.utils.is_physical_function', return_value=True) + def test_address_is_pf(self, mock_is_physical_function): + pci_info = {"address": "0000:0a:00.0", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertTrue(pci.match(dev)) + + @mock.patch('zun.pci.utils.is_physical_function', return_value=True) + def test_address_pf_no_parent_addr(self, mock_is_physical_function): + _dev = dev.copy() + _dev.pop('parent_addr') + pci_info = {"address": "0000:0a:00.5", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertTrue(pci.match(_dev)) + + def test_spec_regex_match(self): + pci_info = {"address": {"domain": ".*", + "bus": ".*", + "slot": "00", + "function": "[5-6]" + }, + "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertTrue(pci.match(dev)) + + def test_spec_regex_no_match(self): + pci_info = {"address": {"domain": ".*", + "bus": ".*", + "slot": "00", + "function": "[6-7]" + }, + "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertFalse(pci.match(dev)) + + def test_spec_invalid_regex(self): + pci_info = {"address": {"domain": ".*", + "bus": ".*", + "slot": "00", + "function": "[6[-7]" + }, + "physical_network": "hr_net"} + self.assertRaises(exception.PciDeviceWrongAddressFormat, + devspec.PciDeviceSpec, pci_info) + + def test_spec_invalid_regex2(self): + pci_info = {"address": {"domain": "*", + "bus": "*", + "slot": "00", + "function": "[6-7]" + }, + "physical_network": "hr_net"} + self.assertRaises(exception.PciDeviceWrongAddressFormat, + devspec.PciDeviceSpec, pci_info) + + def test_spec_partial_bus_regex(self): + pci_info = {"address": {"domain": ".*", + "slot": "00", + "function": "[5-6]" + }, + "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertTrue(pci.match(dev)) + + def test_spec_partial_address_regex(self): + pci_info = {"address": {"domain": ".*", + "bus": ".*", + "slot": "00", + }, + "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertTrue(pci.match(dev)) + + def test_spec_invalid_address(self): + pci_info = {"address": [".*", ".*", "00", "[6-7]"], + "physical_network": "hr_net"} + self.assertRaises(exception.PciDeviceWrongAddressFormat, + devspec.PciDeviceSpec, pci_info) + + @mock.patch('zun.pci.utils.is_physical_function', return_value=True) + def test_address_is_pf_regex(self, mock_is_physical_function): + pci_info = {"address": {"domain": "0000", + "bus": "0a", + "slot": "00", + "function": "0" + }, + "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertTrue(pci.match(dev)) + + +class PciDevSpecTestCase(base.TestCase): + def test_spec_match(self): + pci_info = {"vendor_id": "8086", "address": "*: *: *.5", + "product_id": "5057", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertTrue(pci.match(dev)) + + def test_invalid_vendor_id(self): + pci_info = {"vendor_id": "8087", "address": "*: *: *.5", + "product_id": "5057", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertFalse(pci.match(dev)) + + def test_vendor_id_out_of_range(self): + pci_info = {"vendor_id": "80860", "address": "*:*:*.5", + "product_id": "5057", "physical_network": "hr_net"} + exc = self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PciDeviceSpec, pci_info) + self.assertEqual("Invalid PCI devices Whitelist config " + "invalid vendor_id 80860", six.text_type(exc)) + + def test_invalid_product_id(self): + pci_info = {"vendor_id": "8086", "address": "*: *: *.5", + "product_id": "5056", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertFalse(pci.match(dev)) + + def test_product_id_out_of_range(self): + pci_info = {"vendor_id": "8086", "address": "*:*:*.5", + "product_id": "50570", "physical_network": "hr_net"} + exc = self.assertRaises(exception.PciConfigInvalidWhitelist, + devspec.PciDeviceSpec, pci_info) + self.assertEqual("Invalid PCI devices Whitelist config " + "invalid product_id 50570", six.text_type(exc)) + + def test_devname_and_address(self): + pci_info = {"devname": "eth0", "vendor_id": "8086", + "address": "*:*:*.5", "physical_network": "hr_net"} + self.assertRaises(exception.PciDeviceInvalidDeviceName, + devspec.PciDeviceSpec, pci_info) + + @mock.patch('zun.pci.utils.get_function_by_ifname', + return_value=("0000:0a:00.0", True)) + def test_by_name(self, mock_get_function_by_ifname): + pci_info = {"devname": "eth0", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertTrue(pci.match(dev)) + + @mock.patch('zun.pci.utils.get_function_by_ifname', + return_value=(None, False)) + def test_invalid_name(self, mock_get_function_by_ifname): + pci_info = {"devname": "lo", "physical_network": "hr_net"} + pci = devspec.PciDeviceSpec(pci_info) + self.assertFalse(pci.match(dev))