diff --git a/compass/actions/poll_switch.py b/compass/actions/poll_switch.py index cc3ef46b..dafb125b 100644 --- a/compass/actions/poll_switch.py +++ b/compass/actions/poll_switch.py @@ -135,7 +135,7 @@ def poll_switch(poller_email, ip_addr, credentials, for switch in switches: for machine_dict in machine_dicts: - print 'add machine: %s' % machine_dict + logging.debug('add machine: %s', machine_dict) switch_api.add_switch_machine( poller, switch['id'], False, **machine_dict ) diff --git a/compass/db/api/database.py b/compass/db/api/database.py index 67e5289e..fbc9ca2b 100644 --- a/compass/db/api/database.py +++ b/compass/db/api/database.py @@ -182,7 +182,7 @@ def _setup_switch_table(switch_session): from compass.db.api import switch switch.add_switch_internal( switch_session, long(netaddr.IPAddress(setting.DEFAULT_SWITCH_IP)), - True, filters=['deny ports all'] + True, filters=['allow ports all'] ) diff --git a/compass/db/api/switch.py b/compass/db/api/switch.py index 7f6a9dbd..93a0f8c0 100644 --- a/compass/db/api/switch.py +++ b/compass/db/api/switch.py @@ -179,6 +179,20 @@ def list_switches(session, lister, **filters): def del_switch(session, deleter, switch_id, **kwargs): """Delete a switch.""" switch = utils.get_db_object(session, models.Switch, id=switch_id) + default_switch_ip_int = long(netaddr.IPAddress(setting.DEFAULT_SWITCH_IP)) + default_switch = utils.get_db_object( + session, models.Switch, + ip_int=default_switch_ip_int + ) + for switch_machine in switch.switch_machines: + machine = switch_machine.machine + if len(machine.switch_machines) <= 1: + utils.add_db_object( + session, models.SwitchMachine, + False, + default_switch.id, machine.id, + port=switch_machine.port + ) return utils.del_db_object(session, switch) @@ -241,6 +255,7 @@ def _update_switch(session, updater, switch_id, **kwargs): ) @database.run_in_session() def update_switch(session, updater, switch_id, **kwargs): + """Update fields of a switch.""" return _update_switch(session, updater, switch_id, **kwargs) @@ -260,6 +275,7 @@ def update_switch(session, updater, switch_id, **kwargs): credentials=utils.check_switch_credentials ) def patch_switch(session, updater, switch_id, **kwargs): + """Patch fields of a switch.""" return _update_switch(session, updater, switch_id, **kwargs) @@ -270,7 +286,7 @@ def patch_switch(session, updater, switch_id, **kwargs): ) @utils.wrap_to_dict(RESP_FILTERS_FIELDS) def list_switch_filters(session, lister, **filters): - """list switch filters.""" + """List switch filters.""" return utils.list_db_objects( session, models.Switch, **filters ) @@ -392,14 +408,11 @@ def _filter_vlans(vlan_filter, obj): location=utils.general_filter_callback ) @utils.wrap_to_dict(RESP_MACHINES_FIELDS) -def _filter_switch_machines(session, user, switch_machines, **filters): - if 'ip_int' in filters: - return switch_machines - else: - return [ - switch_machine for switch_machine in switch_machines - if not switch_machine.filtered - ] +def _filter_switch_machines(session, user, switch_machines): + return [ + switch_machine for switch_machine in switch_machines + if not switch_machine.filtered + ] @user_api.check_user_permission_in_session( @@ -417,14 +430,11 @@ def _filter_switch_machines(session, user, switch_machines, **filters): RESP_MACHINES_HOSTS_FIELDS, clusters=RESP_CLUSTER_FIELDS ) -def _filter_switch_machines_hosts(session, user, switch_machines, **filters): - if 'ip_int' in filters: - filtered_switch_machines = switch_machines - else: - filtered_switch_machines = [ - switch_machine for switch_machine in switch_machines - if not switch_machine.filtered - ] +def _filter_switch_machines_hosts(session, user, switch_machines): + filtered_switch_machines = [ + switch_machine for switch_machine in switch_machines + if not switch_machine.filtered + ] switch_machines_hosts = [] for switch_machine in filtered_switch_machines: machine = switch_machine.machine @@ -449,7 +459,7 @@ def list_switch_machines(session, getter, switch_id, **filters): switch_machines = get_switch_machines_internal( session, switch_id=switch_id, **filters ) - return _filter_switch_machines(session, getter, switch_machines, **filters) + return _filter_switch_machines(session, getter, switch_machines) @utils.replace_filters( @@ -465,7 +475,7 @@ def list_switchmachines(session, lister, **filters): session, **filters ) return _filter_switch_machines( - session, lister, switch_machines, **filters + session, lister, switch_machines ) @@ -479,7 +489,7 @@ def list_switch_machines_hosts(session, getter, switch_id, **filters): session, switch_id=switch_id, **filters ) return _filter_switch_machines_hosts( - session, getter, switch_machines, **filters + session, getter, switch_machines ) @@ -500,10 +510,9 @@ def list_switchmachines_hosts(session, lister, **filters): else: filtered_switch_machines = [ switch_machine for switch_machine in switch_machines - if switch_machine.switch_ip != setting.DEFAULT_SWITCH_IP ] return _filter_switch_machines_hosts( - session, lister, filtered_switch_machines, **filters + session, lister, filtered_switch_machines ) @@ -537,23 +546,12 @@ def add_switch_machine( session, models.Machine, False, mac, **machine_dict) - switches = [switch] - if switch.ip != setting.DEFAULT_SWITCH_IP: - switches.append(utils.get_db_object( - session, models.Switch, - ip_int=long(netaddr.IPAddress(setting.DEFAULT_SWITCH_IP)) - )) - - switch_machines = [] - for machine_switch in switches: - switch_machines.append(utils.add_db_object( - session, models.SwitchMachine, - exception_when_existing, - machine_switch.id, machine.id, - **switch_machine_dict - )) - - return switch_machines[0] + return utils.add_db_object( + session, models.SwitchMachine, + exception_when_existing, + switch.id, machine.id, + **switch_machine_dict + ) @utils.supported_filters(optional_support_keys=['find_machines']) @@ -740,11 +738,24 @@ def patch_switchmachine(session, updater, switch_machine_id, **kwargs): ) @utils.wrap_to_dict(RESP_MACHINES_FIELDS) def del_switch_machine(session, deleter, switch_id, machine_id, **kwargs): - """Delete switch machines.""" + """Delete switch machine by switch id and machine id.""" switch_machine = utils.get_db_object( session, models.SwitchMachine, switch_id=switch_id, machine_id=machine_id ) + default_switch_ip_int = long(netaddr.IPAddress(setting.DEFAULT_SWITCH_IP)) + default_switch = utils.get_db_object( + session, models.Switch, + ip_int=default_switch_ip_int + ) + machine = switch_machine.machine + if len(machine.switch_machines) <= 1: + utils.add_db_object( + session, models.SwitchMachine, + False, + default_switch.id, machine.id, + port=switch_machine.port + ) return utils.del_db_object(session, switch_machine) @@ -755,11 +766,24 @@ def del_switch_machine(session, deleter, switch_id, machine_id, **kwargs): ) @utils.wrap_to_dict(RESP_MACHINES_FIELDS) def del_switchmachine(session, deleter, switch_machine_id, **kwargs): - """Delete switch machines.""" + """Delete switch machine by switch_machine_id.""" switch_machine = utils.get_db_object( session, models.SwitchMachine, switch_machine_id=switch_machine_id ) + default_switch_ip_int = long(netaddr.IPAddress(setting.DEFAULT_SWITCH_IP)) + default_switch = utils.get_db_object( + session, models.Switch, + ip_int=default_switch_ip_int + ) + machine = switch_machine.machine + if len(machine.switch_machines) <= 1: + utils.add_db_object( + session, models.SwitchMachine, + False, + default_switch.id, machine.id, + port=switch_machine.port + ) return utils.del_db_object(session, switch_machine) diff --git a/compass/tests/api/test_api.py b/compass/tests/api/test_api.py index 29c5e66b..88fb4f4d 100644 --- a/compass/tests/api/test_api.py +++ b/compass/tests/api/test_api.py @@ -577,6 +577,7 @@ class TestSwitchAPI(ApiTestCase): url = '/switches' return_value = self.get(url) resp = json.loads(return_value.get_data()) + print 'list switches: %s' % resp count = len(resp) self.assertEqual(count, 2) self.assertEqual(return_value.status_code, 200) diff --git a/compass/tests/db/api/test_user.py b/compass/tests/db/api/test_user.py index fd336663..8bfc7f05 100644 --- a/compass/tests/db/api/test_user.py +++ b/compass/tests/db/api/test_user.py @@ -224,7 +224,6 @@ class TestUpdateUser(BaseTest): is_admin=False ) user_object = user_api.get_user_object('dummy@abc.com') - print 'user object: %s' % user_object self.assertRaises( exception.Forbidden, user_api.update_user,