Fixed the unit tests using SQLite do not check foreign keys.

The foreign key constraints will be enabled for each SQLite
database connection.

By default the foreign key constraints are disabled in SQLite,
so some test cases failed after enabling the foreign key
constraints for unit tests. This fixings also fixed the failed
test cases because of the foreign key enforcement.

Fixes: bug #1021023
Change-Id: I89f0cbbd75bb685b50dfe6628116fa971c5e78cb
This commit is contained in:
Jason Zhang 2012-12-10 18:04:18 -08:00
parent df9186db24
commit a5c2e3006d
6 changed files with 72 additions and 54 deletions

View File

@ -22,6 +22,7 @@ import time
import sqlalchemy as sql import sqlalchemy as sql
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.exc import DisconnectionError from sqlalchemy.exc import DisconnectionError
from sqlalchemy.interfaces import PoolListener
from sqlalchemy.orm import sessionmaker, exc from sqlalchemy.orm import sessionmaker, exc
from quantum.db import model_base from quantum.db import model_base
@ -56,6 +57,18 @@ class MySQLPingListener(object):
raise raise
class SqliteForeignKeysListener(PoolListener):
"""
Ensures that the foreign key constraints are enforced in SQLite.
The foreign key constraints are disabled by default in SQLite,
so the foreign key constraints will be enabled here for every
database connection
"""
def connect(self, dbapi_con, con_record):
dbapi_con.execute('pragma foreign_keys=ON')
def configure_db(options): def configure_db(options):
""" """
Establish the database, create an engine if needed, and Establish the database, create an engine if needed, and
@ -74,6 +87,8 @@ def configure_db(options):
if 'mysql' in connection_dict.drivername: if 'mysql' in connection_dict.drivername:
engine_args['listeners'] = [MySQLPingListener()] engine_args['listeners'] = [MySQLPingListener()]
if 'sqlite' in connection_dict.drivername:
engine_args['listeners'] = [SqliteForeignKeysListener()]
_ENGINE = create_engine(options['sql_connection'], **engine_args) _ENGINE = create_engine(options['sql_connection'], **engine_args)
base = options.get('base', BASE) base = options.get('base', BASE)

View File

@ -18,6 +18,7 @@ import unittest2
from quantum.common import exceptions as q_exc from quantum.common import exceptions as q_exc
from quantum.db import api as db from quantum.db import api as db
from quantum.plugins.linuxbridge.db import l2network_db_v2 as lb_db from quantum.plugins.linuxbridge.db import l2network_db_v2 as lb_db
from quantum.tests.unit import test_db_plugin as test_plugin
PHYS_NET = 'physnet1' PHYS_NET = 'physnet1'
PHYS_NET_2 = 'physnet2' PHYS_NET_2 = 'physnet2'
@ -26,7 +27,6 @@ VLAN_MAX = 19
VLAN_RANGES = {PHYS_NET: [(VLAN_MIN, VLAN_MAX)]} VLAN_RANGES = {PHYS_NET: [(VLAN_MIN, VLAN_MAX)]}
UPDATED_VLAN_RANGES = {PHYS_NET: [(VLAN_MIN + 5, VLAN_MAX + 5)], UPDATED_VLAN_RANGES = {PHYS_NET: [(VLAN_MIN + 5, VLAN_MAX + 5)],
PHYS_NET_2: [(VLAN_MIN + 20, VLAN_MAX + 20)]} PHYS_NET_2: [(VLAN_MIN + 20, VLAN_MAX + 20)]}
TEST_NETWORK_ID = 'abcdefghijklmnopqrstuvwxyz'
class NetworkStatesTest(unittest2.TestCase): class NetworkStatesTest(unittest2.TestCase):
@ -144,21 +144,21 @@ class NetworkStatesTest(unittest2.TestCase):
self.assertIsNone(lb_db.get_network_state(PHYS_NET, vlan_id)) self.assertIsNone(lb_db.get_network_state(PHYS_NET, vlan_id))
class NetworkBindingsTest(unittest2.TestCase): class NetworkBindingsTest(test_plugin.QuantumDbPluginV2TestCase):
def setUp(self): def setUp(self):
super(NetworkBindingsTest, self).setUp()
lb_db.initialize() lb_db.initialize()
self.session = db.get_session() self.session = db.get_session()
def tearDown(self):
db.clear_db()
def test_add_network_binding(self): def test_add_network_binding(self):
self.assertIsNone(lb_db.get_network_binding(self.session, with self.network() as network:
TEST_NETWORK_ID)) TEST_NETWORK_ID = network['network']['id']
lb_db.add_network_binding(self.session, TEST_NETWORK_ID, PHYS_NET, self.assertIsNone(lb_db.get_network_binding(self.session,
1234) TEST_NETWORK_ID))
binding = lb_db.get_network_binding(self.session, TEST_NETWORK_ID) lb_db.add_network_binding(self.session, TEST_NETWORK_ID, PHYS_NET,
self.assertIsNotNone(binding) 1234)
self.assertEqual(binding.network_id, TEST_NETWORK_ID) binding = lb_db.get_network_binding(self.session, TEST_NETWORK_ID)
self.assertEqual(binding.physical_network, PHYS_NET) self.assertIsNotNone(binding)
self.assertEqual(binding.vlan_id, 1234) self.assertEqual(binding.network_id, TEST_NETWORK_ID)
self.assertEqual(binding.physical_network, PHYS_NET)
self.assertEqual(binding.vlan_id, 1234)

View File

@ -31,6 +31,7 @@ from quantum.extensions.flavor import (FLAVOR_NETWORK, FLAVOR_ROUTER)
from quantum.extensions import l3 from quantum.extensions import l3
from quantum.openstack.common import cfg from quantum.openstack.common import cfg
from quantum.openstack.common import uuidutils from quantum.openstack.common import uuidutils
from quantum.plugins.metaplugin.meta_quantum_plugin import FlavorNotFound
from quantum.plugins.metaplugin.meta_quantum_plugin import MetaPluginV2 from quantum.plugins.metaplugin.meta_quantum_plugin import MetaPluginV2
from quantum.plugins.metaplugin.proxy_quantum_plugin import ProxyPluginV2 from quantum.plugins.metaplugin.proxy_quantum_plugin import ProxyPluginV2
from quantum.tests.unit.metaplugin import fake_plugin from quantum.tests.unit.metaplugin import fake_plugin
@ -294,7 +295,7 @@ class MetaQuantumPluginV2Test(unittest.TestCase):
self.plugin.delete_router(self.context, router_ret1['id']) self.plugin.delete_router(self.context, router_ret1['id'])
self.plugin.delete_router(self.context, router_ret2['id']) self.plugin.delete_router(self.context, router_ret2['id'])
with self.assertRaises(l3.RouterNotFound): with self.assertRaises(FlavorNotFound):
self.plugin.get_router(self.context, router_ret1['id']) self.plugin.get_router(self.context, router_ret1['id'])
def test_extension_method(self): def test_extension_method(self):

View File

@ -18,6 +18,7 @@ import unittest2
from quantum.common import exceptions as q_exc from quantum.common import exceptions as q_exc
from quantum.db import api as db from quantum.db import api as db
from quantum.plugins.openvswitch import ovs_db_v2 from quantum.plugins.openvswitch import ovs_db_v2
from quantum.tests.unit import test_db_plugin as test_plugin
PHYS_NET = 'physnet1' PHYS_NET = 'physnet1'
PHYS_NET_2 = 'physnet2' PHYS_NET_2 = 'physnet2'
@ -30,7 +31,6 @@ TUN_MIN = 100
TUN_MAX = 109 TUN_MAX = 109
TUNNEL_RANGES = [(TUN_MIN, TUN_MAX)] TUNNEL_RANGES = [(TUN_MIN, TUN_MAX)]
UPDATED_TUNNEL_RANGES = [(TUN_MIN + 5, TUN_MAX + 5)] UPDATED_TUNNEL_RANGES = [(TUN_MIN + 5, TUN_MAX + 5)]
TEST_NETWORK_ID = 'abcdefghijklmnopqrstuvwxyz'
class VlanAllocationsTest(unittest2.TestCase): class VlanAllocationsTest(unittest2.TestCase):
@ -242,22 +242,23 @@ class TunnelAllocationsTest(unittest2.TestCase):
self.assertIsNone(ovs_db_v2.get_tunnel_allocation(tunnel_id)) self.assertIsNone(ovs_db_v2.get_tunnel_allocation(tunnel_id))
class NetworkBindingsTest(unittest2.TestCase): class NetworkBindingsTest(test_plugin.QuantumDbPluginV2TestCase):
def setUp(self): def setUp(self):
super(NetworkBindingsTest, self).setUp()
ovs_db_v2.initialize() ovs_db_v2.initialize()
self.session = db.get_session() self.session = db.get_session()
def tearDown(self):
db.clear_db()
def test_add_network_binding(self): def test_add_network_binding(self):
self.assertIsNone(ovs_db_v2.get_network_binding(self.session, with self.network() as network:
TEST_NETWORK_ID)) TEST_NETWORK_ID = network['network']['id']
ovs_db_v2.add_network_binding(self.session, TEST_NETWORK_ID, 'vlan', self.assertIsNone(ovs_db_v2.get_network_binding(self.session,
PHYS_NET, 1234) TEST_NETWORK_ID))
binding = ovs_db_v2.get_network_binding(self.session, TEST_NETWORK_ID) ovs_db_v2.add_network_binding(self.session, TEST_NETWORK_ID,
self.assertIsNotNone(binding) 'vlan', PHYS_NET, 1234)
self.assertEqual(binding.network_id, TEST_NETWORK_ID) binding = ovs_db_v2.get_network_binding(self.session,
self.assertEqual(binding.network_type, 'vlan') TEST_NETWORK_ID)
self.assertEqual(binding.physical_network, PHYS_NET) self.assertIsNotNone(binding)
self.assertEqual(binding.segmentation_id, 1234) self.assertEqual(binding.network_id, TEST_NETWORK_ID)
self.assertEqual(binding.network_type, 'vlan')
self.assertEqual(binding.physical_network, PHYS_NET)
self.assertEqual(binding.segmentation_id, 1234)

View File

@ -15,6 +15,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from contextlib import nested
import operator import operator
import unittest2 import unittest2
@ -25,26 +26,18 @@ from quantum.plugins.ryu.common import config
from quantum.plugins.ryu.db import api_v2 as db_api_v2 from quantum.plugins.ryu.db import api_v2 as db_api_v2
from quantum.plugins.ryu.db import models_v2 as ryu_models_v2 from quantum.plugins.ryu.db import models_v2 as ryu_models_v2
from quantum.plugins.ryu import ofp_service_type from quantum.plugins.ryu import ofp_service_type
from quantum.tests.unit import test_db_plugin as test_plugin
class RyuDBTest(unittest2.TestCase): class RyuDBTest(test_plugin.QuantumDbPluginV2TestCase):
def setUp(self): def setUp(self):
options = {"sql_connection": 'sqlite:///:memory:'} super(RyuDBTest, self).setUp()
options.update({'base': models_v2.model_base.BASEV2})
reconnect_interval = cfg.CONF.DATABASE.reconnect_interval
options.update({"reconnect_interval": reconnect_interval})
db.configure_db(options)
self.hosts = [(cfg.CONF.OVS.openflow_controller, self.hosts = [(cfg.CONF.OVS.openflow_controller,
ofp_service_type.CONTROLLER), ofp_service_type.CONTROLLER),
(cfg.CONF.OVS.openflow_rest_api, (cfg.CONF.OVS.openflow_rest_api,
ofp_service_type.REST_API)] ofp_service_type.REST_API)]
db_api_v2.set_ofp_servers(self.hosts) db_api_v2.set_ofp_servers(self.hosts)
def tearDown(self):
db.clear_db()
cfg.CONF.reset()
def test_ofp_server(self): def test_ofp_server(self):
session = db.get_session() session = db.get_session()
servers = session.query(ryu_models_v2.OFPServer).all() servers = session.query(ryu_models_v2.OFPServer).all()
@ -61,17 +54,25 @@ class RyuDBTest(unittest2.TestCase):
def test_key_allocation(self): def test_key_allocation(self):
tunnel_key = db_api_v2.TunnelKey() tunnel_key = db_api_v2.TunnelKey()
session = db.get_session() session = db.get_session()
network_id0 = u'network-id-0' with nested(self.network('network-0'),
key0 = tunnel_key.allocate(session, network_id0) self.network('network-1')
network_id1 = u'network-id-1' ) as (network_0,
key1 = tunnel_key.allocate(session, network_id1) network_1):
key_list = tunnel_key.all_list() network_id0 = network_0['network']['id']
self.assertEqual(len(key_list), 2) key0 = tunnel_key.allocate(session, network_id0)
network_id1 = network_1['network']['id']
key1 = tunnel_key.allocate(session, network_id1)
key_list = tunnel_key.all_list()
self.assertEqual(len(key_list), 2)
expected_list = [(network_id0, key0), (network_id1, key1)] expected_list = [(network_id0, key0), (network_id1, key1)]
self.assertEqual(self._tunnel_key_sort(key_list), expected_list) self.assertEqual(self._tunnel_key_sort(key_list),
expected_list)
tunnel_key.delete(session, network_id0) tunnel_key.delete(session, network_id0)
key_list = tunnel_key.all_list() key_list = tunnel_key.all_list()
self.assertEqual(self._tunnel_key_sort(key_list), self.assertEqual(self._tunnel_key_sort(key_list),
[(network_id1, key1)]) [(network_id1, key1)])
tunnel_key.delete(session, network_id1)
self.assertEqual(tunnel_key.all_list(), [])

View File

@ -203,8 +203,8 @@ class SecurityGroupTestPlugin(db_base_plugin_v2.QuantumDbPluginV2,
def delete_port(self, context, id): def delete_port(self, context, id):
session = context.session session = context.session
with session.begin(subtransactions=True): with session.begin(subtransactions=True):
super(SecurityGroupTestPlugin, self).delete_port(context, id)
self._delete_port_security_group_bindings(context, id) self._delete_port_security_group_bindings(context, id)
super(SecurityGroupTestPlugin, self).delete_port(context, id)
def create_network(self, context, network): def create_network(self, context, network):
tenant_id = self._get_tenant_id_for_create(context, network['network']) tenant_id = self._get_tenant_id_for_create(context, network['network'])