diff --git a/quantum/db/api.py b/quantum/db/api.py index 67480a2cf3..caba1a7200 100644 --- a/quantum/db/api.py +++ b/quantum/db/api.py @@ -22,6 +22,7 @@ import time import sqlalchemy as sql from sqlalchemy import create_engine from sqlalchemy.exc import DisconnectionError +from sqlalchemy.interfaces import PoolListener from sqlalchemy.orm import sessionmaker, exc from quantum.db import model_base @@ -56,6 +57,18 @@ class MySQLPingListener(object): raise +class SqliteForeignKeysListener(PoolListener): + """ + Ensures that the foreign key constraints are enforced in SQLite. + + The foreign key constraints are disabled by default in SQLite, + so the foreign key constraints will be enabled here for every + database connection + """ + def connect(self, dbapi_con, con_record): + dbapi_con.execute('pragma foreign_keys=ON') + + def configure_db(options): """ Establish the database, create an engine if needed, and @@ -74,6 +87,8 @@ def configure_db(options): if 'mysql' in connection_dict.drivername: engine_args['listeners'] = [MySQLPingListener()] + if 'sqlite' in connection_dict.drivername: + engine_args['listeners'] = [SqliteForeignKeysListener()] _ENGINE = create_engine(options['sql_connection'], **engine_args) base = options.get('base', BASE) diff --git a/quantum/tests/unit/linuxbridge/test_lb_db.py b/quantum/tests/unit/linuxbridge/test_lb_db.py index f24fad2f08..a497fea838 100644 --- a/quantum/tests/unit/linuxbridge/test_lb_db.py +++ b/quantum/tests/unit/linuxbridge/test_lb_db.py @@ -18,6 +18,7 @@ import unittest2 from quantum.common import exceptions as q_exc from quantum.db import api as db from quantum.plugins.linuxbridge.db import l2network_db_v2 as lb_db +from quantum.tests.unit import test_db_plugin as test_plugin PHYS_NET = 'physnet1' PHYS_NET_2 = 'physnet2' @@ -26,7 +27,6 @@ VLAN_MAX = 19 VLAN_RANGES = {PHYS_NET: [(VLAN_MIN, VLAN_MAX)]} UPDATED_VLAN_RANGES = {PHYS_NET: [(VLAN_MIN + 5, VLAN_MAX + 5)], PHYS_NET_2: [(VLAN_MIN + 20, VLAN_MAX + 20)]} -TEST_NETWORK_ID = 'abcdefghijklmnopqrstuvwxyz' class NetworkStatesTest(unittest2.TestCase): @@ -144,21 +144,21 @@ class NetworkStatesTest(unittest2.TestCase): self.assertIsNone(lb_db.get_network_state(PHYS_NET, vlan_id)) -class NetworkBindingsTest(unittest2.TestCase): +class NetworkBindingsTest(test_plugin.QuantumDbPluginV2TestCase): def setUp(self): + super(NetworkBindingsTest, self).setUp() lb_db.initialize() self.session = db.get_session() - def tearDown(self): - db.clear_db() - def test_add_network_binding(self): - self.assertIsNone(lb_db.get_network_binding(self.session, - TEST_NETWORK_ID)) - lb_db.add_network_binding(self.session, TEST_NETWORK_ID, PHYS_NET, - 1234) - binding = lb_db.get_network_binding(self.session, TEST_NETWORK_ID) - self.assertIsNotNone(binding) - self.assertEqual(binding.network_id, TEST_NETWORK_ID) - self.assertEqual(binding.physical_network, PHYS_NET) - self.assertEqual(binding.vlan_id, 1234) + with self.network() as network: + TEST_NETWORK_ID = network['network']['id'] + self.assertIsNone(lb_db.get_network_binding(self.session, + TEST_NETWORK_ID)) + lb_db.add_network_binding(self.session, TEST_NETWORK_ID, PHYS_NET, + 1234) + binding = lb_db.get_network_binding(self.session, TEST_NETWORK_ID) + self.assertIsNotNone(binding) + self.assertEqual(binding.network_id, TEST_NETWORK_ID) + self.assertEqual(binding.physical_network, PHYS_NET) + self.assertEqual(binding.vlan_id, 1234) diff --git a/quantum/tests/unit/metaplugin/test_metaplugin.py b/quantum/tests/unit/metaplugin/test_metaplugin.py index 3d42923177..ac55754521 100644 --- a/quantum/tests/unit/metaplugin/test_metaplugin.py +++ b/quantum/tests/unit/metaplugin/test_metaplugin.py @@ -31,6 +31,7 @@ from quantum.extensions.flavor import (FLAVOR_NETWORK, FLAVOR_ROUTER) from quantum.extensions import l3 from quantum.openstack.common import cfg from quantum.openstack.common import uuidutils +from quantum.plugins.metaplugin.meta_quantum_plugin import FlavorNotFound from quantum.plugins.metaplugin.meta_quantum_plugin import MetaPluginV2 from quantum.plugins.metaplugin.proxy_quantum_plugin import ProxyPluginV2 from quantum.tests.unit.metaplugin import fake_plugin @@ -294,7 +295,7 @@ class MetaQuantumPluginV2Test(unittest.TestCase): self.plugin.delete_router(self.context, router_ret1['id']) self.plugin.delete_router(self.context, router_ret2['id']) - with self.assertRaises(l3.RouterNotFound): + with self.assertRaises(FlavorNotFound): self.plugin.get_router(self.context, router_ret1['id']) def test_extension_method(self): diff --git a/quantum/tests/unit/openvswitch/test_ovs_db.py b/quantum/tests/unit/openvswitch/test_ovs_db.py index 8ded7d8884..b707c5a971 100644 --- a/quantum/tests/unit/openvswitch/test_ovs_db.py +++ b/quantum/tests/unit/openvswitch/test_ovs_db.py @@ -18,6 +18,7 @@ import unittest2 from quantum.common import exceptions as q_exc from quantum.db import api as db from quantum.plugins.openvswitch import ovs_db_v2 +from quantum.tests.unit import test_db_plugin as test_plugin PHYS_NET = 'physnet1' PHYS_NET_2 = 'physnet2' @@ -30,7 +31,6 @@ TUN_MIN = 100 TUN_MAX = 109 TUNNEL_RANGES = [(TUN_MIN, TUN_MAX)] UPDATED_TUNNEL_RANGES = [(TUN_MIN + 5, TUN_MAX + 5)] -TEST_NETWORK_ID = 'abcdefghijklmnopqrstuvwxyz' class VlanAllocationsTest(unittest2.TestCase): @@ -242,22 +242,23 @@ class TunnelAllocationsTest(unittest2.TestCase): self.assertIsNone(ovs_db_v2.get_tunnel_allocation(tunnel_id)) -class NetworkBindingsTest(unittest2.TestCase): +class NetworkBindingsTest(test_plugin.QuantumDbPluginV2TestCase): def setUp(self): + super(NetworkBindingsTest, self).setUp() ovs_db_v2.initialize() self.session = db.get_session() - def tearDown(self): - db.clear_db() - def test_add_network_binding(self): - self.assertIsNone(ovs_db_v2.get_network_binding(self.session, - TEST_NETWORK_ID)) - ovs_db_v2.add_network_binding(self.session, TEST_NETWORK_ID, 'vlan', - PHYS_NET, 1234) - binding = ovs_db_v2.get_network_binding(self.session, TEST_NETWORK_ID) - self.assertIsNotNone(binding) - self.assertEqual(binding.network_id, TEST_NETWORK_ID) - self.assertEqual(binding.network_type, 'vlan') - self.assertEqual(binding.physical_network, PHYS_NET) - self.assertEqual(binding.segmentation_id, 1234) + with self.network() as network: + TEST_NETWORK_ID = network['network']['id'] + self.assertIsNone(ovs_db_v2.get_network_binding(self.session, + TEST_NETWORK_ID)) + ovs_db_v2.add_network_binding(self.session, TEST_NETWORK_ID, + 'vlan', PHYS_NET, 1234) + binding = ovs_db_v2.get_network_binding(self.session, + TEST_NETWORK_ID) + self.assertIsNotNone(binding) + self.assertEqual(binding.network_id, TEST_NETWORK_ID) + self.assertEqual(binding.network_type, 'vlan') + self.assertEqual(binding.physical_network, PHYS_NET) + self.assertEqual(binding.segmentation_id, 1234) diff --git a/quantum/tests/unit/ryu/test_ryu_db.py b/quantum/tests/unit/ryu/test_ryu_db.py index f09f95830f..b7ebf3080c 100644 --- a/quantum/tests/unit/ryu/test_ryu_db.py +++ b/quantum/tests/unit/ryu/test_ryu_db.py @@ -15,6 +15,7 @@ # License for the specific language governing permissions and limitations # under the License. +from contextlib import nested import operator import unittest2 @@ -25,26 +26,18 @@ from quantum.plugins.ryu.common import config from quantum.plugins.ryu.db import api_v2 as db_api_v2 from quantum.plugins.ryu.db import models_v2 as ryu_models_v2 from quantum.plugins.ryu import ofp_service_type +from quantum.tests.unit import test_db_plugin as test_plugin -class RyuDBTest(unittest2.TestCase): +class RyuDBTest(test_plugin.QuantumDbPluginV2TestCase): def setUp(self): - options = {"sql_connection": 'sqlite:///:memory:'} - options.update({'base': models_v2.model_base.BASEV2}) - reconnect_interval = cfg.CONF.DATABASE.reconnect_interval - options.update({"reconnect_interval": reconnect_interval}) - db.configure_db(options) - + super(RyuDBTest, self).setUp() self.hosts = [(cfg.CONF.OVS.openflow_controller, ofp_service_type.CONTROLLER), (cfg.CONF.OVS.openflow_rest_api, ofp_service_type.REST_API)] db_api_v2.set_ofp_servers(self.hosts) - def tearDown(self): - db.clear_db() - cfg.CONF.reset() - def test_ofp_server(self): session = db.get_session() servers = session.query(ryu_models_v2.OFPServer).all() @@ -61,17 +54,25 @@ class RyuDBTest(unittest2.TestCase): def test_key_allocation(self): tunnel_key = db_api_v2.TunnelKey() session = db.get_session() - network_id0 = u'network-id-0' - key0 = tunnel_key.allocate(session, network_id0) - network_id1 = u'network-id-1' - key1 = tunnel_key.allocate(session, network_id1) - key_list = tunnel_key.all_list() - self.assertEqual(len(key_list), 2) + with nested(self.network('network-0'), + self.network('network-1') + ) as (network_0, + network_1): + network_id0 = network_0['network']['id'] + key0 = tunnel_key.allocate(session, network_id0) + network_id1 = network_1['network']['id'] + key1 = tunnel_key.allocate(session, network_id1) + key_list = tunnel_key.all_list() + self.assertEqual(len(key_list), 2) - expected_list = [(network_id0, key0), (network_id1, key1)] - self.assertEqual(self._tunnel_key_sort(key_list), expected_list) + expected_list = [(network_id0, key0), (network_id1, key1)] + self.assertEqual(self._tunnel_key_sort(key_list), + expected_list) - tunnel_key.delete(session, network_id0) - key_list = tunnel_key.all_list() - self.assertEqual(self._tunnel_key_sort(key_list), - [(network_id1, key1)]) + tunnel_key.delete(session, network_id0) + key_list = tunnel_key.all_list() + self.assertEqual(self._tunnel_key_sort(key_list), + [(network_id1, key1)]) + + tunnel_key.delete(session, network_id1) + self.assertEqual(tunnel_key.all_list(), []) diff --git a/quantum/tests/unit/test_extension_security_group.py b/quantum/tests/unit/test_extension_security_group.py index aec08503c3..833cc5f81b 100644 --- a/quantum/tests/unit/test_extension_security_group.py +++ b/quantum/tests/unit/test_extension_security_group.py @@ -203,8 +203,8 @@ class SecurityGroupTestPlugin(db_base_plugin_v2.QuantumDbPluginV2, def delete_port(self, context, id): session = context.session with session.begin(subtransactions=True): - super(SecurityGroupTestPlugin, self).delete_port(context, id) self._delete_port_security_group_bindings(context, id) + super(SecurityGroupTestPlugin, self).delete_port(context, id) def create_network(self, context, network): tenant_id = self._get_tenant_id_for_create(context, network['network'])