Merge "Passing user variable as kwarg." into dev/experimental

This commit is contained in:
Jenkins 2015-02-07 07:01:41 +00:00 committed by Gerrit Code Review
commit be06a81171
30 changed files with 1425 additions and 1366 deletions

View File

@ -61,7 +61,7 @@ def clean_installers():
if package_installer if package_installer
] ]
user = user_api.get_user_object(setting.COMPASS_ADMIN_EMAIL) user = user_api.get_user_object(setting.COMPASS_ADMIN_EMAIL)
adapters = adapter_api.list_adapters(user) adapters = adapter_api.list_adapters(user=user)
filtered_os_installers = {} filtered_os_installers = {}
filtered_package_installers = {} filtered_package_installers = {}
for adapter in adapters: for adapter in adapters:

View File

@ -56,13 +56,13 @@ def delete_clusters():
if clusternames: if clusternames:
list_cluster_args['name'] = clusternames list_cluster_args['name'] = clusternames
clusters = cluster_api.list_clusters( clusters = cluster_api.list_clusters(
user, **list_cluster_args user=user, **list_cluster_args
) )
delete_underlying_host = flags.OPTIONS.delete_hosts delete_underlying_host = flags.OPTIONS.delete_hosts
for cluster in clusters: for cluster in clusters:
cluster_id = cluster['id'] cluster_id = cluster['id']
cluster_api.del_cluster( cluster_api.del_cluster(
user, cluster_id, True, False, delete_underlying_host cluster_id, True, False, delete_underlying_host, user=user
) )

View File

@ -148,7 +148,7 @@ def set_switch_machines():
switch_mapping = {} switch_mapping = {}
for switch in switches: for switch in switches:
added_switch = switch_api.add_switch( added_switch = switch_api.add_switch(
user, False, **switch False, user=user, **switch
) )
switch_mapping[switch['ip']] = added_switch['id'] switch_mapping[switch['ip']] = added_switch['id']
for switch_ip, machines in switch_machines.items(): for switch_ip, machines in switch_machines.items():
@ -158,7 +158,7 @@ def set_switch_machines():
switch_id = switch_mapping[switch_ip] switch_id = switch_mapping[switch_ip]
for machine in machines: for machine in machines:
switch_api.add_switch_machine( switch_api.add_switch_machine(
user, switch_id, False, **machine switch_id, False, user=user, **machine
) )

View File

@ -62,7 +62,7 @@ def pollswitches(switch_ips):
poll_switches = [] poll_switches = []
all_switches = dict([ all_switches = dict([
(switch['ip'], switch['credentials']) (switch['ip'], switch['credentials'])
for switch in switch_api.list_switches(user) for switch in switch_api.list_switches(user=user)
]) ])
if switch_ips: if switch_ips:
poll_switches = dict([ poll_switches = dict([

View File

@ -128,7 +128,7 @@ def poll_switch(poller_email, ip_addr, credentials,
switch_dict, machine_dicts = _poll_switch( switch_dict, machine_dicts = _poll_switch(
ip_addr, credentials, req_obj=req_obj, oper=oper ip_addr, credentials, req_obj=req_obj, oper=oper
) )
switches = switch_api.list_switches(poller, ip_int=ip_int) switches = switch_api.list_switches(ip_int=ip_int, user=poller)
if not switches: if not switches:
logging.error('no switch found for %s', ip_addr) logging.error('no switch found for %s', ip_addr)
return return
@ -137,6 +137,10 @@ def poll_switch(poller_email, ip_addr, credentials,
for machine_dict in machine_dicts: for machine_dict in machine_dicts:
logging.debug('add machine: %s', machine_dict) logging.debug('add machine: %s', machine_dict)
switch_api.add_switch_machine( switch_api.add_switch_machine(
poller, switch['id'], False, **machine_dict switch['id'], False, user=poller, **machine_dict
)
switch_api.update_switch(
switch['id'],
user=poller,
**switch_dict
) )
switch_api.update_switch(poller, switch['id'], **switch_dict)

View File

@ -55,7 +55,7 @@ def update_progress():
logging.info('update installing progress') logging.info('update installing progress')
user = user_api.get_user_object(setting.COMPASS_ADMIN_EMAIL) user = user_api.get_user_object(setting.COMPASS_ADMIN_EMAIL)
hosts = host_api.list_hosts(user) hosts = host_api.list_hosts(user=user)
host_mapping = {} host_mapping = {}
for host in hosts: for host in hosts:
if 'id' not in host: if 'id' not in host:
@ -74,13 +74,13 @@ def update_progress():
'%s is not in host %s', host_dirname, host '%s is not in host %s', host_dirname, host
) )
continue continue
host_state = host_api.get_host_state(user, host_id) host_state = host_api.get_host_state(host_id, user=user)
if 'state' not in host_state: if 'state' not in host_state:
logging.error('state is not in host state %s', host_state) logging.error('state is not in host state %s', host_state)
continue continue
if host_state['state'] == 'INSTALLING': if host_state['state'] == 'INSTALLING':
host_log_histories = host_api.get_host_log_histories( host_log_histories = host_api.get_host_log_histories(
user, host_id host_id, user=user
) )
host_log_history_mapping = {} host_log_history_mapping = {}
for host_log_history in host_log_histories: for host_log_history in host_log_histories:
@ -101,7 +101,7 @@ def update_progress():
'ignore host state %s since it is not in installing', 'ignore host state %s since it is not in installing',
host_state host_state
) )
adapters = adapter_api.list_adapters(user) adapters = adapter_api.list_adapters(user=user)
adapter_mapping = {} adapter_mapping = {}
for adapter in adapters: for adapter in adapters:
if 'id' not in adapter: if 'id' not in adapter:
@ -116,7 +116,7 @@ def update_progress():
continue continue
adapter_id = adapter['id'] adapter_id = adapter['id']
adapter_mapping[adapter_id] = adapter adapter_mapping[adapter_id] = adapter
clusters = cluster_api.list_clusters(user) clusters = cluster_api.list_clusters(user=user)
cluster_mapping = {} cluster_mapping = {}
for cluster in clusters: for cluster in clusters:
if 'id' not in cluster: if 'id' not in cluster:
@ -129,12 +129,15 @@ def update_progress():
cluster cluster
) )
continue continue
cluster_state = cluster_api.get_cluster_state(user, cluster_id) cluster_state = cluster_api.get_cluster_state(
cluster_id,
user=user
)
if 'state' not in cluster_state: if 'state' not in cluster_state:
logging.error('state not in cluster state %s', cluster_state) logging.error('state not in cluster state %s', cluster_state)
continue continue
cluster_mapping[cluster_id] = (cluster, cluster_state) cluster_mapping[cluster_id] = (cluster, cluster_state)
clusterhosts = cluster_api.list_clusterhosts(user) clusterhosts = cluster_api.list_clusterhosts(user=user)
clusterhost_mapping = {} clusterhost_mapping = {}
for clusterhost in clusterhosts: for clusterhost in clusterhosts:
if 'clusterhost_id' not in clusterhost: if 'clusterhost_id' not in clusterhost:
@ -194,7 +197,7 @@ def update_progress():
package_installer = adapter['package_installer'] package_installer = adapter['package_installer']
clusterhost['package_installer'] = package_installer clusterhost['package_installer'] = package_installer
clusterhost_state = cluster_api.get_clusterhost_self_state( clusterhost_state = cluster_api.get_clusterhost_self_state(
user, clusterhost_id clusterhost_id, user=user
) )
if 'state' not in clusterhost_state: if 'state' not in clusterhost_state:
logging.error( logging.error(
@ -205,7 +208,7 @@ def update_progress():
if clusterhost_state['state'] == 'INSTALLING': if clusterhost_state['state'] == 'INSTALLING':
clusterhost_log_histories = ( clusterhost_log_histories = (
cluster_api.get_clusterhost_log_histories( cluster_api.get_clusterhost_log_histories(
user, clusterhost_id clusterhost_id, user=user
) )
) )
clusterhost_log_history_mapping = {} clusterhost_log_history_mapping = {}
@ -236,7 +239,7 @@ def update_progress():
host_mapping.items() host_mapping.items()
): ):
host_api.update_host_state( host_api.update_host_state(
user, host_id, host_id, user=user,
percentage=host_state.get('percentage', 0), percentage=host_state.get('percentage', 0),
message=host_state.get('message', ''), message=host_state.get('message', ''),
severity=host_state.get('severity', 'INFO') severity=host_state.get('severity', 'INFO')
@ -245,7 +248,7 @@ def update_progress():
host_log_history_mapping.items() host_log_history_mapping.items()
): ):
host_api.add_host_log_history( host_api.add_host_log_history(
user, host_id, filename=filename, host_id, filename=filename, user=user,
position=host_log_history.get('position', 0), position=host_log_history.get('position', 0),
percentage=host_log_history.get('percentage', 0), percentage=host_log_history.get('percentage', 0),
partial_line=host_log_history.get('partial_line', ''), partial_line=host_log_history.get('partial_line', ''),
@ -264,7 +267,7 @@ def update_progress():
clusterhost_mapping.items() clusterhost_mapping.items()
): ):
cluster_api.update_clusterhost_state( cluster_api.update_clusterhost_state(
user, clusterhost_id, clusterhost_id, user=user,
percentage=clusterhost_state.get('percentage', 0), percentage=clusterhost_state.get('percentage', 0),
message=clusterhost_state.get('message', ''), message=clusterhost_state.get('message', ''),
severity=clusterhost_state.get('severity', 'INFO') severity=clusterhost_state.get('severity', 'INFO')
@ -273,7 +276,7 @@ def update_progress():
clusterhost_log_history_mapping.items() clusterhost_log_history_mapping.items()
): ):
cluster_api.add_clusterhost_log_history( cluster_api.add_clusterhost_log_history(
user, clusterhost_id, filename=filename, clusterhost_id, user=user, filename=filename,
position=clusterhost_log_history.get('position', 0), position=clusterhost_log_history.get('position', 0),
percentage=clusterhost_log_history.get('percentage', 0), percentage=clusterhost_log_history.get('percentage', 0),
partial_line=clusterhost_log_history.get( partial_line=clusterhost_log_history.get(
@ -290,5 +293,5 @@ def update_progress():
cluster_mapping) cluster_mapping)
for cluster_id, (cluster, cluster_state) in cluster_mapping.items(): for cluster_id, (cluster, cluster_state) in cluster_mapping.items():
cluster_api.update_cluster_state( cluster_api.update_cluster_state(
user, cluster_id cluster_id, user=user
) )

View File

@ -94,8 +94,8 @@ class ActionHelper(object):
} }
To view a complete output, please refer to backend doc. To view a complete output, please refer to backend doc.
""" """
adapter_info = adapter_db.get_adapter(user, adapter_id) adapter_info = adapter_db.get_adapter(adapter_id, user=user)
metadata = cluster_db.get_cluster_metadata(user, cluster_id) metadata = cluster_db.get_cluster_metadata(cluster_id, user=user)
adapter_info.update({const.METADATA: metadata}) adapter_info.update({const.METADATA: metadata})
for flavor_info in adapter_info[const.FLAVORS]: for flavor_info in adapter_info[const.FLAVORS]:
@ -128,7 +128,7 @@ class ActionHelper(object):
"owner": "xxx" "owner": "xxx"
} }
""" """
cluster_info = cluster_db.get_cluster(user, cluster_id) cluster_info = cluster_db.get_cluster(cluster_id, user=user)
# convert roles retrieved from db into a list of role names # convert roles retrieved from db into a list of role names
roles_info = cluster_info.setdefault( roles_info = cluster_info.setdefault(
@ -137,11 +137,11 @@ class ActionHelper(object):
ActionHelper._get_role_names(roles_info) ActionHelper._get_role_names(roles_info)
# get cluster config info # get cluster config info
cluster_config = cluster_db.get_cluster_config(user, cluster_id) cluster_config = cluster_db.get_cluster_config(cluster_id, user=user)
cluster_info.update(cluster_config) cluster_info.update(cluster_config)
deploy_config = cluster_db.get_cluster_deployed_config(user, deploy_config = cluster_db.get_cluster_deployed_config(cluster_id,
cluster_id) user=user)
cluster_info.update(deploy_config) cluster_info.update(deploy_config)
return cluster_info return cluster_info
@ -179,7 +179,7 @@ class ActionHelper(object):
""" """
hosts_info = {} hosts_info = {}
for host_id in hosts_id_list: for host_id in hosts_id_list:
info = cluster_db.get_cluster_host(user, cluster_id, host_id) info = cluster_db.get_cluster_host(cluster_id, host_id, user=user)
logging.debug("checking on info %r %r" % (host_id, info)) logging.debug("checking on info %r %r" % (host_id, info))
info[const.ROLES] = ActionHelper._get_role_names(info[const.ROLES]) info[const.ROLES] = ActionHelper._get_role_names(info[const.ROLES])
@ -187,9 +187,9 @@ class ActionHelper(object):
# TODO(grace): Is following line necessary?? # TODO(grace): Is following line necessary??
info.setdefault(const.ROLES, []) info.setdefault(const.ROLES, [])
config = cluster_db.get_cluster_host_config(user, config = cluster_db.get_cluster_host_config(cluster_id,
cluster_id, host_id,
host_id) user=user)
info.update(config) info.update(config)
networks = info[const.NETWORKS] networks = info[const.NETWORKS]
@ -220,26 +220,34 @@ class ActionHelper(object):
cluster_id = cluster_config[const.ID] cluster_id = cluster_config[const.ID]
del cluster_config[const.ID] del cluster_config[const.ID]
cluster_db.update_cluster_deployed_config(user, cluster_id, cluster_db.update_cluster_deployed_config(cluster_id, user=user,
**cluster_config) **cluster_config)
hosts_id_list = deployed_config[const.HOSTS].keys() hosts_id_list = deployed_config[const.HOSTS].keys()
for host_id in hosts_id_list: for host_id in hosts_id_list:
config = deployed_config[const.HOSTS][host_id] config = deployed_config[const.HOSTS][host_id]
cluster_db.update_cluster_host_deployed_config(user, cluster_db.update_cluster_host_deployed_config(cluster_id,
cluster_id,
host_id, host_id,
user=user,
**config) **config)
@staticmethod @staticmethod
def update_state(cluster_id, host_id_list, user): def update_state(cluster_id, host_id_list, user):
# update all clusterhosts state # update all clusterhosts state
for host_id in host_id_list: for host_id in host_id_list:
cluster_db.update_cluster_host_state(user, cluster_id, host_id, cluster_db.update_cluster_host_state(
state='INSTALLING') cluster_id,
host_id,
user=user,
state='INSTALLING'
)
# update cluster state # update cluster state
cluster_db.update_cluster_state(user, cluster_id, state='INSTALLING') cluster_db.update_cluster_state(
cluster_id,
user=user,
state='INSTALLING'
)
@staticmethod @staticmethod
def delete_cluster( def delete_cluster(
@ -251,7 +259,7 @@ class ActionHelper(object):
user, host_id, True, True user, host_id, True, True
) )
cluster_db.del_cluster( cluster_db.del_cluster(
user, cluster_id, True, True cluster_id, True, True, user=user
) )
@staticmethod @staticmethod
@ -260,16 +268,16 @@ class ActionHelper(object):
): ):
if delete_underlying_host: if delete_underlying_host:
host_db.del_host( host_db.del_host(
user, host_id, True, True host_id, True, True, user=user
) )
cluster_db.del_cluster_host( cluster_db.del_cluster_host(
user, cluster_id, host_id, True, True cluster_id, host_id, True, True, user=user
) )
@staticmethod @staticmethod
def delete_host(host_id, user): def delete_host(host_id, user):
host_db.del_host( host_db.del_host(
user, host_id, True, True host_id, True, True, user=user
) )
@staticmethod @staticmethod

File diff suppressed because it is too large Load Diff

View File

@ -95,7 +95,7 @@ def _filter_adapters(adapter_config, filter_name, filter_value):
roles=RESP_ROLES_FIELDS, roles=RESP_ROLES_FIELDS,
flavors=RESP_FLAVORS_FIELDS flavors=RESP_FLAVORS_FIELDS
) )
def list_adapters(lister, session=None, **filters): def list_adapters(user=None, session=None, **filters):
"""list adapters.""" """list adapters."""
if not ADAPTER_MAPPING: if not ADAPTER_MAPPING:
load_adapters_internal(session) load_adapters_internal(session)
@ -125,6 +125,6 @@ def get_adapter_internal(session, adapter_id):
roles=RESP_ROLES_FIELDS, roles=RESP_ROLES_FIELDS,
flavors=RESP_FLAVORS_FIELDS flavors=RESP_FLAVORS_FIELDS
) )
def get_adapter(getter, adapter_id, session=None, **kwargs): def get_adapter(adapter_id, user=None, session=None, **kwargs):
"""get adapter.""" """get adapter."""
return get_adapter_internal(session, adapter_id) return get_adapter_internal(session, adapter_id)

View File

@ -154,7 +154,7 @@ UPDATED_CLUSTERHOST_LOG_FIELDS = [
permission.PERMISSION_LIST_CLUSTERS permission.PERMISSION_LIST_CLUSTERS
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_clusters(lister, session=None, **filters): def list_clusters(user=None, session=None, **filters):
"""List clusters.""" """List clusters."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.Cluster, **filters session, models.Cluster, **filters
@ -168,8 +168,8 @@ def list_clusters(lister, session=None, **filters):
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def get_cluster( def get_cluster(
getter, cluster_id, cluster_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Get cluster info.""" """Get cluster info."""
return utils.get_db_object( return utils.get_db_object(
@ -243,14 +243,13 @@ def is_cluster_editable(
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def add_cluster( def add_cluster(
creator,
exception_when_existing=True, exception_when_existing=True,
name=None, session=None, **kwargs name=None, user=None, session=None, **kwargs
): ):
"""Create a cluster.""" """Create a cluster."""
return utils.add_db_object( return utils.add_db_object(
session, models.Cluster, exception_when_existing, session, models.Cluster, exception_when_existing,
name, creator_id=creator.id, name, creator_id=user.id,
**kwargs **kwargs
) )
@ -265,13 +264,13 @@ def add_cluster(
permission.PERMISSION_ADD_CLUSTER permission.PERMISSION_ADD_CLUSTER
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def update_cluster(updater, cluster_id, session=None, **kwargs): def update_cluster(cluster_id, user=None, session=None, **kwargs):
"""Update a cluster.""" """Update a cluster."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
) )
is_cluster_editable( is_cluster_editable(
session, cluster, updater, session, cluster, user,
reinstall_distributed_system_set=( reinstall_distributed_system_set=(
kwargs.get('reinstall_distributed_system', False) kwargs.get('reinstall_distributed_system', False)
) )
@ -301,9 +300,8 @@ def update_cluster(updater, cluster_id, session=None, **kwargs):
hosts=RESP_CLUSTERHOST_FIELDS hosts=RESP_CLUSTERHOST_FIELDS
) )
def del_cluster( def del_cluster(
deleter, cluster_id, cluster_id, force=False, from_database_only=False,
force=False, from_database_only=False, delete_underlying_host=False, user=None, session=None, **kwargs
delete_underlying_host=False, session=None, **kwargs
): ):
"""Delete a cluster.""" """Delete a cluster."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
@ -325,7 +323,7 @@ def del_cluster(
cluster.state.state = 'ERROR' cluster.state.state = 'ERROR'
is_cluster_editable( is_cluster_editable(
session, cluster, deleter, session, cluster, user,
reinstall_distributed_system_set=True reinstall_distributed_system_set=True
) )
@ -333,7 +331,7 @@ def del_cluster(
from compass.db.api import host as host_api from compass.db.api import host as host_api
host = clusterhost.host host = clusterhost.host
host_api.is_host_editable( host_api.is_host_editable(
session, host, deleter, reinstall_os_set=True session, host, user, reinstall_os_set=True
) )
if host.state.state == 'UNINITIALIZED' or from_database_only: if host.state.state == 'UNINITIALIZED' or from_database_only:
utils.del_db_object( utils.del_db_object(
@ -353,7 +351,7 @@ def del_cluster(
celery_client.celery.send_task( celery_client.celery.send_task(
'compass.tasks.delete_cluster', 'compass.tasks.delete_cluster',
( (
deleter.email, cluster_id, user.email, cluster_id,
[clusterhost.host_id for clusterhost in clusterhosts], [clusterhost.host_id for clusterhost in clusterhosts],
delete_underlying_host delete_underlying_host
) )
@ -371,7 +369,7 @@ def del_cluster(
permission.PERMISSION_LIST_CLUSTER_CONFIG permission.PERMISSION_LIST_CLUSTER_CONFIG
) )
@utils.wrap_to_dict(RESP_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CONFIG_FIELDS)
def get_cluster_config(getter, cluster_id, session=None, **kwargs): def get_cluster_config(cluster_id, user=None, session=None, **kwargs):
"""Get cluster config.""" """Get cluster config."""
return utils.get_db_object( return utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
@ -384,7 +382,7 @@ def get_cluster_config(getter, cluster_id, session=None, **kwargs):
permission.PERMISSION_LIST_CLUSTER_CONFIG permission.PERMISSION_LIST_CLUSTER_CONFIG
) )
@utils.wrap_to_dict(RESP_DEPLOYED_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_DEPLOYED_CONFIG_FIELDS)
def get_cluster_deployed_config(getter, cluster_id, session=None, **kwargs): def get_cluster_deployed_config(cluster_id, user=None, session=None, **kwargs):
"""Get cluster deployed config.""" """Get cluster deployed config."""
return utils.get_db_object( return utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
@ -397,7 +395,7 @@ def get_cluster_deployed_config(getter, cluster_id, session=None, **kwargs):
permission.PERMISSION_LIST_METADATAS permission.PERMISSION_LIST_METADATAS
) )
@utils.wrap_to_dict(RESP_METADATA_FIELDS) @utils.wrap_to_dict(RESP_METADATA_FIELDS)
def get_cluster_metadata(getter, cluster_id, session=None, **kwargs): def get_cluster_metadata(cluster_id, user=None, session=None, **kwargs):
"""Get cluster metadata.""" """Get cluster metadata."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
@ -419,9 +417,9 @@ def get_cluster_metadata(getter, cluster_id, session=None, **kwargs):
@utils.wrap_to_dict(RESP_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CONFIG_FIELDS)
def _update_cluster_config(session, updater, cluster, **kwargs): def _update_cluster_config(session, user, cluster, **kwargs):
"""Update a cluster config.""" """Update a cluster config."""
is_cluster_editable(session, cluster, updater) is_cluster_editable(session, cluster, user)
return utils.update_db_object( return utils.update_db_object(
session, cluster, **kwargs session, cluster, **kwargs
) )
@ -441,13 +439,13 @@ def _update_cluster_config(session, updater, cluster, **kwargs):
) )
@utils.wrap_to_dict(RESP_DEPLOYED_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_DEPLOYED_CONFIG_FIELDS)
def update_cluster_deployed_config( def update_cluster_deployed_config(
updater, cluster_id, session=None, **kwargs cluster_id, user=None, session=None, **kwargs
): ):
"""Update cluster deployed config.""" """Update cluster deployed config."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
) )
is_cluster_editable(session, cluster, updater) is_cluster_editable(session, cluster, user)
is_cluster_validated(session, cluster) is_cluster_validated(session, cluster)
return utils.update_db_object( return utils.update_db_object(
session, cluster, **kwargs session, cluster, **kwargs
@ -466,7 +464,7 @@ def update_cluster_deployed_config(
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_ADD_CLUSTER_CONFIG permission.PERMISSION_ADD_CLUSTER_CONFIG
) )
def update_cluster_config(updater, cluster_id, session=None, **kwargs): def update_cluster_config(cluster_id, user=None, session=None, **kwargs):
"""Update cluster config.""" """Update cluster config."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
@ -490,7 +488,7 @@ def update_cluster_config(updater, cluster_id, session=None, **kwargs):
cluster, **in_kwargs cluster, **in_kwargs
): ):
return _update_cluster_config( return _update_cluster_config(
session, updater, cluster, **in_kwargs session, user, cluster, **in_kwargs
) )
return update_config_internal( return update_config_internal(
@ -510,7 +508,7 @@ def update_cluster_config(updater, cluster_id, session=None, **kwargs):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_ADD_CLUSTER_CONFIG permission.PERMISSION_ADD_CLUSTER_CONFIG
) )
def patch_cluster_config(updater, cluster_id, session=None, **kwargs): def patch_cluster_config(cluster_id, user=None, session=None, **kwargs):
"""patch cluster config.""" """patch cluster config."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
@ -532,7 +530,7 @@ def patch_cluster_config(updater, cluster_id, session=None, **kwargs):
) )
def update_config_internal(cluster, **in_kwargs): def update_config_internal(cluster, **in_kwargs):
return _update_cluster_config( return _update_cluster_config(
session, updater, cluster, **in_kwargs session, user, cluster, **in_kwargs
) )
return update_config_internal( return update_config_internal(
@ -546,12 +544,12 @@ def patch_cluster_config(updater, cluster_id, session=None, **kwargs):
permission.PERMISSION_DEL_CLUSTER_CONFIG permission.PERMISSION_DEL_CLUSTER_CONFIG
) )
@utils.wrap_to_dict(RESP_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CONFIG_FIELDS)
def del_cluster_config(deleter, cluster_id, session=None): def del_cluster_config(cluster_id, user=None, session=None):
"""Delete a cluster config.""" """Delete a cluster config."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
) )
is_cluster_editable(session, cluster, deleter) is_cluster_editable(session, cluster, user)
return utils.update_db_object( return utils.update_db_object(
session, cluster, os_config={}, session, cluster, os_config={},
package_config={}, config_validated=False package_config={}, config_validated=False
@ -672,7 +670,7 @@ def _set_clusterhosts(session, cluster, machines):
permission.PERMISSION_LIST_CLUSTERHOSTS permission.PERMISSION_LIST_CLUSTERHOSTS
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS)
def list_cluster_hosts(lister, cluster_id, session=None, **filters): def list_cluster_hosts(cluster_id, user=None, session=None, **filters):
"""Get cluster host info.""" """Get cluster host info."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.ClusterHost, cluster_id=cluster_id, session, models.ClusterHost, cluster_id=cluster_id,
@ -686,7 +684,7 @@ def list_cluster_hosts(lister, cluster_id, session=None, **filters):
permission.PERMISSION_LIST_CLUSTERHOSTS permission.PERMISSION_LIST_CLUSTERHOSTS
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS)
def list_clusterhosts(lister, session=None, **filters): def list_clusterhosts(user=None, session=None, **filters):
"""Get cluster host info.""" """Get cluster host info."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.ClusterHost, **filters session, models.ClusterHost, **filters
@ -700,8 +698,8 @@ def list_clusterhosts(lister, session=None, **filters):
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS)
def get_cluster_host( def get_cluster_host(
getter, cluster_id, host_id, cluster_id, host_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Get clusterhost info.""" """Get clusterhost info."""
return utils.get_db_object( return utils.get_db_object(
@ -718,8 +716,8 @@ def get_cluster_host(
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS)
def get_clusterhost( def get_clusterhost(
getter, clusterhost_id, clusterhost_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Get clusterhost info.""" """Get clusterhost info."""
return utils.get_db_object( return utils.get_db_object(
@ -735,14 +733,14 @@ def get_clusterhost(
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS)
def add_cluster_host( def add_cluster_host(
creator, cluster_id, cluster_id, exception_when_existing=True,
exception_when_existing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Add cluster host.""" """Add cluster host."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
) )
is_cluster_editable(session, cluster, creator) is_cluster_editable(session, cluster, user)
return add_clusterhost_internal( return add_clusterhost_internal(
session, cluster, exception_when_existing, session, cluster, exception_when_existing,
**kwargs **kwargs
@ -750,7 +748,7 @@ def add_cluster_host(
@utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_FIELDS)
def _update_clusterhost(session, updater, clusterhost, **kwargs): def _update_clusterhost(session, user, clusterhost, **kwargs):
clusterhost_dict = {} clusterhost_dict = {}
host_dict = {} host_dict = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
@ -824,7 +822,7 @@ def _update_clusterhost(session, updater, clusterhost, **kwargs):
session, clusterhost, **in_kwargs session, clusterhost, **in_kwargs
) )
is_cluster_editable(session, clusterhost.cluster, updater) is_cluster_editable(session, clusterhost.cluster, user)
return update_internal( return update_internal(
clusterhost, **kwargs clusterhost, **kwargs
) )
@ -839,14 +837,14 @@ def _update_clusterhost(session, updater, clusterhost, **kwargs):
permission.PERMISSION_UPDATE_CLUSTER_HOSTS permission.PERMISSION_UPDATE_CLUSTER_HOSTS
) )
def update_cluster_host( def update_cluster_host(
updater, cluster_id, host_id, cluster_id, host_id, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Update cluster host.""" """Update cluster host."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
session, models.ClusterHost, cluster_id=cluster_id, host_id=host_id session, models.ClusterHost, cluster_id=cluster_id, host_id=host_id
) )
return _update_clusterhost(session, updater, clusterhost, **kwargs) return _update_clusterhost(session, user, clusterhost, **kwargs)
@utils.supported_filters( @utils.supported_filters(
@ -858,14 +856,14 @@ def update_cluster_host(
permission.PERMISSION_UPDATE_CLUSTER_HOSTS permission.PERMISSION_UPDATE_CLUSTER_HOSTS
) )
def update_clusterhost( def update_clusterhost(
updater, clusterhost_id, clusterhost_id, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Update cluster host.""" """Update cluster host."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
session, models.ClusterHost, clusterhost_id=clusterhost_id session, models.ClusterHost, clusterhost_id=clusterhost_id
) )
return _update_clusterhost(session, updater, clusterhost, **kwargs) return _update_clusterhost(session, user, clusterhost, **kwargs)
@utils.replace_filters( @utils.replace_filters(
@ -880,14 +878,14 @@ def update_clusterhost(
permission.PERMISSION_UPDATE_CLUSTER_HOSTS permission.PERMISSION_UPDATE_CLUSTER_HOSTS
) )
def patch_cluster_host( def patch_cluster_host(
updater, cluster_id, host_id, session=None, cluster_id, host_id, user=None,
**kwargs session=None, **kwargs
): ):
"""Update cluster host.""" """Update cluster host."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
session, models.ClusterHost, cluster_id=cluster_id, host_id=host_id session, models.ClusterHost, cluster_id=cluster_id, host_id=host_id
) )
return _update_clusterhost(session, updater, clusterhost, **kwargs) return _update_clusterhost(session, user, clusterhost, **kwargs)
@utils.replace_filters( @utils.replace_filters(
@ -902,14 +900,14 @@ def patch_cluster_host(
permission.PERMISSION_UPDATE_CLUSTER_HOSTS permission.PERMISSION_UPDATE_CLUSTER_HOSTS
) )
def patch_clusterhost( def patch_clusterhost(
updater, clusterhost_id, session=None, clusterhost_id, user=None, session=None,
**kwargs **kwargs
): ):
"""Update cluster host.""" """Update cluster host."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
session, models.ClusterHost, clusterhost_id=clusterhost_id session, models.ClusterHost, clusterhost_id=clusterhost_id
) )
return _update_clusterhost(session, updater, clusterhost, **kwargs) return _update_clusterhost(session, user, clusterhost, **kwargs)
@utils.supported_filters([]) @utils.supported_filters([])
@ -922,9 +920,9 @@ def patch_clusterhost(
host=RESP_CLUSTERHOST_FIELDS host=RESP_CLUSTERHOST_FIELDS
) )
def del_cluster_host( def del_cluster_host(
deleter, cluster_id, host_id, cluster_id, host_id,
force=False, from_database_only=False, force=False, from_database_only=False,
delete_underlying_host=False, delete_underlying_host=False, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Delete cluster host.""" """Delete cluster host."""
@ -936,7 +934,7 @@ def del_cluster_host(
clusterhost.state.state = 'ERROR' clusterhost.state.state = 'ERROR'
if not force: if not force:
is_cluster_editable( is_cluster_editable(
session, clusterhost.cluster, deleter, session, clusterhost.cluster, user,
reinstall_distributed_system_set=True reinstall_distributed_system_set=True
) )
else: else:
@ -949,7 +947,7 @@ def del_cluster_host(
host.state.state = 'ERROR' host.state.state = 'ERROR'
import compass.db.api.host as host_api import compass.db.api.host as host_api
host_api.is_host_editable( host_api.is_host_editable(
session, host, deleter, session, host, user,
reinstall_os_set=True reinstall_os_set=True
) )
if host.state.state == 'UNINITIALIZED' or from_database_only: if host.state.state == 'UNINITIALIZED' or from_database_only:
@ -970,7 +968,7 @@ def del_cluster_host(
celery_client.celery.send_task( celery_client.celery.send_task(
'compass.tasks.delete_cluster_host', 'compass.tasks.delete_cluster_host',
( (
deleter.email, cluster_id, host_id, user.email, cluster_id, host_id,
delete_underlying_host delete_underlying_host
) )
) )
@ -990,9 +988,9 @@ def del_cluster_host(
host=RESP_CLUSTERHOST_FIELDS host=RESP_CLUSTERHOST_FIELDS
) )
def del_clusterhost( def del_clusterhost(
deleter, clusterhost_id, clusterhost_id,
force=False, from_database_only=False, force=False, from_database_only=False,
delete_underlying_host=False, delete_underlying_host=False, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Delete cluster host.""" """Delete cluster host."""
@ -1004,7 +1002,7 @@ def del_clusterhost(
clusterhost.state.state = 'ERROR' clusterhost.state.state = 'ERROR'
if not force: if not force:
is_cluster_editable( is_cluster_editable(
session, clusterhost.cluster, deleter, session, clusterhost.cluster, user,
reinstall_distributed_system_set=True reinstall_distributed_system_set=True
) )
if delete_underlying_host: if delete_underlying_host:
@ -1013,7 +1011,7 @@ def del_clusterhost(
host.state.state = 'ERROR' host.state.state = 'ERROR'
import compass.db.api.host as host_api import compass.db.api.host as host_api
host_api.is_host_editable( host_api.is_host_editable(
session, host, deleter, session, host, user,
reinstall_os_set=True reinstall_os_set=True
) )
if host.state.state == 'UNINITIALIZED' or from_database_only: if host.state.state == 'UNINITIALIZED' or from_database_only:
@ -1034,7 +1032,7 @@ def del_clusterhost(
celery_client.celery.send_task( celery_client.celery.send_task(
'compass.tasks.delete_cluster_host', 'compass.tasks.delete_cluster_host',
( (
deleter.email, clusterhost.cluster_id, user.email, clusterhost.cluster_id,
clusterhost.host_id, clusterhost.host_id,
delete_underlying_host delete_underlying_host
) )
@ -1052,8 +1050,8 @@ def del_clusterhost(
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS)
def get_cluster_host_config( def get_cluster_host_config(
getter, cluster_id, cluster_id, host_id, user=None,
host_id, session=None, **kwargs session=None, **kwargs
): ):
"""Get clusterhost config.""" """Get clusterhost config."""
return utils.get_db_object( return utils.get_db_object(
@ -1069,7 +1067,7 @@ def get_cluster_host_config(
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_DEPLOYED_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_DEPLOYED_CONFIG_FIELDS)
def get_cluster_host_deployed_config( def get_cluster_host_deployed_config(
getter, cluster_id, host_id, session=None, **kwargs cluster_id, host_id, user=None, session=None, **kwargs
): ):
"""Get clusterhost deployed config.""" """Get clusterhost deployed config."""
return utils.get_db_object( return utils.get_db_object(
@ -1084,7 +1082,7 @@ def get_cluster_host_deployed_config(
permission.PERMISSION_LIST_CLUSTERHOST_CONFIG permission.PERMISSION_LIST_CLUSTERHOST_CONFIG
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS)
def get_clusterhost_config(getter, clusterhost_id, session=None, **kwargs): def get_clusterhost_config(clusterhost_id, user=None, session=None, **kwargs):
"""Get clusterhost config.""" """Get clusterhost config."""
return utils.get_db_object( return utils.get_db_object(
session, models.ClusterHost, clusterhost_id=clusterhost_id session, models.ClusterHost, clusterhost_id=clusterhost_id
@ -1098,7 +1096,7 @@ def get_clusterhost_config(getter, clusterhost_id, session=None, **kwargs):
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_DEPLOYED_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_DEPLOYED_CONFIG_FIELDS)
def get_clusterhost_deployed_config( def get_clusterhost_deployed_config(
getter, clusterhost_id, clusterhost_id, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Get clusterhost deployed config.""" """Get clusterhost deployed config."""
@ -1108,11 +1106,11 @@ def get_clusterhost_deployed_config(
@utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS)
def _update_clusterhost_config(session, updater, clusterhost, **kwargs): def _update_clusterhost_config(session, user, clusterhost, **kwargs):
from compass.db.api import host as host_api from compass.db.api import host as host_api
ignore_keys = [] ignore_keys = []
if not host_api.is_host_editable( if not host_api.is_host_editable(
session, clusterhost.host, updater, session, clusterhost.host, user,
exception_when_not_editable=False exception_when_not_editable=False
): ):
ignore_keys.append('put_os_config') ignore_keys.append('put_os_config')
@ -1124,7 +1122,7 @@ def _update_clusterhost_config(session, updater, clusterhost, **kwargs):
def package_config_validates(package_config): def package_config_validates(package_config):
cluster = clusterhost.cluster cluster = clusterhost.cluster
is_cluster_editable(session, cluster, updater) is_cluster_editable(session, cluster, user)
metadata_api.validate_package_config( metadata_api.validate_package_config(
session, package_config, cluster.adapter_id session, package_config, cluster.adapter_id
) )
@ -1149,12 +1147,12 @@ def _update_clusterhost_config(session, updater, clusterhost, **kwargs):
@utils.wrap_to_dict(RESP_CLUSTERHOST_DEPLOYED_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_DEPLOYED_CONFIG_FIELDS)
def _update_clusterhost_deployed_config( def _update_clusterhost_deployed_config(
session, updater, clusterhost, **kwargs session, user, clusterhost, **kwargs
): ):
from compass.db.api import host as host_api from compass.db.api import host as host_api
ignore_keys = [] ignore_keys = []
if not host_api.is_host_editable( if not host_api.is_host_editable(
session, clusterhost.host, updater, session, clusterhost.host, user,
exception_when_not_editable=False exception_when_not_editable=False
): ):
ignore_keys.append('deployed_os_config') ignore_keys.append('deployed_os_config')
@ -1165,7 +1163,7 @@ def _update_clusterhost_deployed_config(
def package_config_validates(package_config): def package_config_validates(package_config):
cluster = clusterhost.cluster cluster = clusterhost.cluster
is_cluster_editable(session, cluster, updater) is_cluster_editable(session, cluster, user)
is_clusterhost_validated(session, clusterhost) is_clusterhost_validated(session, clusterhost)
@utils.supported_filters( @utils.supported_filters(
@ -1195,7 +1193,7 @@ def _update_clusterhost_deployed_config(
permission.PERMISSION_ADD_CLUSTERHOST_CONFIG permission.PERMISSION_ADD_CLUSTERHOST_CONFIG
) )
def update_cluster_host_config( def update_cluster_host_config(
updater, cluster_id, host_id, session=None, **kwargs cluster_id, host_id, user=None, session=None, **kwargs
): ):
"""Update clusterhost config.""" """Update clusterhost config."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
@ -1203,7 +1201,7 @@ def update_cluster_host_config(
cluster_id=cluster_id, host_id=host_id cluster_id=cluster_id, host_id=host_id
) )
return _update_clusterhost_config( return _update_clusterhost_config(
session, updater, clusterhost, **kwargs session, user, clusterhost, **kwargs
) )
@ -1216,7 +1214,7 @@ def update_cluster_host_config(
permission.PERMISSION_ADD_CLUSTERHOST_CONFIG permission.PERMISSION_ADD_CLUSTERHOST_CONFIG
) )
def update_cluster_host_deployed_config( def update_cluster_host_deployed_config(
updater, cluster_id, host_id, session=None, **kwargs cluster_id, host_id, user=None, session=None, **kwargs
): ):
"""Update clusterhost deployed config.""" """Update clusterhost deployed config."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
@ -1224,7 +1222,7 @@ def update_cluster_host_deployed_config(
cluster_id=cluster_id, host_id=host_id cluster_id=cluster_id, host_id=host_id
) )
return _update_clusterhost_deployed_config( return _update_clusterhost_deployed_config(
session, updater, clusterhost, **kwargs session, user, clusterhost, **kwargs
) )
@ -1237,14 +1235,14 @@ def update_cluster_host_deployed_config(
permission.PERMISSION_ADD_CLUSTERHOST_CONFIG permission.PERMISSION_ADD_CLUSTERHOST_CONFIG
) )
def update_clusterhost_config( def update_clusterhost_config(
updater, clusterhost_id, session=None, **kwargs clusterhost_id, user=None, session=None, **kwargs
): ):
"""Update clusterhost config.""" """Update clusterhost config."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
session, models.ClusterHost, clusterhost_id=clusterhost_id session, models.ClusterHost, clusterhost_id=clusterhost_id
) )
return _update_clusterhost_config( return _update_clusterhost_config(
session, updater, clusterhost, **kwargs session, user, clusterhost, **kwargs
) )
@ -1257,23 +1255,23 @@ def update_clusterhost_config(
permission.PERMISSION_ADD_CLUSTERHOST_CONFIG permission.PERMISSION_ADD_CLUSTERHOST_CONFIG
) )
def update_clusterhost_deployed_config( def update_clusterhost_deployed_config(
updater, clusterhost_id, session=None, **kwargs clusterhost_id, user=None, session=None, **kwargs
): ):
"""Update clusterhost deployed config.""" """Update clusterhost deployed config."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
session, models.ClusterHost, clusterhost_id=clusterhost_id session, models.ClusterHost, clusterhost_id=clusterhost_id
) )
return _update_clusterhost_deployed_config( return _update_clusterhost_deployed_config(
session, updater, clusterhost, **kwargs session, user, clusterhost, **kwargs
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS)
def _patch_clusterhost_config(session, updater, clusterhost, **kwargs): def _patch_clusterhost_config(session, user, clusterhost, **kwargs):
from compass.db.api import host as host_api from compass.db.api import host as host_api
ignore_keys = [] ignore_keys = []
if not host_api.is_host_editable( if not host_api.is_host_editable(
session, clusterhost.host, updater, session, clusterhost.host, user,
exception_when_not_editable=False exception_when_not_editable=False
): ):
ignore_keys.append('patched_os_config') ignore_keys.append('patched_os_config')
@ -1284,7 +1282,7 @@ def _patch_clusterhost_config(session, updater, clusterhost, **kwargs):
def package_config_validates(package_config): def package_config_validates(package_config):
cluster = clusterhost.cluster cluster = clusterhost.cluster
is_cluster_editable(session, cluster, updater) is_cluster_editable(session, cluster, user)
metadata_api.validate_package_config( metadata_api.validate_package_config(
session, package_config, cluster.adapter_id session, package_config, cluster.adapter_id
) )
@ -1316,7 +1314,7 @@ def _patch_clusterhost_config(session, updater, clusterhost, **kwargs):
permission.PERMISSION_ADD_CLUSTERHOST_CONFIG permission.PERMISSION_ADD_CLUSTERHOST_CONFIG
) )
def patch_cluster_host_config( def patch_cluster_host_config(
updater, cluster_id, host_id, session=None, **kwargs cluster_id, host_id, user=None, session=None, **kwargs
): ):
"""patch clusterhost config.""" """patch clusterhost config."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
@ -1324,7 +1322,7 @@ def patch_cluster_host_config(
cluster_id=cluster_id, host_id=host_id cluster_id=cluster_id, host_id=host_id
) )
return _patch_clusterhost_config( return _patch_clusterhost_config(
session, updater, clusterhost, **kwargs session, user, clusterhost, **kwargs
) )
@ -1337,31 +1335,31 @@ def patch_cluster_host_config(
permission.PERMISSION_ADD_CLUSTERHOST_CONFIG permission.PERMISSION_ADD_CLUSTERHOST_CONFIG
) )
def patch_clusterhost_config( def patch_clusterhost_config(
updater, clusterhost_id, session=None, **kwargs clusterhost_id, user=None, session=None, **kwargs
): ):
"""patch clusterhost config.""" """patch clusterhost config."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
session, models.ClusterHost, clusterhost_id=clusterhost_id session, models.ClusterHost, clusterhost_id=clusterhost_id
) )
return _patch_clusterhost_config( return _patch_clusterhost_config(
session, updater, clusterhost, **kwargs session, user, clusterhost, **kwargs
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS)
def _delete_clusterhost_config( def _delete_clusterhost_config(
session, deleter, clusterhost session, user, clusterhost
): ):
from compass.db.api import host as host_api from compass.db.api import host as host_api
ignore_keys = [] ignore_keys = []
if not host_api.is_host_editable( if not host_api.is_host_editable(
session, clusterhost.host, deleter, session, clusterhost.host, user,
exception_when_not_editable=False exception_when_not_editable=False
): ):
ignore_keys.append('os_config') ignore_keys.append('os_config')
def package_config_validates(package_config): def package_config_validates(package_config):
is_cluster_editable(session, clusterhost.cluster, deleter) is_cluster_editable(session, clusterhost.cluster, user)
@utils.supported_filters( @utils.supported_filters(
optional_support_keys=['os_config', 'package_config'], optional_support_keys=['os_config', 'package_config'],
@ -1388,7 +1386,7 @@ def _delete_clusterhost_config(
permission.PERMISSION_DEL_CLUSTERHOST_CONFIG permission.PERMISSION_DEL_CLUSTERHOST_CONFIG
) )
def delete_cluster_host_config( def delete_cluster_host_config(
deleter, cluster_id, host_id, session=None cluster_id, host_id, user=None, session=None
): ):
"""Delete a clusterhost config.""" """Delete a clusterhost config."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
@ -1396,7 +1394,7 @@ def delete_cluster_host_config(
cluster_id=cluster_id, host_id=host_id cluster_id=cluster_id, host_id=host_id
) )
return _delete_clusterhost_config( return _delete_clusterhost_config(
session, deleter, clusterhost session, user, clusterhost
) )
@ -1406,13 +1404,13 @@ def delete_cluster_host_config(
permission.PERMISSION_DEL_CLUSTERHOST_CONFIG permission.PERMISSION_DEL_CLUSTERHOST_CONFIG
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_CONFIG_FIELDS)
def delete_clusterhost_config(deleter, clusterhost_id, session=None): def delete_clusterhost_config(clusterhost_id, user=None, session=None):
"""Delet a clusterhost config.""" """Delet a clusterhost config."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
session, models.ClusterHost, clusterhost_id=clusterhost_id session, models.ClusterHost, clusterhost_id=clusterhost_id
) )
return _delete_clusterhost_config( return _delete_clusterhost_config(
session, deleter, clusterhost session, user, clusterhost
) )
@ -1428,14 +1426,14 @@ def delete_clusterhost_config(deleter, clusterhost_id, session=None):
hosts=RESP_CLUSTERHOST_FIELDS hosts=RESP_CLUSTERHOST_FIELDS
) )
def update_cluster_hosts( def update_cluster_hosts(
updater, cluster_id, add_hosts={}, set_hosts=None, cluster_id, add_hosts={}, set_hosts=None,
remove_hosts={}, session=None remove_hosts={}, user=None, session=None
): ):
"""Update cluster hosts.""" """Update cluster hosts."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
) )
is_cluster_editable(session, cluster, updater) is_cluster_editable(session, cluster, user)
if remove_hosts: if remove_hosts:
_remove_clusterhosts(session, cluster, **remove_hosts) _remove_clusterhosts(session, cluster, **remove_hosts)
if add_hosts: if add_hosts:
@ -1514,13 +1512,13 @@ def validate_cluster(session, cluster):
cluster=RESP_CONFIG_FIELDS, cluster=RESP_CONFIG_FIELDS,
hosts=RESP_CLUSTERHOST_CONFIG_FIELDS hosts=RESP_CLUSTERHOST_CONFIG_FIELDS
) )
def review_cluster(reviewer, cluster_id, review={}, session=None, **kwargs): def review_cluster(cluster_id, review={}, user=None, session=None, **kwargs):
"""review cluster.""" """review cluster."""
from compass.db.api import host as host_api from compass.db.api import host as host_api
cluster = utils.get_db_object( cluster = utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
) )
is_cluster_editable(session, cluster, reviewer) is_cluster_editable(session, cluster, user)
host_ids = review.get('hosts', []) host_ids = review.get('hosts', [])
clusterhost_ids = review.get('clusterhosts', []) clusterhost_ids = review.get('clusterhosts', [])
clusterhosts = [] clusterhosts = []
@ -1542,7 +1540,7 @@ def review_cluster(reviewer, cluster_id, review={}, session=None, **kwargs):
for clusterhost in clusterhosts: for clusterhost in clusterhosts:
host = clusterhost.host host = clusterhost.host
if not host_api.is_host_editable( if not host_api.is_host_editable(
session, host, reviewer, False session, host, user, False
): ):
logging.info( logging.info(
'ignore update host %s config ' 'ignore update host %s config '
@ -1616,7 +1614,7 @@ def review_cluster(reviewer, cluster_id, review={}, session=None, **kwargs):
hosts=RESP_CLUSTERHOST_FIELDS hosts=RESP_CLUSTERHOST_FIELDS
) )
def deploy_cluster( def deploy_cluster(
deployer, cluster_id, deploy={}, session=None, **kwargs cluster_id, deploy={}, user=None, session=None, **kwargs
): ):
"""deploy cluster.""" """deploy cluster."""
from compass.db.api import host as host_api from compass.db.api import host as host_api
@ -1633,13 +1631,13 @@ def deploy_cluster(
clusterhost.host_id in host_ids clusterhost.host_id in host_ids
): ):
clusterhosts.append(clusterhost) clusterhosts.append(clusterhost)
is_cluster_editable(session, cluster, deployer) is_cluster_editable(session, cluster, user)
is_cluster_validated(session, cluster) is_cluster_validated(session, cluster)
utils.update_db_object(session, cluster.state, state='INITIALIZED') utils.update_db_object(session, cluster.state, state='INITIALIZED')
for clusterhost in clusterhosts: for clusterhost in clusterhosts:
host = clusterhost.host host = clusterhost.host
if host_api.is_host_editable( if host_api.is_host_editable(
session, host, deployer, session, host, user,
exception_when_not_editable=False exception_when_not_editable=False
): ):
host_api.is_host_validated( host_api.is_host_validated(
@ -1655,7 +1653,7 @@ def deploy_cluster(
celery_client.celery.send_task( celery_client.celery.send_task(
'compass.tasks.deploy_cluster', 'compass.tasks.deploy_cluster',
( (
deployer.email, cluster_id, user.email, cluster_id,
[clusterhost.host_id for clusterhost in clusterhosts] [clusterhost.host_id for clusterhost in clusterhosts]
) )
) )
@ -1672,7 +1670,7 @@ def deploy_cluster(
permission.PERMISSION_GET_CLUSTER_STATE permission.PERMISSION_GET_CLUSTER_STATE
) )
@utils.wrap_to_dict(RESP_STATE_FIELDS) @utils.wrap_to_dict(RESP_STATE_FIELDS)
def get_cluster_state(getter, cluster_id, session=None, **kwargs): def get_cluster_state(cluster_id, user=None, session=None, **kwargs):
"""Get cluster state info.""" """Get cluster state info."""
return utils.get_db_object( return utils.get_db_object(
session, models.Cluster, id=cluster_id session, models.Cluster, id=cluster_id
@ -1686,7 +1684,7 @@ def get_cluster_state(getter, cluster_id, session=None, **kwargs):
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS)
def get_cluster_host_state( def get_cluster_host_state(
getter, cluster_id, host_id, session=None, **kwargs cluster_id, host_id, user=None, session=None, **kwargs
): ):
"""Get clusterhost state info.""" """Get clusterhost state info."""
return utils.get_db_object( return utils.get_db_object(
@ -1702,7 +1700,7 @@ def get_cluster_host_state(
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS)
def get_cluster_host_self_state( def get_cluster_host_self_state(
getter, cluster_id, host_id, session=None, **kwargs cluster_id, host_id, user=None, session=None, **kwargs
): ):
"""Get clusterhost state info.""" """Get clusterhost state info."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
@ -1722,7 +1720,7 @@ def get_cluster_host_self_state(
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS)
def get_clusterhost_state( def get_clusterhost_state(
getter, clusterhost_id, session=None, **kwargs clusterhost_id, user=None, session=None, **kwargs
): ):
"""Get clusterhost state info.""" """Get clusterhost state info."""
return utils.get_db_object( return utils.get_db_object(
@ -1738,7 +1736,7 @@ def get_clusterhost_state(
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS)
def get_clusterhost_self_state( def get_clusterhost_self_state(
getter, clusterhost_id, session=None, **kwargs clusterhost_id, user=None, session=None, **kwargs
): ):
"""Get clusterhost state info.""" """Get clusterhost state info."""
return utils.get_db_object( return utils.get_db_object(
@ -1757,7 +1755,7 @@ def get_clusterhost_self_state(
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS)
def update_cluster_host_state( def update_cluster_host_state(
updater, cluster_id, host_id, session=None, **kwargs cluster_id, host_id, user=None, session=None, **kwargs
): ):
"""Update a clusterhost state.""" """Update a clusterhost state."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
@ -1778,7 +1776,7 @@ def update_cluster_host_state(
) )
@utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_STATE_FIELDS)
def update_clusterhost_state( def update_clusterhost_state(
updater, clusterhost_id, session=None, **kwargs clusterhost_id, user=None, session=None, **kwargs
): ):
"""Update a clusterhost state.""" """Update a clusterhost state."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(
@ -1799,7 +1797,7 @@ def update_clusterhost_state(
) )
@utils.wrap_to_dict(RESP_STATE_FIELDS) @utils.wrap_to_dict(RESP_STATE_FIELDS)
def update_cluster_state( def update_cluster_state(
updater, cluster_id, session=None, **kwargs cluster_id, user=None, session=None, **kwargs
): ):
"""Update a cluster state.""" """Update a cluster state."""
cluster = utils.get_db_object( cluster = utils.get_db_object(
@ -1813,7 +1811,7 @@ def update_cluster_state(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS)
def get_cluster_host_log_histories( def get_cluster_host_log_histories(
getter, cluster_id, host_id, session=None, **kwargs cluster_id, host_id, user=None, session=None, **kwargs
): ):
"""Get clusterhost log history.""" """Get clusterhost log history."""
return utils.list_db_objects( return utils.list_db_objects(
@ -1826,7 +1824,7 @@ def get_cluster_host_log_histories(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS)
def get_clusterhost_log_histories( def get_clusterhost_log_histories(
getter, clusterhost_id, clusterhost_id, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Get clusterhost log history.""" """Get clusterhost log history."""
@ -1839,7 +1837,7 @@ def get_clusterhost_log_histories(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS)
def get_cluster_host_log_history( def get_cluster_host_log_history(
getter, cluster_id, host_id, filename, session=None, **kwargs cluster_id, host_id, filename, user=None, session=None, **kwargs
): ):
"""Get clusterhost log history.""" """Get clusterhost log history."""
return utils.get_db_object( return utils.get_db_object(
@ -1852,7 +1850,7 @@ def get_cluster_host_log_history(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS)
def get_clusterhost_log_history( def get_clusterhost_log_history(
getter, clusterhost_id, filename, session=None, **kwargs clusterhost_id, filename, user=None, session=None, **kwargs
): ):
"""Get host log history.""" """Get host log history."""
return utils.get_db_object( return utils.get_db_object(
@ -1868,7 +1866,7 @@ def get_clusterhost_log_history(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS)
def update_cluster_host_log_history( def update_cluster_host_log_history(
updater, cluster_id, host_id, filename, session=None, **kwargs cluster_id, host_id, filename, user=None, session=None, **kwargs
): ):
"""Update a host log history.""" """Update a host log history."""
cluster_host_log_history = utils.get_db_object( cluster_host_log_history = utils.get_db_object(
@ -1885,7 +1883,7 @@ def update_cluster_host_log_history(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS)
def update_clusterhost_log_history( def update_clusterhost_log_history(
updater, clusterhost_id, filename, session=None, **kwargs clusterhost_id, filename, user=None, session=None, **kwargs
): ):
"""Update a host log history.""" """Update a host log history."""
clusterhost_log_history = utils.get_db_object( clusterhost_log_history = utils.get_db_object(
@ -1903,8 +1901,8 @@ def update_clusterhost_log_history(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS)
def add_clusterhost_log_history( def add_clusterhost_log_history(
creator, clusterhost_id, exception_when_existing=False, clusterhost_id, exception_when_existing=False,
filename=None, session=None, **kwargs filename=None, user=None, session=None, **kwargs
): ):
"""add a host log history.""" """add a host log history."""
return utils.add_db_object( return utils.add_db_object(
@ -1921,8 +1919,8 @@ def add_clusterhost_log_history(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS) @utils.wrap_to_dict(RESP_CLUSTERHOST_LOG_FIELDS)
def add_cluster_host_log_history( def add_cluster_host_log_history(
creator, cluster_id, host_id, exception_when_existing=False, cluster_id, host_id, exception_when_existing=False,
filename=None, session=None, **kwargs filename=None, user=None, session=None, **kwargs
): ):
"""add a host log history.""" """add a host log history."""
clusterhost = utils.get_db_object( clusterhost = utils.get_db_object(

View File

@ -149,9 +149,7 @@ def run_in_session():
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if args is not () and 'session' in str(args[-1]): if 'session' in kwargs.keys():
return func(*args, **kwargs)
elif 'session' in kwargs.keys():
return func(*args, **kwargs) return func(*args, **kwargs)
else: else:
with session() as my_session: with session() as my_session:

View File

@ -107,7 +107,7 @@ UPDATED_LOG_FIELDS = [
permission.PERMISSION_LIST_HOSTS permission.PERMISSION_LIST_HOSTS
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_hosts(lister, session=None, **filters): def list_hosts(user=None, session=None, **filters):
"""List hosts.""" """List hosts."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.Host, **filters session, models.Host, **filters
@ -128,7 +128,7 @@ def list_hosts(lister, session=None, **filters):
os_id=utils.general_filter_callback os_id=utils.general_filter_callback
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_machines_or_hosts(lister, session=None, **filters): def list_machines_or_hosts(user=None, session=None, **filters):
"""List hosts.""" """List hosts."""
machines = utils.list_db_objects( machines = utils.list_db_objects(
session, models.Machine, **filters session, models.Machine, **filters
@ -150,8 +150,8 @@ def list_machines_or_hosts(lister, session=None, **filters):
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def get_host( def get_host(
getter, host_id, host_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""get host info.""" """get host info."""
return utils.get_db_object( return utils.get_db_object(
@ -167,8 +167,8 @@ def get_host(
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def get_machine_or_host( def get_machine_or_host(
getter, host_id, host_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""get host info.""" """get host info."""
machine = utils.get_db_object( machine = utils.get_db_object(
@ -190,7 +190,7 @@ def get_machine_or_host(
permission.PERMISSION_LIST_HOST_CLUSTERS permission.PERMISSION_LIST_HOST_CLUSTERS
) )
@utils.wrap_to_dict(RESP_CLUSTER_FIELDS) @utils.wrap_to_dict(RESP_CLUSTER_FIELDS)
def get_host_clusters(getter, host_id, session=None, **kwargs): def get_host_clusters(host_id, user=None, session=None, **kwargs):
"""get host clusters.""" """get host clusters."""
host = utils.get_db_object( host = utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
@ -276,13 +276,13 @@ def validate_host(session, host):
) )
@utils.input_validates(name=utils.check_name) @utils.input_validates(name=utils.check_name)
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def _update_host(session, updater, host_id, **kwargs): def _update_host(session, user, host_id, **kwargs):
"""Update a host internal.""" """Update a host internal."""
host = utils.get_db_object( host = utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
) )
is_host_editable( is_host_editable(
session, host, updater, session, host, user,
reinstall_os_set=kwargs.get('reinstall_os', False) reinstall_os_set=kwargs.get('reinstall_os', False)
) )
if 'name' in kwargs: if 'name' in kwargs:
@ -303,19 +303,19 @@ def _update_host(session, updater, host_id, **kwargs):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_UPDATE_HOST permission.PERMISSION_UPDATE_HOST
) )
def update_host(updater, host_id, session=None, **kwargs): def update_host(host_id, user=None, session=None, **kwargs):
"""Update a host.""" """Update a host."""
return _update_host(session, updater, host_id=host_id, **kwargs) return _update_host(session, user, host_id=host_id, **kwargs)
@database.run_in_session() @database.run_in_session()
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_UPDATE_HOST permission.PERMISSION_UPDATE_HOST
) )
def update_hosts(updater, data=[], session=None): def update_hosts(data=[], user=None, session=None):
hosts = [] hosts = []
for host_data in data: for host_data in data:
hosts.append(_update_host(session, updater, **host_data)) hosts.append(_update_host(session, user, **host_data))
return hosts return hosts
@ -329,8 +329,8 @@ def update_hosts(updater, data=[], session=None):
host=RESP_FIELDS host=RESP_FIELDS
) )
def del_host( def del_host(
deleter, host_id, host_id, force=False, from_database_only=False,
force=False, from_database_only=False, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Delete a host.""" """Delete a host."""
from compass.db.api import cluster as cluster_api from compass.db.api import cluster as cluster_api
@ -340,7 +340,7 @@ def del_host(
if host.state.state != 'UNINITIALIZED' and force: if host.state.state != 'UNINITIALIZED' and force:
host.state.state = 'ERROR' host.state.state = 'ERROR'
is_host_editable( is_host_editable(
session, host, deleter, session, host, user,
reinstall_os_set=True reinstall_os_set=True
) )
cluster_ids = [] cluster_ids = []
@ -348,7 +348,7 @@ def del_host(
if clusterhost.state.state != 'UNINITIALIZED' and force: if clusterhost.state.state != 'UNINITIALIZED' and force:
clusterhost.state.state = 'ERROR' clusterhost.state.state = 'ERROR'
cluster_api.is_cluster_editable( cluster_api.is_cluster_editable(
session, clusterhost.cluster, deleter, session, clusterhost.cluster, user,
reinstall_distributed_system_set=True reinstall_distributed_system_set=True
) )
cluster_ids.append(clusterhost.cluster_id) cluster_ids.append(clusterhost.cluster_id)
@ -363,7 +363,7 @@ def del_host(
celery_client.celery.send_task( celery_client.celery.send_task(
'compass.tasks.delete_host', 'compass.tasks.delete_host',
( (
deleter.email, host_id, cluster_ids user.email, host_id, cluster_ids
) )
) )
return { return {
@ -378,7 +378,7 @@ def del_host(
permission.PERMISSION_LIST_HOST_CONFIG permission.PERMISSION_LIST_HOST_CONFIG
) )
@utils.wrap_to_dict(RESP_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CONFIG_FIELDS)
def get_host_config(getter, host_id, session=None, **kwargs): def get_host_config(host_id, user=None, session=None, **kwargs):
"""Get host config.""" """Get host config."""
return utils.get_db_object( return utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
@ -391,7 +391,7 @@ def get_host_config(getter, host_id, session=None, **kwargs):
permission.PERMISSION_LIST_HOST_CONFIG permission.PERMISSION_LIST_HOST_CONFIG
) )
@utils.wrap_to_dict(RESP_DEPLOYED_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_DEPLOYED_CONFIG_FIELDS)
def get_host_deployed_config(getter, host_id, session=None, **kwargs): def get_host_deployed_config(host_id, user=None, session=None, **kwargs):
"""Get host deployed config.""" """Get host deployed config."""
return utils.get_db_object( return utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
@ -410,20 +410,20 @@ def get_host_deployed_config(getter, host_id, session=None, **kwargs):
permission.PERMISSION_ADD_HOST_CONFIG permission.PERMISSION_ADD_HOST_CONFIG
) )
@utils.wrap_to_dict(RESP_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CONFIG_FIELDS)
def update_host_deployed_config(updater, host_id, session=None, **kwargs): def update_host_deployed_config(host_id, user=None, session=None, **kwargs):
"""Update host deployed config.""" """Update host deployed config."""
host = utils.get_db_object( host = utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
) )
is_host_editable(session, host, updater) is_host_editable(session, host, user)
is_host_validated(session, host) is_host_validated(session, host)
return utils.update_db_object(session, host, **kwargs) return utils.update_db_object(session, host, **kwargs)
@utils.wrap_to_dict(RESP_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CONFIG_FIELDS)
def _update_host_config(session, updater, host, **kwargs): def _update_host_config(session, user, host, **kwargs):
"""Update host config.""" """Update host config."""
is_host_editable(session, host, updater) is_host_editable(session, host, user)
return utils.update_db_object(session, host, **kwargs) return utils.update_db_object(session, host, **kwargs)
@ -438,7 +438,7 @@ def _update_host_config(session, updater, host, **kwargs):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_ADD_HOST_CONFIG permission.PERMISSION_ADD_HOST_CONFIG
) )
def update_host_config(updater, host_id, session=None, **kwargs): def update_host_config(host_id, user=None, session=None, **kwargs):
host = utils.get_db_object( host = utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
) )
@ -453,7 +453,7 @@ def update_host_config(updater, host_id, session=None, **kwargs):
) )
def update_config_internal(host, **in_kwargs): def update_config_internal(host, **in_kwargs):
return _update_host_config( return _update_host_config(
session, updater, host, **kwargs session, user, host, **kwargs
) )
return update_config_internal( return update_config_internal(
@ -472,7 +472,7 @@ def update_host_config(updater, host_id, session=None, **kwargs):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_ADD_HOST_CONFIG permission.PERMISSION_ADD_HOST_CONFIG
) )
def patch_host_config(updater, host_id, session=None, **kwargs): def patch_host_config(host_id, user=None, session=None, **kwargs):
host = utils.get_db_object( host = utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
) )
@ -487,7 +487,7 @@ def patch_host_config(updater, host_id, session=None, **kwargs):
) )
def patch_config_internal(host, **in_kwargs): def patch_config_internal(host, **in_kwargs):
return _update_host_config( return _update_host_config(
session, updater, host, **in_kwargs session, user, host, **in_kwargs
) )
return patch_config_internal( return patch_config_internal(
@ -501,12 +501,12 @@ def patch_host_config(updater, host_id, session=None, **kwargs):
permission.PERMISSION_DEL_HOST_CONFIG permission.PERMISSION_DEL_HOST_CONFIG
) )
@utils.wrap_to_dict(RESP_CONFIG_FIELDS) @utils.wrap_to_dict(RESP_CONFIG_FIELDS)
def del_host_config(deleter, host_id, session=None): def del_host_config(host_id, user=None, session=None):
"""delete a host config.""" """delete a host config."""
host = utils.get_db_object( host = utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
) )
is_host_editable(session, host, deleter) is_host_editable(session, host, user)
return utils.update_db_object( return utils.update_db_object(
session, host, os_config={}, config_validated=False session, host, os_config={}, config_validated=False
) )
@ -520,7 +520,7 @@ def del_host_config(deleter, host_id, session=None):
permission.PERMISSION_LIST_HOST_NETWORKS permission.PERMISSION_LIST_HOST_NETWORKS
) )
@utils.wrap_to_dict(RESP_NETWORK_FIELDS) @utils.wrap_to_dict(RESP_NETWORK_FIELDS)
def list_host_networks(lister, host_id, session=None, **filters): def list_host_networks(host_id, user=None, session=None, **filters):
"""Get host networks.""" """Get host networks."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.HostNetwork, session, models.HostNetwork,
@ -536,7 +536,7 @@ def list_host_networks(lister, host_id, session=None, **filters):
permission.PERMISSION_LIST_HOST_NETWORKS permission.PERMISSION_LIST_HOST_NETWORKS
) )
@utils.wrap_to_dict(RESP_NETWORK_FIELDS) @utils.wrap_to_dict(RESP_NETWORK_FIELDS)
def list_hostnetworks(lister, session=None, **filters): def list_hostnetworks(user=None, session=None, **filters):
"""Get host networks.""" """Get host networks."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.HostNetwork, **filters session, models.HostNetwork, **filters
@ -550,8 +550,8 @@ def list_hostnetworks(lister, session=None, **filters):
) )
@utils.wrap_to_dict(RESP_NETWORK_FIELDS) @utils.wrap_to_dict(RESP_NETWORK_FIELDS)
def get_host_network( def get_host_network(
getter, host_id, host_id, host_network_id,
host_network_id, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Get host network.""" """Get host network."""
host_network = utils.get_db_object( host_network = utils.get_db_object(
@ -573,7 +573,7 @@ def get_host_network(
permission.PERMISSION_LIST_HOST_NETWORKS permission.PERMISSION_LIST_HOST_NETWORKS
) )
@utils.wrap_to_dict(RESP_NETWORK_FIELDS) @utils.wrap_to_dict(RESP_NETWORK_FIELDS)
def get_hostnetwork(getter, host_network_id, session=None, **kwargs): def get_hostnetwork(host_network_id, user=None, session=None, **kwargs):
"""Get host network.""" """Get host network."""
return utils.get_db_object( return utils.get_db_object(
session, models.HostNetwork, session, models.HostNetwork,
@ -591,7 +591,7 @@ def get_hostnetwork(getter, host_network_id, session=None, **kwargs):
) )
@utils.wrap_to_dict(RESP_NETWORK_FIELDS) @utils.wrap_to_dict(RESP_NETWORK_FIELDS)
def _add_host_network( def _add_host_network(
session, creator, host_id, exception_when_existing=True, session, user, host_id, exception_when_existing=True,
interface=None, ip=None, **kwargs interface=None, ip=None, **kwargs
): ):
host = utils.get_db_object( host = utils.get_db_object(
@ -613,7 +613,7 @@ def _add_host_network(
ip, host_network.id ip, host_network.id
) )
) )
is_host_editable(session, host, creator) is_host_editable(session, host, user)
return utils.add_db_object( return utils.add_db_object(
session, models.HostNetwork, session, models.HostNetwork,
exception_when_existing, exception_when_existing,
@ -626,13 +626,12 @@ def _add_host_network(
permission.PERMISSION_ADD_HOST_NETWORK permission.PERMISSION_ADD_HOST_NETWORK
) )
def add_host_network( def add_host_network(
creator, host_id, host_id, exception_when_existing=True,
exception_when_existing=True, interface=None, user=None, session=None, **kwargs
interface=None, session=None, **kwargs
): ):
"""Create a host network.""" """Create a host network."""
return _add_host_network( return _add_host_network(
session, creator, host_id, exception_when_existing, session, user, host_id, exception_when_existing,
interface=interface, **kwargs interface=interface, **kwargs
) )
@ -642,9 +641,8 @@ def add_host_network(
permission.PERMISSION_ADD_HOST_NETWORK permission.PERMISSION_ADD_HOST_NETWORK
) )
def add_host_networks( def add_host_networks(
creator,
exception_when_existing=False, exception_when_existing=False,
data=[], session=None data=[], user=None, session=None
): ):
"""Create host networks.""" """Create host networks."""
hosts = [] hosts = []
@ -657,7 +655,7 @@ def add_host_networks(
for network in networks: for network in networks:
try: try:
host_networks.append(_add_host_network( host_networks.append(_add_host_network(
session, creator, host_id, exception_when_existing, session, user, host_id, exception_when_existing,
**network **network
)) ))
except exception.DatabaseException as error: except exception.DatabaseException as error:
@ -677,7 +675,7 @@ def add_host_networks(
@utils.wrap_to_dict(RESP_NETWORK_FIELDS) @utils.wrap_to_dict(RESP_NETWORK_FIELDS)
def _update_host_network( def _update_host_network(
session, updater, host_network, **kwargs session, user, host_network, **kwargs
): ):
if 'interface' in kwargs: if 'interface' in kwargs:
interface = kwargs['interface'] interface = kwargs['interface']
@ -708,7 +706,7 @@ def _update_host_network(
ip, host_network_by_ip.id ip, host_network_by_ip.id
) )
) )
is_host_editable(session, host_network.host, updater) is_host_editable(session, host_network.host, user)
return utils.update_db_object(session, host_network, **kwargs) return utils.update_db_object(session, host_network, **kwargs)
@ -724,7 +722,7 @@ def _update_host_network(
permission.PERMISSION_ADD_HOST_NETWORK permission.PERMISSION_ADD_HOST_NETWORK
) )
def update_host_network( def update_host_network(
updater, host_id, host_network_id, session=None, **kwargs host_id, host_network_id, user=None, session=None, **kwargs
): ):
"""Update a host network.""" """Update a host network."""
host_network = utils.get_db_object( host_network = utils.get_db_object(
@ -738,7 +736,7 @@ def update_host_network(
) )
) )
return _update_host_network( return _update_host_network(
session, updater, host_network, **kwargs session, user, host_network, **kwargs
) )
@ -753,13 +751,13 @@ def update_host_network(
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_ADD_HOST_NETWORK permission.PERMISSION_ADD_HOST_NETWORK
) )
def update_hostnetwork(updater, host_network_id, session=None, **kwargs): def update_hostnetwork(host_network_id, user=None, session=None, **kwargs):
"""Update a host network.""" """Update a host network."""
host_network = utils.get_db_object( host_network = utils.get_db_object(
session, models.HostNetwork, id=host_network_id session, models.HostNetwork, id=host_network_id
) )
return _update_host_network( return _update_host_network(
session, updater, host_network, **kwargs session, user, host_network, **kwargs
) )
@ -770,7 +768,7 @@ def update_hostnetwork(updater, host_network_id, session=None, **kwargs):
) )
@utils.wrap_to_dict(RESP_NETWORK_FIELDS) @utils.wrap_to_dict(RESP_NETWORK_FIELDS)
def del_host_network( def del_host_network(
deleter, host_id, host_network_id, host_id, host_network_id, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Delete a host network.""" """Delete a host network."""
@ -784,7 +782,7 @@ def del_host_network(
host_id, host_network_id host_id, host_network_id
) )
) )
is_host_editable(session, host_network.host, deleter) is_host_editable(session, host_network.host, user)
return utils.del_db_object(session, host_network) return utils.del_db_object(session, host_network)
@ -794,12 +792,12 @@ def del_host_network(
permission.PERMISSION_DEL_HOST_NETWORK permission.PERMISSION_DEL_HOST_NETWORK
) )
@utils.wrap_to_dict(RESP_NETWORK_FIELDS) @utils.wrap_to_dict(RESP_NETWORK_FIELDS)
def del_hostnetwork(deleter, host_network_id, session=None, **kwargs): def del_hostnetwork(host_network_id, user=None, session=None, **kwargs):
"""Delete a host network.""" """Delete a host network."""
host_network = utils.get_db_object( host_network = utils.get_db_object(
session, models.HostNetwork, id=host_network_id session, models.HostNetwork, id=host_network_id
) )
is_host_editable(session, host_network.host, deleter) is_host_editable(session, host_network.host, user)
return utils.del_db_object(session, host_network) return utils.del_db_object(session, host_network)
@ -809,7 +807,7 @@ def del_hostnetwork(deleter, host_network_id, session=None, **kwargs):
permission.PERMISSION_GET_HOST_STATE permission.PERMISSION_GET_HOST_STATE
) )
@utils.wrap_to_dict(RESP_STATE_FIELDS) @utils.wrap_to_dict(RESP_STATE_FIELDS)
def get_host_state(getter, host_id, session=None, **kwargs): def get_host_state(host_id, user=None, session=None, **kwargs):
"""Get host state info.""" """Get host state info."""
return utils.get_db_object( return utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
@ -825,7 +823,7 @@ def get_host_state(getter, host_id, session=None, **kwargs):
permission.PERMISSION_UPDATE_HOST_STATE permission.PERMISSION_UPDATE_HOST_STATE
) )
@utils.wrap_to_dict(RESP_STATE_FIELDS) @utils.wrap_to_dict(RESP_STATE_FIELDS)
def update_host_state(updater, host_id, session=None, **kwargs): def update_host_state(host_id, user=None, session=None, **kwargs):
"""Update a host state.""" """Update a host state."""
host = utils.get_db_object( host = utils.get_db_object(
session, models.Host, id=host_id session, models.Host, id=host_id
@ -837,7 +835,7 @@ def update_host_state(updater, host_id, session=None, **kwargs):
@utils.supported_filters([]) @utils.supported_filters([])
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_LOG_FIELDS) @utils.wrap_to_dict(RESP_LOG_FIELDS)
def get_host_log_histories(getter, host_id, session=None, **kwargs): def get_host_log_histories(host_id, user=None, session=None, **kwargs):
"""Get host log history.""" """Get host log history."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.HostLogHistory, id=host_id session, models.HostLogHistory, id=host_id
@ -847,7 +845,7 @@ def get_host_log_histories(getter, host_id, session=None, **kwargs):
@utils.supported_filters([]) @utils.supported_filters([])
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_LOG_FIELDS) @utils.wrap_to_dict(RESP_LOG_FIELDS)
def get_host_log_history(getter, host_id, filename, session=None, **kwargs): def get_host_log_history(host_id, filename, user=None, session=None, **kwargs):
"""Get host log history.""" """Get host log history."""
return utils.get_db_object( return utils.get_db_object(
session, models.HostLogHistory, id=host_id, filename=filename session, models.HostLogHistory, id=host_id, filename=filename
@ -861,7 +859,7 @@ def get_host_log_history(getter, host_id, filename, session=None, **kwargs):
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_LOG_FIELDS) @utils.wrap_to_dict(RESP_LOG_FIELDS)
def update_host_log_history( def update_host_log_history(
updater, host_id, filename, host_id, filename, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Update a host log history.""" """Update a host log history."""
@ -879,8 +877,8 @@ def update_host_log_history(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_LOG_FIELDS) @utils.wrap_to_dict(RESP_LOG_FIELDS)
def add_host_log_history( def add_host_log_history(
creator, host_id, exception_when_existing=False, host_id, exception_when_existing=False,
filename=None, session=None, **kwargs filename=None, user=None, session=None, **kwargs
): ):
"""add a host log history.""" """add a host log history."""
return utils.add_db_object( return utils.add_db_object(
@ -899,7 +897,7 @@ def add_host_log_history(
host=RESP_CONFIG_FIELDS host=RESP_CONFIG_FIELDS
) )
def poweron_host( def poweron_host(
deployer, host_id, poweron={}, session=None, **kwargs host_id, poweron={}, user=None, session=None, **kwargs
): ):
"""power on host.""" """power on host."""
from compass.tasks import client as celery_client from compass.tasks import client as celery_client
@ -927,7 +925,7 @@ def poweron_host(
host=RESP_CONFIG_FIELDS host=RESP_CONFIG_FIELDS
) )
def poweroff_host( def poweroff_host(
deployer, host_id, poweroff={}, session=None, **kwargs host_id, poweroff={}, user=None, session=None, **kwargs
): ):
"""power off host.""" """power off host."""
from compass.tasks import client as celery_client from compass.tasks import client as celery_client
@ -955,7 +953,7 @@ def poweroff_host(
host=RESP_CONFIG_FIELDS host=RESP_CONFIG_FIELDS
) )
def reset_host( def reset_host(
deployer, host_id, reset={}, session=None, **kwargs host_id, reset={}, user=None, session=None, **kwargs
): ):
"""reset host.""" """reset host."""
from compass.tasks import client as celery_client from compass.tasks import client as celery_client

View File

@ -50,9 +50,8 @@ RESP_DEPLOY_FIELDS = [
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def get_machine( def get_machine(
getter, machine_id, machine_id, exception_when_missing=True,
exception_when_missing=True, session=None, user=None, session=None, **kwargs
**kwargs
): ):
"""get field dict of a machine.""" """get field dict of a machine."""
return utils.get_db_object( return utils.get_db_object(
@ -73,7 +72,7 @@ def get_machine(
location=utils.general_filter_callback location=utils.general_filter_callback
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_machines(lister, session=None, **filters): def list_machines(user=None, session=None, **filters):
"""List machines.""" """List machines."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.Machine, **filters session, models.Machine, **filters
@ -81,7 +80,7 @@ def list_machines(lister, session=None, **filters):
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def _update_machine(session, updater, machine_id, **kwargs): def _update_machine(session, machine_id, **kwargs):
"""Update a machine.""" """Update a machine."""
machine = utils.get_db_object(session, models.Machine, id=machine_id) machine = utils.get_db_object(session, models.Machine, id=machine_id)
return utils.update_db_object(session, machine, **kwargs) return utils.update_db_object(session, machine, **kwargs)
@ -96,9 +95,9 @@ def _update_machine(session, updater, machine_id, **kwargs):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_ADD_MACHINE permission.PERMISSION_ADD_MACHINE
) )
def update_machine(updater, machine_id, session=None, **kwargs): def update_machine(machine_id, user=None, session=None, **kwargs):
return _update_machine( return _update_machine(
session, updater, machine_id, **kwargs session, machine_id, **kwargs
) )
@ -113,9 +112,12 @@ def update_machine(updater, machine_id, session=None, **kwargs):
) )
@database.run_in_session() @database.run_in_session()
@utils.output_validates(ipmi_credentials=utils.check_ipmi_credentials) @utils.output_validates(ipmi_credentials=utils.check_ipmi_credentials)
def patch_machine(updater, machine_id, session=None, **kwargs): @user_api.check_user_permission_in_session(
permission.PERMISSION_ADD_MACHINE
)
def patch_machine(machine_id, user=None, session=None, **kwargs):
return _update_machine( return _update_machine(
session, updater, machine_id, **kwargs session, machine_id, **kwargs
) )
@ -125,7 +127,7 @@ def patch_machine(updater, machine_id, session=None, **kwargs):
permission.PERMISSION_DEL_MACHINE permission.PERMISSION_DEL_MACHINE
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def del_machine(deleter, machine_id, session=None, **kwargs): def del_machine(machine_id, user=None, session=None, **kwargs):
"""Delete a machine.""" """Delete a machine."""
machine = utils.get_db_object(session, models.Machine, id=machine_id) machine = utils.get_db_object(session, models.Machine, id=machine_id)
if machine.host: if machine.host:
@ -148,7 +150,7 @@ def del_machine(deleter, machine_id, session=None, **kwargs):
machine=RESP_FIELDS machine=RESP_FIELDS
) )
def poweron_machine( def poweron_machine(
deployer, machine_id, poweron={}, session=None, **kwargs machine_id, poweron={}, user=None, session=None, **kwargs
): ):
"""power on machine.""" """power on machine."""
from compass.tasks import client as celery_client from compass.tasks import client as celery_client
@ -175,7 +177,7 @@ def poweron_machine(
machine=RESP_FIELDS machine=RESP_FIELDS
) )
def poweroff_machine( def poweroff_machine(
deployer, machine_id, poweroff={}, session=None, **kwargs machine_id, poweroff={}, user=None, session=None, **kwargs
): ):
"""power off machine.""" """power off machine."""
from compass.tasks import client as celery_client from compass.tasks import client as celery_client
@ -202,7 +204,7 @@ def poweroff_machine(
machine=RESP_FIELDS machine=RESP_FIELDS
) )
def reset_machine( def reset_machine(
deployer, machine_id, reset={}, session=None, **kwargs machine_id, reset={}, user=None, session=None, **kwargs
): ):
"""reset machine.""" """reset machine."""
from compass.tasks import client as celery_client from compass.tasks import client as celery_client

View File

@ -51,7 +51,7 @@ def _check_subnet(subnet):
permission.PERMISSION_LIST_SUBNETS permission.PERMISSION_LIST_SUBNETS
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_subnets(lister, session=None, **filters): def list_subnets(user=None, session=None, **filters):
"""List subnets.""" """List subnets."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.Subnet, **filters session, models.Subnet, **filters
@ -65,8 +65,8 @@ def list_subnets(lister, session=None, **filters):
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def get_subnet( def get_subnet(
getter, subnet_id, subnet_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Get subnet info.""" """Get subnet info."""
return utils.get_db_object( return utils.get_db_object(
@ -86,8 +86,8 @@ def get_subnet(
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def add_subnet( def add_subnet(
creator, exception_when_existing=True, exception_when_existing=True, subnet=None,
subnet=None, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Create a subnet.""" """Create a subnet."""
return utils.add_db_object( return utils.add_db_object(
@ -106,7 +106,7 @@ def add_subnet(
permission.PERMISSION_ADD_SUBNET permission.PERMISSION_ADD_SUBNET
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def update_subnet(updater, subnet_id, session=None, **kwargs): def update_subnet(subnet_id, user=None, session=None, **kwargs):
"""Update a subnet.""" """Update a subnet."""
subnet = utils.get_db_object( subnet = utils.get_db_object(
session, models.Subnet, id=subnet_id session, models.Subnet, id=subnet_id
@ -120,7 +120,7 @@ def update_subnet(updater, subnet_id, session=None, **kwargs):
permission.PERMISSION_DEL_SUBNET permission.PERMISSION_DEL_SUBNET
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def del_subnet(deleter, subnet_id, session=None, **kwargs): def del_subnet(subnet_id, user=None, session=None, **kwargs):
"""Delete a subnet.""" """Delete a subnet."""
subnet = utils.get_db_object( subnet = utils.get_db_object(
session, models.Subnet, id=subnet_id session, models.Subnet, id=subnet_id

View File

@ -300,7 +300,7 @@ def list_permissions_internal(session, **filters):
@database.run_in_session() @database.run_in_session()
@user_api.check_user_permission_in_session(PERMISSION_LIST_PERMISSIONS) @user_api.check_user_permission_in_session(PERMISSION_LIST_PERMISSIONS)
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_permissions(lister, session=None, **filters): def list_permissions(user=None, session=None, **filters):
"""list permissions.""" """list permissions."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.Permission, **filters session, models.Permission, **filters
@ -312,8 +312,8 @@ def list_permissions(lister, session=None, **filters):
@user_api.check_user_permission_in_session(PERMISSION_LIST_PERMISSIONS) @user_api.check_user_permission_in_session(PERMISSION_LIST_PERMISSIONS)
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def get_permission( def get_permission(
getter, permission_id, permission_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""get permissions.""" """get permissions."""
return utils.get_db_object( return utils.get_db_object(

View File

@ -140,8 +140,8 @@ def get_switch_internal(
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def get_switch( def get_switch(
getter, switch_id, switch_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""get field dict of a switch.""" """get field dict of a switch."""
return utils.get_db_object( return utils.get_db_object(
@ -156,7 +156,7 @@ def get_switch(
permission.PERMISSION_LIST_SWITCHES permission.PERMISSION_LIST_SWITCHES
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_switches(lister, session=None, **filters): def list_switches(user=None, session=None, **filters):
"""List switches.""" """List switches."""
switches = utils.list_db_objects( switches = utils.list_db_objects(
session, models.Switch, **filters session, models.Switch, **filters
@ -176,7 +176,7 @@ def list_switches(lister, session=None, **filters):
permission.PERMISSION_DEL_SWITCH permission.PERMISSION_DEL_SWITCH
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def del_switch(deleter, switch_id, session=None, **kwargs): def del_switch(switch_id, user=None, session=None, **kwargs):
"""Delete a switch.""" """Delete a switch."""
switch = utils.get_db_object(session, models.Switch, id=switch_id) switch = utils.get_db_object(session, models.Switch, id=switch_id)
default_switch_ip_int = long(netaddr.IPAddress(setting.DEFAULT_SWITCH_IP)) default_switch_ip_int = long(netaddr.IPAddress(setting.DEFAULT_SWITCH_IP))
@ -212,8 +212,8 @@ def del_switch(deleter, switch_id, session=None, **kwargs):
) )
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def add_switch( def add_switch(
creator, exception_when_existing=True, exception_when_existing=True, ip=None,
ip=None, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Create a switch.""" """Create a switch."""
ip_int = long(netaddr.IPAddress(ip)) ip_int = long(netaddr.IPAddress(ip))
@ -231,7 +231,7 @@ def update_switch_internal(session, switch, **kwargs):
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def _update_switch(session, updater, switch_id, **kwargs): def _update_switch(session, switch_id, **kwargs):
"""Update a switch.""" """Update a switch."""
switch = utils.get_db_object( switch = utils.get_db_object(
session, models.Switch, id=switch_id session, models.Switch, id=switch_id
@ -254,9 +254,9 @@ def _update_switch(session, updater, switch_id, **kwargs):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_ADD_SWITCH permission.PERMISSION_ADD_SWITCH
) )
def update_switch(updater, switch_id, session=None, **kwargs): def update_switch(switch_id, user=None, session=None, **kwargs):
"""Update fields of a switch.""" """Update fields of a switch."""
return _update_switch(session, updater, switch_id, **kwargs) return _update_switch(session, switch_id, **kwargs)
@utils.replace_filters( @utils.replace_filters(
@ -274,9 +274,12 @@ def update_switch(updater, switch_id, session=None, **kwargs):
@utils.output_validates( @utils.output_validates(
credentials=utils.check_switch_credentials credentials=utils.check_switch_credentials
) )
def patch_switch(updater, switch_id, session=None, **kwargs): @user_api.check_user_permission_in_session(
permission.PERMISSION_ADD_SWITCH
)
def patch_switch(switch_id, user=None, session=None, **kwargs):
"""Patch fields of a switch.""" """Patch fields of a switch."""
return _update_switch(session, updater, switch_id, **kwargs) return _update_switch(session, switch_id, **kwargs)
@utils.supported_filters(optional_support_keys=SUPPORTED_FILTER_FIELDS) @utils.supported_filters(optional_support_keys=SUPPORTED_FILTER_FIELDS)
@ -285,7 +288,7 @@ def patch_switch(updater, switch_id, session=None, **kwargs):
permission.PERMISSION_LIST_SWITCH_FILTERS permission.PERMISSION_LIST_SWITCH_FILTERS
) )
@utils.wrap_to_dict(RESP_FILTERS_FIELDS) @utils.wrap_to_dict(RESP_FILTERS_FIELDS)
def list_switch_filters(lister, session=None, **filters): def list_switch_filters(user=None, session=None, **filters):
"""List switch filters.""" """List switch filters."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.Switch, **filters session, models.Switch, **filters
@ -299,7 +302,7 @@ def list_switch_filters(lister, session=None, **filters):
) )
@utils.wrap_to_dict(RESP_FILTERS_FIELDS) @utils.wrap_to_dict(RESP_FILTERS_FIELDS)
def get_switch_filters( def get_switch_filters(
getter, switch_id, session=None, **kwargs switch_id, user=None, session=None, **kwargs
): ):
"""get switch filter.""" """get switch filter."""
return utils.get_db_object( return utils.get_db_object(
@ -320,7 +323,7 @@ def get_switch_filters(
permission.PERMISSION_UPDATE_SWITCH_FILTERS permission.PERMISSION_UPDATE_SWITCH_FILTERS
) )
@utils.wrap_to_dict(RESP_FILTERS_FIELDS) @utils.wrap_to_dict(RESP_FILTERS_FIELDS)
def update_switch_filters(updater, switch_id, session=None, **kwargs): def update_switch_filters(switch_id, user=None, session=None, **kwargs):
"""Update a switch filter.""" """Update a switch filter."""
switch = utils.get_db_object(session, models.Switch, id=switch_id) switch = utils.get_db_object(session, models.Switch, id=switch_id)
return utils.update_db_object(session, switch, **kwargs) return utils.update_db_object(session, switch, **kwargs)
@ -339,7 +342,7 @@ def update_switch_filters(updater, switch_id, session=None, **kwargs):
permission.PERMISSION_UPDATE_SWITCH_FILTERS permission.PERMISSION_UPDATE_SWITCH_FILTERS
) )
@utils.wrap_to_dict(RESP_FILTERS_FIELDS) @utils.wrap_to_dict(RESP_FILTERS_FIELDS)
def patch_switch_filter(updater, switch_id, session=None, **kwargs): def patch_switch_filter(switch_id, user=None, session=None, **kwargs):
"""Patch a switch filter.""" """Patch a switch filter."""
switch = utils.get_db_object(session, models.Switch, id=switch_id) switch = utils.get_db_object(session, models.Switch, id=switch_id)
return utils.update_db_object(session, switch, **kwargs) return utils.update_db_object(session, switch, **kwargs)
@ -405,7 +408,7 @@ def _filter_vlans(vlan_filter, obj):
location=utils.general_filter_callback location=utils.general_filter_callback
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def _filter_switch_machines(session, user, switch_machines): def _filter_switch_machines(session, switch_machines):
return [ return [
switch_machine for switch_machine in switch_machines switch_machine for switch_machine in switch_machines
if not switch_machine.filtered if not switch_machine.filtered
@ -424,7 +427,7 @@ def _filter_switch_machines(session, user, switch_machines):
RESP_MACHINES_HOSTS_FIELDS, RESP_MACHINES_HOSTS_FIELDS,
clusters=RESP_CLUSTER_FIELDS clusters=RESP_CLUSTER_FIELDS
) )
def _filter_switch_machines_hosts(session, user, switch_machines): def _filter_switch_machines_hosts(session, switch_machines):
filtered_switch_machines = [ filtered_switch_machines = [
switch_machine for switch_machine in switch_machines switch_machine for switch_machine in switch_machines
if not switch_machine.filtered if not switch_machine.filtered
@ -451,12 +454,12 @@ def _filter_switch_machines_hosts(session, user, switch_machines):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_LIST_SWITCH_MACHINES permission.PERMISSION_LIST_SWITCH_MACHINES
) )
def list_switch_machines(getter, switch_id, session=None, **filters): def list_switch_machines(switch_id, user=None, session=None, **filters):
"""Get switch machines.""" """Get switch machines."""
switch_machines = get_switch_machines_internal( switch_machines = get_switch_machines_internal(
session, switch_id=switch_id, **filters session, switch_id=switch_id, **filters
) )
return _filter_switch_machines(session, getter, switch_machines) return _filter_switch_machines(session, switch_machines)
@utils.replace_filters( @utils.replace_filters(
@ -469,13 +472,13 @@ def list_switch_machines(getter, switch_id, session=None, **filters):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_LIST_SWITCH_MACHINES permission.PERMISSION_LIST_SWITCH_MACHINES
) )
def list_switchmachines(lister, session=None, **filters): def list_switchmachines(user=None, session=None, **filters):
"""List switch machines.""" """List switch machines."""
switch_machines = get_switch_machines_internal( switch_machines = get_switch_machines_internal(
session, **filters session, **filters
) )
return _filter_switch_machines( return _filter_switch_machines(
session, lister, switch_machines session, switch_machines
) )
@ -486,13 +489,13 @@ def list_switchmachines(lister, session=None, **filters):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_LIST_SWITCH_MACHINES permission.PERMISSION_LIST_SWITCH_MACHINES
) )
def list_switch_machines_hosts(getter, switch_id, session=None, **filters): def list_switch_machines_hosts(switch_id, user=None, session=None, **filters):
"""Get switch machines hosts.""" """Get switch machines hosts."""
switch_machines = get_switch_machines_internal( switch_machines = get_switch_machines_internal(
session, switch_id=switch_id, **filters session, switch_id=switch_id, **filters
) )
return _filter_switch_machines_hosts( return _filter_switch_machines_hosts(
session, getter, switch_machines session, switch_machines
) )
@ -506,7 +509,7 @@ def list_switch_machines_hosts(getter, switch_id, session=None, **filters):
@user_api.check_user_permission_in_session( @user_api.check_user_permission_in_session(
permission.PERMISSION_LIST_SWITCH_MACHINES permission.PERMISSION_LIST_SWITCH_MACHINES
) )
def list_switchmachines_hosts(lister, session=None, **filters): def list_switchmachines_hosts(user=None, session=None, **filters):
"""List switch machines hosts.""" """List switch machines hosts."""
switch_machines = get_switch_machines_internal( switch_machines = get_switch_machines_internal(
session, **filters session, **filters
@ -518,7 +521,7 @@ def list_switchmachines_hosts(lister, session=None, **filters):
switch_machine for switch_machine in switch_machines switch_machine for switch_machine in switch_machines
] ]
return _filter_switch_machines_hosts( return _filter_switch_machines_hosts(
session, lister, filtered_switch_machines session, filtered_switch_machines
) )
@ -534,9 +537,8 @@ def list_switchmachines_hosts(lister, session=None, **filters):
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def add_switch_machine( def add_switch_machine(
creator, switch_id, switch_id, exception_when_existing=True,
exception_when_existing=True, mac=None, user=None, session=None, **kwargs
mac=None, session=None, **kwargs
): ):
"""Add switch machine.""" """Add switch machine."""
switch = utils.get_db_object( switch = utils.get_db_object(
@ -566,13 +568,13 @@ def add_switch_machine(
permission.PERMISSION_UPDATE_SWITCH_MACHINES permission.PERMISSION_UPDATE_SWITCH_MACHINES
) )
@utils.wrap_to_dict(RESP_ACTION_FIELDS) @utils.wrap_to_dict(RESP_ACTION_FIELDS)
def poll_switch_machines(poller, switch_id, session=None, **kwargs): def poll_switch_machines(switch_id, user=None, session=None, **kwargs):
"""poll switch machines.""" """poll switch machines."""
from compass.tasks import client as celery_client from compass.tasks import client as celery_client
switch = utils.get_db_object(session, models.Switch, id=switch_id) switch = utils.get_db_object(session, models.Switch, id=switch_id)
celery_client.celery.send_task( celery_client.celery.send_task(
'compass.tasks.pollswitch', 'compass.tasks.pollswitch',
(poller.email, switch.ip, switch.credentials) (user.email, switch.ip, switch.credentials)
) )
return { return {
'status': 'action %s sent' % kwargs, 'status': 'action %s sent' % kwargs,
@ -588,8 +590,8 @@ def poll_switch_machines(poller, switch_id, session=None, **kwargs):
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def get_switch_machine( def get_switch_machine(
getter, switch_id, machine_id, switch_id, machine_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""get field dict of a switch machine.""" """get field dict of a switch machine."""
return utils.get_db_object( return utils.get_db_object(
@ -606,9 +608,8 @@ def get_switch_machine(
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def get_switchmachine( def get_switchmachine(
getter, switch_machine_id, switch_machine_id, exception_when_missing=True,
exception_when_missing=True, session=None, user=None, session=None, **kwargs
**kwargs
): ):
"""get field dict of a switch machine.""" """get field dict of a switch machine."""
return utils.get_db_object( return utils.get_db_object(
@ -648,7 +649,7 @@ def update_switch_machine_internal(
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def update_switch_machine( def update_switch_machine(
updater, switch_id, machine_id, switch_id, machine_id, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Update switch machine.""" """Update switch machine."""
@ -672,7 +673,7 @@ def update_switch_machine(
permission.PERMISSION_ADD_SWITCH_MACHINE permission.PERMISSION_ADD_SWITCH_MACHINE
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def update_switchmachine(updater, switch_machine_id, session=None, **kwargs): def update_switchmachine(switch_machine_id, user=None, session=None, **kwargs):
"""Update switch machine.""" """Update switch machine."""
switch_machine = utils.get_db_object( switch_machine = utils.get_db_object(
session, models.SwitchMachine, session, models.SwitchMachine,
@ -701,7 +702,7 @@ def update_switchmachine(updater, switch_machine_id, session=None, **kwargs):
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def patch_switch_machine( def patch_switch_machine(
updater, switch_id, machine_id, switch_id, machine_id, user=None,
session=None, **kwargs session=None, **kwargs
): ):
"""Patch switch machine.""" """Patch switch machine."""
@ -731,7 +732,7 @@ def patch_switch_machine(
permission.PERMISSION_ADD_SWITCH_MACHINE permission.PERMISSION_ADD_SWITCH_MACHINE
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def patch_switchmachine(updater, switch_machine_id, session=None, **kwargs): def patch_switchmachine(switch_machine_id, user=None, session=None, **kwargs):
"""Patch switch machine.""" """Patch switch machine."""
switch_machine = utils.get_db_object( switch_machine = utils.get_db_object(
session, models.SwitchMachine, session, models.SwitchMachine,
@ -749,7 +750,10 @@ def patch_switchmachine(updater, switch_machine_id, session=None, **kwargs):
permission.PERMISSION_DEL_SWITCH_MACHINE permission.PERMISSION_DEL_SWITCH_MACHINE
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def del_switch_machine(deleter, switch_id, machine_id, session=None, **kwargs): def del_switch_machine(
switch_id, machine_id, user=None,
session=None, **kwargs
):
"""Delete switch machine by switch id and machine id.""" """Delete switch machine by switch id and machine id."""
switch_machine = utils.get_db_object( switch_machine = utils.get_db_object(
session, models.SwitchMachine, session, models.SwitchMachine,
@ -777,7 +781,7 @@ def del_switch_machine(deleter, switch_id, machine_id, session=None, **kwargs):
permission.PERMISSION_DEL_SWITCH_MACHINE permission.PERMISSION_DEL_SWITCH_MACHINE
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def del_switchmachine(deleter, switch_machine_id, session=None, **kwargs): def del_switchmachine(switch_machine_id, user=None, session=None, **kwargs):
"""Delete switch machine by switch_machine_id.""" """Delete switch machine by switch_machine_id."""
switch_machine = utils.get_db_object( switch_machine = utils.get_db_object(
session, models.SwitchMachine, session, models.SwitchMachine,
@ -847,9 +851,8 @@ def _set_machines(session, switch, machines):
) )
@utils.wrap_to_dict(RESP_MACHINES_FIELDS) @utils.wrap_to_dict(RESP_MACHINES_FIELDS)
def update_switch_machines( def update_switch_machines(
updater, switch_id, switch_id, add_machines=[], remove_machines=[],
add_machines=[], remove_machines=[], set_machines=None, user=None, session=None, **kwargs
set_machines=None, session=None, **kwargs
): ):
"""update switch machines.""" """update switch machines."""
switch = utils.get_db_object( switch = utils.get_db_object(

View File

@ -101,13 +101,14 @@ def _check_user_permission(session, user, permission):
def check_user_permission_in_session(permission): def check_user_permission_in_session(permission):
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(user, *args, **kwargs): def wrapper(*args, **kwargs):
if 'session' in kwargs.keys(): if 'user' in kwargs.keys() and 'session' in kwargs.keys():
session = kwargs['session'] session = kwargs['session']
user = kwargs['user']
_check_user_permission(session, user, permission)
return func(*args, **kwargs)
else: else:
session = args[-1] return func(*args, **kwargs)
_check_user_permission(session, user, permission)
return func(user, *args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@ -115,14 +116,18 @@ def check_user_permission_in_session(permission):
def check_user_admin(): def check_user_admin():
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(user, *args, **kwargs): def wrapper(*args, **kwargs):
if not user.is_admin: if 'user' in kwargs.keys():
raise exception.Forbidden( user = kwargs['user']
'User %s is not admin.' % ( if not user.is_admin:
user.email raise exception.Forbidden(
'User %s is not admin.' % (
user.email
)
) )
) return func(*args, **kwargs)
return func(user, *args, **kwargs) else:
return func(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@ -130,14 +135,18 @@ def check_user_admin():
def check_user_admin_or_owner(): def check_user_admin_or_owner():
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(user, user_id, *args, **kwargs): def wrapper(user_id, *args, **kwargs):
if not user.is_admin and user.id != user_id: if 'user' in kwargs.keys():
raise exception.Forbidden( user = kwargs['user']
'User %s is not admin or the owner of user id %s.' % ( if not user.is_admin and user.id != user_id:
user.email, user_id raise exception.Forbidden(
'User %s is not admin or the owner of user id %s.' % (
user.email, user_id
)
) )
) return func(user_id, *args, **kwargs)
return func(user, user_id, *args, **kwargs) else:
return func(user_id, *args, **kwargs)
return wrapper return wrapper
return decorator return decorator
@ -266,7 +275,7 @@ def get_user_object_from_token(token, session=None):
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_TOKEN_FIELDS) @utils.wrap_to_dict(RESP_TOKEN_FIELDS)
def record_user_token( def record_user_token(
user, token, expire_timestamp, session=None token, expire_timestamp, user=None, session=None
): ):
"""record user token in database.""" """record user token in database."""
user_token = utils.get_db_object( user_token = utils.get_db_object(
@ -289,7 +298,7 @@ def record_user_token(
@utils.supported_filters() @utils.supported_filters()
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_TOKEN_FIELDS) @utils.wrap_to_dict(RESP_TOKEN_FIELDS)
def clean_user_token(user, token, session=None): def clean_user_token(token, user=None, session=None):
"""clean user token in database.""" """clean user token in database."""
return utils.del_db_objects( return utils.del_db_objects(
session, models.UserToken, session, models.UserToken,
@ -302,8 +311,8 @@ def clean_user_token(user, token, session=None):
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def get_user( def get_user(
getter, user_id, user_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""get field dict of a user.""" """get field dict of a user."""
return utils.get_db_object( return utils.get_db_object(
@ -315,12 +324,12 @@ def get_user(
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def get_current_user( def get_current_user(
getter, exception_when_missing=True, user=None,
exception_when_missing=True, session=None, **kwargs session=None, **kwargs
): ):
"""get field dict of a user.""" """get field dict of a user."""
return utils.get_db_object( return utils.get_db_object(
session, models.User, exception_when_missing, id=getter.id session, models.User, exception_when_missing, id=user.id
) )
@ -330,7 +339,7 @@ def get_current_user(
@check_user_admin() @check_user_admin()
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_users(lister, session=None, **filters): def list_users(user=None, session=None, **filters):
"""List fields of all users by some fields.""" """List fields of all users by some fields."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.User, **filters session, models.User, **filters
@ -347,8 +356,7 @@ def list_users(lister, session=None, **filters):
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def add_user( def add_user(
creator, exception_when_existing=True, user=None,
exception_when_existing=True,
session=None, **kwargs session=None, **kwargs
): ):
"""Create a user and return created user object.""" """Create a user and return created user object."""
@ -361,7 +369,7 @@ def add_user(
@check_user_admin() @check_user_admin()
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def del_user(deleter, user_id, session=None, **kwargs): def del_user(user_id, user=None, session=None, **kwargs):
"""delete a user and return the deleted user object.""" """delete a user and return the deleted user object."""
user = utils.get_db_object(session, models.User, id=user_id) user = utils.get_db_object(session, models.User, id=user_id)
return utils.del_db_object(session, user) return utils.del_db_object(session, user)
@ -374,22 +382,22 @@ def del_user(deleter, user_id, session=None, **kwargs):
@utils.input_validates(email=_check_email) @utils.input_validates(email=_check_email)
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def update_user(updater, user_id, session=None, **kwargs): def update_user(user_id, user=None, session=None, **kwargs):
"""Update a user and return the updated user object.""" """Update a user and return the updated user object."""
user = utils.get_db_object( user = utils.get_db_object(
session, models.User, id=user_id session, models.User, id=user_id
) )
allowed_fields = set() allowed_fields = set()
if updater.is_admin: if user.is_admin:
allowed_fields |= set(ADMIN_UPDATED_FIELDS) allowed_fields |= set(ADMIN_UPDATED_FIELDS)
if updater.id == user_id: if user.id == user_id:
allowed_fields |= set(SELF_UPDATED_FIELDS) allowed_fields |= set(SELF_UPDATED_FIELDS)
unsupported_fields = set(kwargs) - allowed_fields unsupported_fields = set(kwargs) - allowed_fields
if unsupported_fields: if unsupported_fields:
# The user is not allowed to update a user. # The user is not allowed to update a user.
raise exception.Forbidden( raise exception.Forbidden(
'User %s has no permission to update user %s fields %s.' % ( 'User %s has no permission to update user %s fields %s.' % (
updater.email, user.email, unsupported_fields user.email, user.email, unsupported_fields
) )
) )
return utils.update_db_object(session, user, **kwargs) return utils.update_db_object(session, user, **kwargs)
@ -399,7 +407,7 @@ def update_user(updater, user_id, session=None, **kwargs):
@check_user_admin_or_owner() @check_user_admin_or_owner()
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(PERMISSION_RESP_FIELDS) @utils.wrap_to_dict(PERMISSION_RESP_FIELDS)
def get_permissions(lister, user_id, session=None, **kwargs): def get_permissions(user_id, user=None, session=None, **kwargs):
"""List permissions of a user.""" """List permissions of a user."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.UserPermission, user_id=user_id, **kwargs session, models.UserPermission, user_id=user_id, **kwargs
@ -411,8 +419,8 @@ def get_permissions(lister, user_id, session=None, **kwargs):
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(PERMISSION_RESP_FIELDS) @utils.wrap_to_dict(PERMISSION_RESP_FIELDS)
def get_permission( def get_permission(
getter, user_id, permission_id, user_id, permission_id, exception_when_missing=True,
exception_when_missing=True, session=None, **kwargs user=None, session=None, **kwargs
): ):
"""Get a specific user permission.""" """Get a specific user permission."""
return utils.get_db_object( return utils.get_db_object(
@ -427,7 +435,7 @@ def get_permission(
@check_user_admin_or_owner() @check_user_admin_or_owner()
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(PERMISSION_RESP_FIELDS) @utils.wrap_to_dict(PERMISSION_RESP_FIELDS)
def del_permission(deleter, user_id, permission_id, session=None, **kwargs): def del_permission(user_id, permission_id, user=None, session=None, **kwargs):
"""Delete a specific user permission.""" """Delete a specific user permission."""
user_permission = utils.get_db_object( user_permission = utils.get_db_object(
session, models.UserPermission, session, models.UserPermission,
@ -445,8 +453,8 @@ def del_permission(deleter, user_id, permission_id, session=None, **kwargs):
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(PERMISSION_RESP_FIELDS) @utils.wrap_to_dict(PERMISSION_RESP_FIELDS)
def add_permission( def add_permission(
creator, user_id, user_id, exception_when_missing=True,
exception_when_missing=True, permission_id=None, session=None permission_id=None, user=None, session=None
): ):
"""Add an user permission.""" """Add an user permission."""
return utils.add_db_object( return utils.add_db_object(
@ -471,9 +479,8 @@ def _get_permission_filters(permission_ids):
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(PERMISSION_RESP_FIELDS) @utils.wrap_to_dict(PERMISSION_RESP_FIELDS)
def update_permissions( def update_permissions(
updater, user_id, user_id, add_permissions=[], remove_permissions=[],
add_permissions=[], remove_permissions=[], set_permissions=None, user=None, session=None, **kwargs
set_permissions=None, session=None, **kwargs
): ):
"""update user permissions.""" """update user permissions."""
user = utils.get_db_object(session, models.User, id=user_id) user = utils.get_db_object(session, models.User, id=user_id)

View File

@ -24,7 +24,7 @@ from compass.db import models
SUPPORTED_FIELDS = ['user_email', 'timestamp'] SUPPORTED_FIELDS = ['user_email', 'timestamp']
USER_SUPPORTED_FIELDS = ['timestamp'] USER_SUPPORTED_FIELDS = ['timestamp']
RESP_FIELDS = ['user_id', 'logs', 'timestamp'] RESP_FIELDS = ['user_id', 'action', 'timestamp']
@database.run_in_session() @database.run_in_session()
@ -39,7 +39,7 @@ def log_user_action(user_id, action, session=None):
@user_api.check_user_admin_or_owner() @user_api.check_user_admin_or_owner()
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_user_actions(lister, user_id, session=None, **filters): def list_user_actions(user_id, user=None, session=None, **filters):
"""list user actions.""" """list user actions."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.UserLog, order_by=['timestamp'], session, models.UserLog, order_by=['timestamp'],
@ -51,7 +51,7 @@ def list_user_actions(lister, user_id, session=None, **filters):
@user_api.check_user_admin() @user_api.check_user_admin()
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def list_actions(lister, session=None, **filters): def list_actions(user=None, session=None, **filters):
"""list actions.""" """list actions."""
return utils.list_db_objects( return utils.list_db_objects(
session, models.UserLog, order_by=['timestamp'], **filters session, models.UserLog, order_by=['timestamp'], **filters
@ -62,7 +62,7 @@ def list_actions(lister, session=None, **filters):
@user_api.check_user_admin_or_owner() @user_api.check_user_admin_or_owner()
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def del_user_actions(deleter, user_id, session=None, **filters): def del_user_actions(user_id, user=None, session=None, **filters):
"""delete user actions.""" """delete user actions."""
return utils.del_db_objects( return utils.del_db_objects(
session, models.UserLog, user_id=user_id, **filters session, models.UserLog, user_id=user_id, **filters
@ -73,7 +73,7 @@ def del_user_actions(deleter, user_id, session=None, **filters):
@user_api.check_user_admin() @user_api.check_user_admin()
@database.run_in_session() @database.run_in_session()
@utils.wrap_to_dict(RESP_FIELDS) @utils.wrap_to_dict(RESP_FIELDS)
def del_actions(deleter, session=None, **filters): def del_actions(user=None, session=None, **filters):
"""delete actions.""" """delete actions."""
return utils.del_db_objects( return utils.del_db_objects(
session, models.UserLog, **filters session, models.UserLog, **filters

View File

@ -85,7 +85,7 @@ class TestProgressCalculator(unittest2.TestCase):
self.cluster_id = None self.cluster_id = None
# get adapter information # get adapter information
list_adapters = adapter.list_adapters(self.user_object) list_adapters = adapter.list_adapters(user=self.user_object)
for adptr in list_adapters: for adptr in list_adapters:
if ('package_installer' in adptr.keys() and if ('package_installer' in adptr.keys() and
adptr['flavors'] != [] and adptr['flavors'] != [] and
@ -102,13 +102,13 @@ class TestProgressCalculator(unittest2.TestCase):
#add cluster #add cluster
cluster.add_cluster( cluster.add_cluster(
self.user_object,
adapter_id=self.adapter_id, adapter_id=self.adapter_id,
os_id=self.os_id, os_id=self.os_id,
flavor_id=self.flavor_id, flavor_id=self.flavor_id,
name='test_cluster' name='test_cluster',
user=self.user_object,
) )
list_clusters = cluster.list_clusters(self.user_object) list_clusters = cluster.list_clusters(user=self.user_object)
for list_cluster in list_clusters: for list_cluster in list_clusters:
if list_cluster['name'] == 'test_cluster': if list_cluster['name'] == 'test_cluster':
self.cluster_id = list_cluster['id'] self.cluster_id = list_cluster['id']
@ -118,51 +118,51 @@ class TestProgressCalculator(unittest2.TestCase):
#add switch #add switch
switch.add_switch( switch.add_switch(
self.user_object, ip=SWITCH_IP,
ip=SWITCH_IP user=self.user_object,
) )
list_switches = switch.list_switches(self.user_object) list_switches = switch.list_switches(user=self.user_object)
for list_switch in list_switches: for list_switch in list_switches:
self.switch_id = list_switch['id'] self.switch_id = list_switch['id']
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
self.switch_id, self.switch_id,
user=self.user_object,
mac=MACHINE_MAC, mac=MACHINE_MAC,
port='1' port='1'
) )
#get machine information #get machine information
list_machines = machine.list_machines(self.user_object) list_machines = machine.list_machines(user=self.user_object)
for list_machine in list_machines: for list_machine in list_machines:
self.machine_id = list_machine['id'] self.machine_id = list_machine['id']
#add cluster host #add cluster host
cluster.add_cluster_host( cluster.add_cluster_host(
self.user_object,
self.cluster_id, self.cluster_id,
user=self.user_object,
machine_id=self.machine_id, machine_id=self.machine_id,
name='test_clusterhost' name='test_clusterhost'
) )
list_clusterhosts = cluster.list_clusterhosts(self.user_object) list_clusterhosts = cluster.list_clusterhosts(user=self.user_object)
for list_clusterhost in list_clusterhosts: for list_clusterhost in list_clusterhosts:
self.host_id = list_clusterhost['host_id'] self.host_id = list_clusterhost['host_id']
self.clusterhost_id = list_clusterhost['clusterhost_id'] self.clusterhost_id = list_clusterhost['clusterhost_id']
#add subnet #add subnet
network.add_subnet( network.add_subnet(
self.user_object, subnet=SUBNET,
subnet=SUBNET user=self.user_object,
) )
list_subnets = network.list_subnets( list_subnets = network.list_subnets(
self.user_object user=self.user_object
) )
for list_subnet in list_subnets: for list_subnet in list_subnets:
self.subnet_id = list_subnet['id'] self.subnet_id = list_subnet['id']
#add host network #add host network
host.add_host_network( host.add_host_network(
self.user_object,
self.host_id, self.host_id,
user=self.user_object,
interface='eth0', interface='eth0',
ip=HOST_IP, ip=HOST_IP,
subnet_id=self.subnet_id, subnet_id=self.subnet_id,
@ -171,32 +171,32 @@ class TestProgressCalculator(unittest2.TestCase):
#get clusterhost #get clusterhost
list_clusterhosts = cluster.list_clusterhosts( list_clusterhosts = cluster.list_clusterhosts(
self.user_object user=self.user_object
) )
for list_clusterhost in list_clusterhosts: for list_clusterhost in list_clusterhosts:
self.clusterhost_id = list_clusterhost['id'] self.clusterhost_id = list_clusterhost['id']
#update host state #update host state
self.list_hosts = host.list_hosts(self.user_object) self.list_hosts = host.list_hosts(user=self.user_object)
for list_host in self.list_hosts: for list_host in self.list_hosts:
self.host_id = list_host['id'] self.host_id = list_host['id']
self.host_state = host.update_host_state( self.host_state = host.update_host_state(
self.user_object,
self.host_id, self.host_id,
user=self.user_object,
state='INSTALLING' state='INSTALLING'
) )
#update cluster state #update cluster state
cluster.update_cluster_state( cluster.update_cluster_state(
self.user_object,
self.cluster_id, self.cluster_id,
user=self.user_object,
state='INSTALLING' state='INSTALLING'
) )
#update clusterhost state #update clusterhost state
cluster.update_clusterhost_state( cluster.update_clusterhost_state(
self.user_object,
self.clusterhost_id, self.clusterhost_id,
user=self.user_object,
state='INSTALLING' state='INSTALLING'
) )
@ -433,8 +433,8 @@ class TestProgressCalculator(unittest2.TestCase):
self._file_generator('check_point_1') self._file_generator('check_point_1')
update_progress.update_progress() update_progress.update_progress()
clusterhost_state = cluster.get_clusterhost_state( clusterhost_state = cluster.get_clusterhost_state(
self.user_object, self.clusterhost_id,
self.clusterhost_id user=self.user_object,
) )
self.assertAlmostEqual( self.assertAlmostEqual(
clusterhost_state['percentage'], clusterhost_state['percentage'],
@ -446,8 +446,8 @@ class TestProgressCalculator(unittest2.TestCase):
self._file_generator('check_point_2') self._file_generator('check_point_2')
update_progress.update_progress() update_progress.update_progress()
clusterhost_state = cluster.get_clusterhost_state( clusterhost_state = cluster.get_clusterhost_state(
self.user_object, self.clusterhost_id,
self.clusterhost_id user=self.user_object,
) )
self.assertAlmostEqual( self.assertAlmostEqual(
clusterhost_state['percentage'], clusterhost_state['percentage'],
@ -459,8 +459,8 @@ class TestProgressCalculator(unittest2.TestCase):
self._file_generator('check_point_3') self._file_generator('check_point_3')
update_progress.update_progress() update_progress.update_progress()
clusterhost_state = cluster.get_clusterhost_state( clusterhost_state = cluster.get_clusterhost_state(
self.user_object, self.clusterhost_id,
self.clusterhost_id user=self.user_object,
) )
self.assertAlmostEqual( self.assertAlmostEqual(
clusterhost_state['percentage'], clusterhost_state['percentage'],
@ -472,8 +472,8 @@ class TestProgressCalculator(unittest2.TestCase):
self._file_generator('check_point_4') self._file_generator('check_point_4')
update_progress.update_progress() update_progress.update_progress()
clusterhost_state = cluster.get_clusterhost_state( clusterhost_state = cluster.get_clusterhost_state(
self.user_object, self.clusterhost_id,
self.clusterhost_id user=self.user_object,
) )
self.assertAlmostEqual( self.assertAlmostEqual(
clusterhost_state['percentage'], clusterhost_state['percentage'],
@ -485,8 +485,8 @@ class TestProgressCalculator(unittest2.TestCase):
self._file_generator('check_point_5') self._file_generator('check_point_5')
update_progress.update_progress() update_progress.update_progress()
clusterhost_state = cluster.get_clusterhost_state( clusterhost_state = cluster.get_clusterhost_state(
self.user_object, self.clusterhost_id,
self.clusterhost_id user=self.user_object,
) )
self.assertEqual( self.assertEqual(
clusterhost_state['percentage'], clusterhost_state['percentage'],

View File

@ -375,9 +375,9 @@ class TestClusterAPI(ApiTestCase):
) )
) )
cluster_api.update_cluster_state( cluster_api.update_cluster_state(
self.user_object,
1, 1,
state='INSTALLING' state='INSTALLING',
user=self.user_object,
) )
url = '/clusters/1' url = '/clusters/1'
return_value = self.delete(url) return_value = self.delete(url)
@ -577,7 +577,6 @@ class TestSwitchAPI(ApiTestCase):
url = '/switches' url = '/switches'
return_value = self.get(url) return_value = self.get(url)
resp = json.loads(return_value.get_data()) resp = json.loads(return_value.get_data())
print 'list switches: %s' % resp
count = len(resp) count = len(resp)
self.assertEqual(count, 2) self.assertEqual(count, 2)
self.assertEqual(return_value.status_code, 200) self.assertEqual(return_value.status_code, 200)

View File

@ -154,7 +154,7 @@ class TestHealthCheckAPI(ApiTestCase):
# Cluster has been deployed successfully. # Cluster has been deployed successfully.
user = models.User.query.filter_by(email='admin@huawei.com').first() user = models.User.query.filter_by(email='admin@huawei.com').first()
cluster_db.update_cluster_state( cluster_db.update_cluster_state(
user, self.cluster_id, state='SUCCESSFUL' self.cluster_id, user=user, state='SUCCESSFUL'
) )
return_value = self.test_client.post(url, data=request_data) return_value = self.test_client.post(url, data=request_data)
self.assertEqual(202, return_value.status_code) self.assertEqual(202, return_value.status_code)

View File

@ -73,7 +73,7 @@ class AdapterTestCase(unittest2.TestCase):
with database.session() as session: with database.session() as session:
adapter_api.add_adapters_internal(session) adapter_api.add_adapters_internal(session)
adapter.load_adapters() adapter.load_adapters()
self.adapter_object = adapter.list_adapters(self.user_object) self.adapter_object = adapter.list_adapters(user=self.user_object)
for adapter_obj in self.adapter_object: for adapter_obj in self.adapter_object:
if adapter_obj['name'] == 'openstack_icehouse': if adapter_obj['name'] == 'openstack_icehouse':
self.adapter_id = adapter_obj['id'] self.adapter_id = adapter_obj['id']
@ -97,7 +97,7 @@ class TestListAdapters(AdapterTestCase):
def test_list_adapters(self): def test_list_adapters(self):
adapters = adapter.list_adapters( adapters = adapter.list_adapters(
self.user_object user=self.user_object
) )
result = [] result = []
for item in adapters: for item in adapters:
@ -124,8 +124,8 @@ class TestGetAdapter(AdapterTestCase):
def test_get_adapter(self): def test_get_adapter(self):
get_adapter = adapter.get_adapter( get_adapter = adapter.get_adapter(
self.user_object, self.adapter_id,
self.adapter_id user=self.user_object,
) )
name = None name = None
for k, v in get_adapter.items(): for k, v in get_adapter.items():

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -45,16 +45,17 @@ class TestGetMachine(BaseTest):
def test_get_machine(self): def test_get_machine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
get_machine = machine.get_machine( get_machine = machine.get_machine(
self.user_object, 1,
1 user=self.user_object,
) )
self.assertIsNotNone(get_machine) self.assertIsNotNone(get_machine)
self.assertEqual(get_machine['mac'], '28:6e:d4:46:c4:25')
class TestListMachines(BaseTest): class TestListMachines(BaseTest):
@ -68,13 +69,14 @@ class TestListMachines(BaseTest):
def test_list_machines(self): def test_list_machines(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
list_machine = machine.list_machines(self.user_object) list_machine = machine.list_machines(self.user_object)
self.assertIsNotNone(list_machine) self.assertIsNotNone(list_machine)
self.assertEqual(list_machine[0]['mac'], '28:6e:d4:46:c4:25')
class TestUpdateMachine(BaseTest): class TestUpdateMachine(BaseTest):
@ -88,20 +90,26 @@ class TestUpdateMachine(BaseTest):
def test_update_machine(self): def test_update_machine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
machine.update_machine( machine.update_machine(
self.user_object,
1, 1,
tag='test' tag='test',
user=self.user_object,
) )
update_machine = machine.list_machines(self.user_object) update_machine = machine.list_machines(self.user_object)
expected = {'tag': 'test'} expected = {
'id': 1,
'mac': '28:6e:d4:46:c4:25',
'tag': 'test',
'switch_ip': '0.0.0.0',
'port': '1'
}
self.assertTrue( self.assertTrue(
item in update_machine[0].items() for item in expected.items() all(item in update_machine[0].items() for item in expected.items())
) )
@ -116,20 +124,20 @@ class TestPatchMachine(BaseTest):
def test_patch_machine(self): def test_patch_machine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
machine.patch_machine( machine.patch_machine(
self.user_object,
1, 1,
user=self.user_object,
tag={'patched_tag': 'test'} tag={'patched_tag': 'test'}
) )
patch_machine = machine.list_machines(self.user_object) patch_machine = machine.list_machines(self.user_object)
expected = {'patched_tag': 'test'} expected = {'tag': {'patched_tag': 'test'}}
self.assertTrue( self.assertTrue(
item in patch_machine[0].items() for item in expected.items() all(item in patch_machine[0].items() for item in expected.items())
) )
@ -144,14 +152,14 @@ class TestDelMachine(BaseTest):
def test_del_machine(self): def test_del_machine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
machine.del_machine( machine.del_machine(
self.user_object, 1,
1 user=self.user_object,
) )
del_machine = machine.list_machines(self.user_object) del_machine = machine.list_machines(self.user_object)
self.assertEqual([], del_machine) self.assertEqual([], del_machine)

View File

@ -46,15 +46,19 @@ class TestListSubnets(BaseTest):
def test_list_subnets(self): def test_list_subnets(self):
network.add_subnet( network.add_subnet(
self.user_object, subnet='10.145.89.0/24',
subnet='10.145.89.0/24' user=self.user_object,
) )
list_subnet = network.list_subnets( list_subnet = network.list_subnets(
self.user_object user=self.user_object
) )
expected = '10.145.89.0/24' expected = {
'subnet': '10.145.89.0/24',
'id': 1,
'name': '10.145.89.0/24'
}
self.assertTrue( self.assertTrue(
item in list_subnet[0].items() for item in expected all(item in list_subnet[0].items() for item in expected.items())
) )
@ -69,12 +73,12 @@ class TestGetSubnet(BaseTest):
def test_get_subnet(self): def test_get_subnet(self):
network.add_subnet( network.add_subnet(
self.user_object, subnet='10.145.89.0/24',
subnet='10.145.89.0/24' user=self.user_object,
) )
get_subnet = network.get_subnet( get_subnet = network.get_subnet(
self.user_object, 1,
1 user=self.user_object,
) )
self.assertEqual( self.assertEqual(
'10.145.89.0/24', '10.145.89.0/24',
@ -83,8 +87,8 @@ class TestGetSubnet(BaseTest):
def tset_get_subnet_no_exist(self): def tset_get_subnet_no_exist(self):
get_subnet_no_exist = network.get_subnet( get_subnet_no_exist = network.get_subnet(
self.user_object, 2,
2 user=self.user_object,
) )
self.assertEqual([], get_subnet_no_exist) self.assertEqual([], get_subnet_no_exist)
@ -100,11 +104,11 @@ class TestAddSubnet(BaseTest):
def test_add_subnet(self): def test_add_subnet(self):
network.add_subnet( network.add_subnet(
self.user_object, subnet='10.145.89.0/24',
subnet='10.145.89.0/24' user=self.user_object,
) )
add_subnets = network.list_subnets( add_subnets = network.list_subnets(
self.user_object user=self.user_object
) )
expected = '10.145.89.0/24' expected = '10.145.89.0/24'
for add_subnet in add_subnets: for add_subnet in add_subnets:
@ -112,12 +116,12 @@ class TestAddSubnet(BaseTest):
def test_add_subnet_position(self): def test_add_subnet_position(self):
network.add_subnet( network.add_subnet(
self.user_object,
True, True,
'10.145.89.0/23' '10.145.89.0/23',
user=self.user_object,
) )
add_subnets = network.list_subnets( add_subnets = network.list_subnets(
self.user_object user=self.user_object
) )
expected = '10.145.89.0/23' expected = '10.145.89.0/23'
for add_subnet in add_subnets: for add_subnet in add_subnets:
@ -128,30 +132,16 @@ class TestAddSubnet(BaseTest):
network.add_subnet( network.add_subnet(
self.user_object, self.user_object,
subnet='10.145.89.0/22', subnet='10.145.89.0/22',
user=self.user_object,
session=session session=session
) )
add_subnets = network.list_subnets( add_subnets = network.list_subnets(
self.user_object user=self.user_object
) )
expected = '10.145.89.0/22' expected = '10.145.89.0/22'
for add_subnet in add_subnets: for add_subnet in add_subnets:
self.assertEqual(expected, add_subnet['subnet']) self.assertEqual(expected, add_subnet['subnet'])
def test_add_subnet_position_session(self):
with database.session() as session:
network.add_subnet(
self.user_object,
True,
'10.145.89.0/21',
session
)
add_subnets = network.list_subnets(
self.user_object
)
expected = '10.145.89.0/21'
for add_subnet in add_subnets:
self.assertEqual(expected, add_subnet['subnet'])
class TestUpdateSubnet(BaseTest): class TestUpdateSubnet(BaseTest):
"""Test update subnet.""" """Test update subnet."""
@ -164,28 +154,32 @@ class TestUpdateSubnet(BaseTest):
def test_update_subnet(self): def test_update_subnet(self):
network.add_subnet( network.add_subnet(
self.user_object, subnet='10.145.89.0/24',
subnet='10.145.89.0/24' user=self.user_object,
) )
network.update_subnet( network.update_subnet(
self.user_object,
1, 1,
user=self.user_object,
subnet='192.168.100.0/24' subnet='192.168.100.0/24'
) )
update_subnet = network.list_subnets( update_subnet = network.list_subnets(
self.user_object user=self.user_object
) )
expected = '192.168.100.0/24' expected = {
'subnet': '192.168.100.0/24',
'id': 1,
'name': '192.168.100.0/24'
}
self.assertTrue( self.assertTrue(
item in update_subnet[0].items() for item in expected all(item in update_subnet[0].items() for item in expected.items())
) )
def test_update_subnet_no_exist(self): def test_update_subnet_no_exist(self):
self.assertRaises( self.assertRaises(
exception.DatabaseException, exception.DatabaseException,
network.update_subnet, network.update_subnet,
self.user_object, 2,
2 user=self.user_object,
) )
@ -200,15 +194,15 @@ class TestDelSubnet(BaseTest):
def test_del_subnet(self): def test_del_subnet(self):
network.add_subnet( network.add_subnet(
self.user_object, user=self.user_object,
subnet='10.145.89.0/24' subnet='10.145.89.0/24'
) )
network.del_subnet( network.del_subnet(
self.user_object, 1,
1 user=self.user_object,
) )
del_subnet = network.list_subnets( del_subnet = network.list_subnets(
self.user_object user=self.user_object
) )
self.assertEqual([], del_subnet) self.assertEqual([], del_subnet)
@ -216,8 +210,8 @@ class TestDelSubnet(BaseTest):
self.assertRaises( self.assertRaises(
exception.RecordNotExists, exception.RecordNotExists,
network.del_subnet, network.del_subnet,
self.user_object, 2,
2 user=self.user_object,
) )

View File

@ -44,7 +44,7 @@ class TestListPermissions(BaseTest):
super(TestListPermissions, self).tearDown() super(TestListPermissions, self).tearDown()
def test_list_permissions(self): def test_list_permissions(self):
permissions = permission.list_permissions(self.user_object) permissions = permission.list_permissions(user=self.user_object)
self.assertIsNotNone(permissions) self.assertIsNotNone(permissions)
self.assertEqual(54, len(permissions)) self.assertEqual(54, len(permissions))
@ -59,7 +59,9 @@ class TestGetPermission(BaseTest):
super(TestGetPermission, self).tearDown() super(TestGetPermission, self).tearDown()
def test_get_permission(self): def test_get_permission(self):
get_permission = permission.get_permission(self.user_object, 1) get_permission = permission.get_permission(
1,
user=self.user_object)
self.assertIsNotNone(get_permission) self.assertIsNotNone(get_permission)
expected = { expected = {
'alias': 'list permissions', 'alias': 'list permissions',

View File

@ -46,10 +46,11 @@ class TestGetSwitch(BaseTest):
def test_get_switch(self): def test_get_switch(self):
get_switch = switch.get_switch( get_switch = switch.get_switch(
self.user_object, 1,
1 user=self.user_object,
) )
self.assertIsNotNone(get_switch) self.assertIsNotNone(get_switch)
self.assertEqual(get_switch['ip'], '0.0.0.0')
class TestAddSwitch(BaseTest): class TestAddSwitch(BaseTest):
@ -63,43 +64,31 @@ class TestAddSwitch(BaseTest):
def test_add_switch(self): def test_add_switch(self):
add_switch = switch.add_switch( add_switch = switch.add_switch(
self.user_object,
ip='2887583784', ip='2887583784',
user=self.user_object,
) )
expected = '172.29.8.40' expected = '172.29.8.40'
self.assertEqual(expected, add_switch['ip']) self.assertEqual(expected, add_switch['ip'])
def test_add_switch_position_args(self): def test_add_switch_position_args(self):
add_switch = switch.add_switch( add_switch = switch.add_switch(
self.user_object,
True, True,
'2887583784', '2887583784',
user=self.user_object,
) )
print add_switch
expected = '172.29.8.40' expected = '172.29.8.40'
self.assertEqual(expected, add_switch['ip']) self.assertEqual(expected, add_switch['ip'])
def test_add_switch_session(self): def test_add_switch_session(self):
with database.session() as session: with database.session() as session:
add_switch = switch.add_switch( add_switch = switch.add_switch(
self.user_object,
ip='2887583784', ip='2887583784',
user=self.user_object,
session=session session=session
) )
expected = '172.29.8.40' expected = '172.29.8.40'
self.assertEqual(expected, add_switch['ip']) self.assertEqual(expected, add_switch['ip'])
def test_add_switch_position_args_session(self):
with database.session() as session:
add_switch = switch.add_switch(
self.user_object,
True,
'2887583784',
session
)
expected = '172.29.8.40'
self.assertEqual(expected, add_switch['ip'])
class TestListSwitches(BaseTest): class TestListSwitches(BaseTest):
"""Test list switch.""" """Test list switch."""
@ -112,38 +101,39 @@ class TestListSwitches(BaseTest):
def test_list_switches_ip_int_invalid(self): def test_list_switches_ip_int_invalid(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
list_switches = switch.list_switches( list_switches = switch.list_switches(
self.user_object, ip_int='test',
ip_int='test' user=self.user_object,
) )
self.assertEqual(list_switches, []) self.assertEqual(list_switches, [])
def test_list_switches_with_ip_int(self): def test_list_switches_with_ip_int(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
list_switches = switch.list_switches( list_switches = switch.list_switches(
self.user_object, ip_int='2887583784',
ip_int='2887583784' user=self.user_object,
)
expected = '2887583784'
self.assertTrue(
item in expected.items() for item in list_switches[0].items()
) )
expected = '172.29.8.40'
self.assertIsNotNone(list_switches)
self.assertEqual(expected, list_switches[0]['ip'])
def test_list_switches(self): def test_list_switches(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
list_switches = switch.list_switches( list_switches = switch.list_switches(
self.user_object user=self.user_object
) )
expected = '172.29.8.40'
self.assertIsNotNone(list_switches) self.assertIsNotNone(list_switches)
self.assertEqual(expected, list_switches[0]['ip'])
class TestDelSwitch(BaseTest): class TestDelSwitch(BaseTest):
@ -157,11 +147,11 @@ class TestDelSwitch(BaseTest):
def test_del_switch(self): def test_del_switch(self):
switch.del_switch( switch.del_switch(
self.user_object, 1,
1 user=self.user_object,
) )
del_switch = switch.list_switches( del_switch = switch.list_switches(
self.user_object user=self.user_object
) )
self.assertEqual([], del_switch) self.assertEqual([], del_switch)
@ -177,13 +167,13 @@ class TestUpdateSwitch(BaseTest):
def test_update_switch(self): def test_update_switch(self):
switch.update_switch( switch.update_switch(
self.user_object,
1, 1,
user=self.user_object,
vendor='test_update' vendor='test_update'
) )
update_switch = switch.get_switch( update_switch = switch.get_switch(
self.user_object, 1,
1 user=self.user_object,
) )
expected = 'test_update' expected = 'test_update'
self.assertEqual(expected, update_switch['vendor']) self.assertEqual(expected, update_switch['vendor'])
@ -200,23 +190,24 @@ class TestPatchSwitch(BaseTest):
def test_patch_switch(self): def test_patch_switch(self):
switch.patch_switch( switch.patch_switch(
self.user_object,
1, 1,
user=self.user_object,
patched_credentials={ patched_credentials={
'version': '2c', 'version': '2c',
'community': 'public' 'community': 'public'
} }
) )
patch_switch = switch.get_switch( patch_switch = switch.get_switch(
self.user_object, 1,
1 user=self.user_object,
) )
expected = { expected = {
'credentials': {
'version': '2c', 'version': '2c',
'community': 'public' 'community': 'public'}
} }
self.assertTrue( self.assertTrue(
item in expected.items() for item in patch_switch.items() all(item in patch_switch.items() for item in expected.items())
) )
@ -231,9 +222,17 @@ class TestListSwitchFilters(BaseTest):
def test_list_switch_filters(self): def test_list_switch_filters(self):
list_switch_filters = switch.list_switch_filters( list_switch_filters = switch.list_switch_filters(
self.user_object user=self.user_object
) )
expected = {
'ip': '0.0.0.0',
'id': 1,
'filters': 'allow ports all',
}
self.assertIsNotNone(list_switch_filters) self.assertIsNotNone(list_switch_filters)
self.assertTrue(
all(item in list_switch_filters[0].items()
for item in expected.items()))
class TestGetSwitchFilters(BaseTest): class TestGetSwitchFilters(BaseTest):
@ -247,10 +246,18 @@ class TestGetSwitchFilters(BaseTest):
def test_get_swtich_filters(self): def test_get_swtich_filters(self):
get_switch_filter = switch.get_switch_filters( get_switch_filter = switch.get_switch_filters(
self.user_object, 1,
1 user=self.user_object,
) )
expected = {
'ip': '0.0.0.0',
'id': 1,
'filters': 'allow ports all',
}
self.assertIsNotNone(get_switch_filter) self.assertIsNotNone(get_switch_filter)
self.assertTrue(
all(item in get_switch_filter.items()
for item in expected.items()))
class TestUpdateSwitchFilters(BaseTest): class TestUpdateSwitchFilters(BaseTest):
@ -264,8 +271,8 @@ class TestUpdateSwitchFilters(BaseTest):
def test_update_switch_filters(self): def test_update_switch_filters(self):
switch.update_switch_filters( switch.update_switch_filters(
self.user_object,
1, 1,
user=self.user_object,
filters=[ filters=[
{ {
'filter_type': 'allow' 'filter_type': 'allow'
@ -273,15 +280,15 @@ class TestUpdateSwitchFilters(BaseTest):
] ]
) )
update_switch_filters = switch.get_switch_filters( update_switch_filters = switch.get_switch_filters(
self.user_object, 1,
1 user=self.user_object,
) )
expected = { expected = {
'filter_type': 'allow' 'filters': 'allow'
} }
self.assertTrue( self.assertTrue(
item in update_switch_filters[0].items() all(item in update_switch_filters.items()
for item in expected.items() for item in expected.items())
) )
@ -295,9 +302,13 @@ class TestPatchSwitchFilter(BaseTest):
super(TestPatchSwitchFilter, self).tearDown() super(TestPatchSwitchFilter, self).tearDown()
def test_patch_switch_filter(self): def test_patch_switch_filter(self):
switch.add_switch(
ip='2887583784',
user=self.user_object,
)
switch.patch_switch_filter( switch.patch_switch_filter(
self.user_object, 2,
1, user=self.user_object,
patched_filters=[ patched_filters=[
{ {
'filter_type': 'allow' 'filter_type': 'allow'
@ -305,14 +316,15 @@ class TestPatchSwitchFilter(BaseTest):
] ]
) )
patch_switch_filter = switch.get_switch_filters( patch_switch_filter = switch.get_switch_filters(
self.user_object, 2,
1 user=self.user_object,
) )
expected = { expected = {
'filter_type': 'allow' 'filters': 'allow'
} }
self.assertTrue( self.assertTrue(
item in patch_switch_filter[0].items() for item in expected.items() all(item in patch_switch_filter.items()
for item in expected.items())
) )
@ -327,21 +339,21 @@ class TestAddSwitchMachine(BaseTest):
def test_add_switch_machine(self): def test_add_switch_machine(self):
add_switch_machine = switch.add_switch_machine( add_switch_machine = switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
expected = '28:6e:d4:46:c4:25' expected = '28:6e:d4:46:c4:25'
self.assertEqual(expected, add_switch_machine['mac']) self.assertEqual(expected, add_switch_machine['mac'])
def test_add_switch_machine_position_args(self): def test_add_switch_machine_position_args(self):
add_switch_machine = switch.add_switch_machine( add_switch_machine = switch.add_switch_machine(
self.user_object,
1, 1,
True, True,
'28:6e:d4:46:c4:25', '28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
expected = '28:6e:d4:46:c4:25' expected = '28:6e:d4:46:c4:25'
self.assertEqual(expected, add_switch_machine['mac']) self.assertEqual(expected, add_switch_machine['mac'])
@ -349,28 +361,15 @@ class TestAddSwitchMachine(BaseTest):
def test_add_switch_machine_session(self): def test_add_switch_machine_session(self):
with database.session() as session: with database.session() as session:
add_switch_machine = switch.add_switch_machine( add_switch_machine = switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
user=self.user_object,
session=session, session=session,
port='1' port='1'
) )
expected = '28:6e:d4:46:c4:25' expected = '28:6e:d4:46:c4:25'
self.assertEqual(expected, add_switch_machine['mac']) self.assertEqual(expected, add_switch_machine['mac'])
def test_add_switch_machine_position_args_session(self):
with database.session() as session:
add_switch_machine = switch.add_switch_machine(
self.user_object,
1,
True,
'28:6e:d4:46:c4:25',
session,
port='1'
)
expected = '28:6e:d4:46:c4:25'
self.assertEqual(expected, add_switch_machine['mac'])
class TestListSwitchMachines(BaseTest): class TestListSwitchMachines(BaseTest):
"""Test get switch machines.""" """Test get switch machines."""
@ -383,20 +382,32 @@ class TestListSwitchMachines(BaseTest):
def test_list_switch_machines(self): def test_list_switch_machines(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
list_switch_machines = switch.list_switch_machines( list_switch_machines = switch.list_switch_machines(
self.user_object, 2,
2 user=self.user_object,
) )
expected = {
'switch_id': 2,
'id': 1,
'mac': '28:6e:d4:46:c4:25',
'switch_ip': '172.29.8.40',
'machine_id': 1,
'port': '1',
'switch_machine_id': 1
}
self.assertIsNotNone(list_switch_machines) self.assertIsNotNone(list_switch_machines)
self.assertTrue(
all(item in list_switch_machines[0].items()
for item in expected.items()))
class TestListSwitchmachines(BaseTest): class TestListSwitchmachines(BaseTest):
@ -410,54 +421,59 @@ class TestListSwitchmachines(BaseTest):
def test_list_switch_machines_with_ip_int(self): def test_list_switch_machines_with_ip_int(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
list_switch_machines = switch.list_switchmachines( list_switch_machines = switch.list_switchmachines(
self.user_object, switch_ip_int='2887583784',
switch_ip_int='2887583784' user=self.user_object,
) )
expected = '172.29.8.40' expected = {'switch_ip': '172.29.8.40'}
self.assertTrue(expected for item in list_switch_machines[0].items()) self.assertTrue(
all(item in list_switch_machines[0].items()
for item in expected.items()))
def test_list_switch_machines_ip_invalid(self): def test_list_switch_machines_ip_invalid(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
list_switch_machines = switch.list_switchmachines( list_switch_machines = switch.list_switchmachines(
self.user_object, switch_ip_int='test',
switch_ip_int='test' user=self.user_object,
) )
self.assertEqual(list_switch_machines, []) self.assertEqual(list_switch_machines, [])
def test_list_switch_machines_without_ip(self): def test_list_switch_machines_without_ip(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
list_switch_machines = switch.list_switchmachines( list_switch_machines = switch.list_switchmachines(
self.user_object user=self.user_object
) )
self.assertIsNotNone(list_switch_machines) expected = {'switch_ip': '172.29.8.40'}
self.assertTrue(
all(item in list_switch_machines[0].items()
for item in expected.items()))
class TestListSwitchMachinesHosts(BaseTest): class TestListSwitchMachinesHosts(BaseTest):
@ -471,20 +487,31 @@ class TestListSwitchMachinesHosts(BaseTest):
def test_list_hosts(self): def test_list_hosts(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
list_hosts = switch.list_switch_machines_hosts( list_hosts = switch.list_switch_machines_hosts(
self.user_object, 2,
2 user=self.user_object,
) )
self.assertIsNotNone(list_hosts) expected = {
'switch_id': 2,
'id': 1,
'mac': '28:6e:d4:46:c4:25',
'switch_ip': '172.29.8.40',
'machine_id': 1,
'port': '1',
'switch_machine_id': 1
}
self.assertTrue(
all(item in list_hosts[0].items()
for item in expected.items()))
class TestListSwitchmachinesHosts(BaseTest): class TestListSwitchmachinesHosts(BaseTest):
@ -498,53 +525,59 @@ class TestListSwitchmachinesHosts(BaseTest):
def test_list_hosts_with_ip_int(self): def test_list_hosts_with_ip_int(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
list_hosts = switch.list_switchmachines_hosts( list_hosts = switch.list_switchmachines_hosts(
self.user_object, switch_ip_int='2887583784',
switch_ip_int='2887583784' user=self.user_object,
) )
expected = '172.29.8.40' expected = {'switch_ip': '172.29.8.40'}
self.assertTrue(expected for item in list_hosts[0].items()) self.assertTrue(
all(item in list_hosts[0].items()
for item in expected.items()))
def test_list_hosts_ip_invalid(self): def test_list_hosts_ip_invalid(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
list_hosts = switch.list_switchmachines_hosts( list_hosts = switch.list_switchmachines_hosts(
self.user_object, switch_ip_int='test',
switch_ip_int='test' user=self.user_object,
) )
self.assertEqual(list_hosts, []) self.assertEqual(list_hosts, [])
def test_list_hosts_without_ip(self): def test_list_hosts_without_ip(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
list_hosts = switch.list_switchmachines_hosts( list_hosts = switch.list_switchmachines_hosts(
self.user_object user=self.user_object
) )
expected = {'switch_ip': '172.29.8.40'}
self.assertTrue(
all(item in list_hosts[0].items()
for item in expected.items()))
self.assertIsNotNone(list_hosts) self.assertIsNotNone(list_hosts)
@ -559,21 +592,22 @@ class TestGetSwitchMachine(BaseTest):
def test_get_switch_machine(self): def test_get_switch_machine(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
get_switch_machine = switch.get_switch_machine( get_switch_machine = switch.get_switch_machine(
self.user_object,
2, 2,
1 1,
user=self.user_object,
) )
self.assertIsNotNone(get_switch_machine) self.assertIsNotNone(get_switch_machine)
self.assertEqual(get_switch_machine['mac'], '28:6e:d4:46:c4:25')
class TestGetSwitchmachine(BaseTest): class TestGetSwitchmachine(BaseTest):
@ -587,16 +621,17 @@ class TestGetSwitchmachine(BaseTest):
def test_get_switchmachine(self): def test_get_switchmachine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
get_switchmachine = switch.get_switchmachine( get_switchmachine = switch.get_switchmachine(
self.user_object, 1,
1 user=self.user_object,
) )
self.assertIsNotNone(get_switchmachine) self.assertIsNotNone(get_switchmachine)
self.assertEqual(get_switchmachine['mac'], '28:6e:d4:46:c4:25')
class TestUpdateSwitchMachine(BaseTest): class TestUpdateSwitchMachine(BaseTest):
@ -610,24 +645,34 @@ class TestUpdateSwitchMachine(BaseTest):
def test_update_switch_machine(self): def test_update_switch_machine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
switch.update_switch_machine( switch.update_switch_machine(
self.user_object,
1, 1,
1, 1,
tag='test_tag' tag='test_tag',
user=self.user_object,
) )
update_switch_machine = switch.list_switch_machines( update_switch_machine = switch.list_switch_machines(
self.user_object, 1,
1 user=self.user_object,
) )
expected = {'tag': 'test_tag'} expected = {
'switch_id': 1,
'id': 1,
'mac': '28:6e:d4:46:c4:25',
'tag': 'test_tag',
'switch_ip': '0.0.0.0',
'machine_id': 1,
'port': '1',
'switch_machine_id': 1
}
self.assertTrue( self.assertTrue(
item in update_switch_machine[0].items for item in expected.items() all(item in update_switch_machine[0].items()
for item in expected.items())
) )
@ -642,23 +687,32 @@ class TestUpdateSwitchmachine(BaseTest):
def test_update_switchmachine(self): def test_update_switchmachine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
switch.update_switchmachine( switch.update_switchmachine(
self.user_object,
1, 1,
location='test_location' location='test_location',
user=self.user_object,
) )
update_switchmachine = switch.list_switchmachines( update_switchmachine = switch.list_switchmachines(
self.user_object, user=self.user_object,
) )
expected = {'location': 'test_location'} expected = {
'switch_id': 1,
'id': 1,
'mac': '28:6e:d4:46:c4:25',
'location': 'test_location',
'switch_ip': '0.0.0.0',
'machine_id': 1,
'port': '1',
'switch_machine_id': 1
}
self.assertTrue( self.assertTrue(
item in update_switchmachine[0].items() all(item in update_switchmachine[0].items()
for item in expected.items() for item in expected.items())
) )
@ -673,27 +727,29 @@ class TestPatchSwitchMachine(BaseTest):
def test_pathc_switch_machine(self): def test_pathc_switch_machine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
switch.patch_switch_machine( switch.patch_switch_machine(
self.user_object,
1, 1,
1, 1,
user=self.user_object,
patched_tag={ patched_tag={
'patched_tag': 'test_patched_tag' 'patched_tag': 'test_patched_tag'
} }
) )
switch_patch_switch_machine = switch.list_switch_machines( switch_patch_switch_machine = switch.list_switch_machines(
self.user_object, 1,
1 user=self.user_object,
) )
expected = {'patched_tag': 'test_patched_tag'} expected = {'tag': {
'patched_tag': 'test_patched_tag'}
}
self.assertTrue( self.assertTrue(
item in switch_patch_switch_machine[0].items() all(item in switch_patch_switch_machine[0].items()
for item in expected.items() for item in expected.items())
) )
@ -708,24 +764,27 @@ class TestPatchSwitchmachine(BaseTest):
def test_patch_switchmachine(self): def test_patch_switchmachine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
switch.patch_switchmachine( switch.patch_switchmachine(
self.user_object,
1, 1,
user=self.user_object,
patched_location={ patched_location={
'patched_location': 'test_location' 'patched_location': 'test_location'
} }
) )
patch_switchmachine = switch.list_switchmachines( patch_switchmachine = switch.list_switchmachines(
self.user_object user=self.user_object
) )
expected = {'patched_location': 'test_location'} expected = {'location': {
'patched_location': 'test_location'}
}
self.assertTrue( self.assertTrue(
item in patch_switchmachine[0].items() for item in expected.items() all(item in patch_switchmachine[0].items()
for item in expected.items())
) )
@ -740,19 +799,19 @@ class TestDelSwitchMachine(BaseTest):
def test_del_switch_machine(self): def test_del_switch_machine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
switch.del_switch_machine( switch.del_switch_machine(
self.user_object,
1, 1,
1 1,
user=self.user_object,
) )
del_switch_machine = switch.list_switch_machines( del_switch_machine = switch.list_switch_machines(
self.user_object, 1,
1 user=self.user_object,
) )
self.assertEqual([], del_switch_machine) self.assertEqual([], del_switch_machine)
@ -768,17 +827,17 @@ class TestDelSwitchmachine(BaseTest):
def test_switchmachine(self): def test_switchmachine(self):
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
1, 1,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
switch.del_switchmachine( switch.del_switchmachine(
self.user_object, 1,
1 user=self.user_object,
) )
del_switchmachine = switch.list_switchmachines( del_switchmachine = switch.list_switchmachines(
self.user_object user=self.user_object
) )
self.assertEqual([], del_switchmachine) self.assertEqual([], del_switchmachine)
@ -794,23 +853,23 @@ class TestUpdateSwitchMachines(BaseTest):
def test_update_switch_machines_remove(self): def test_update_switch_machines_remove(self):
switch.add_switch( switch.add_switch(
self.user_object, ip='2887583784',
ip='2887583784' user=self.user_object,
) )
switch.add_switch_machine( switch.add_switch_machine(
self.user_object,
2, 2,
mac='28:6e:d4:46:c4:25', mac='28:6e:d4:46:c4:25',
port='1' port='1',
user=self.user_object,
) )
switch.update_switch_machines( switch.update_switch_machines(
self.user_object,
2, 2,
remove_machines=1 remove_machines=1,
user=self.user_object,
) )
update_remove = switch.list_switch_machines( update_remove = switch.list_switch_machines(
self.user_object, 2,
2 user=self.user_object,
) )
self.assertEqual([], update_remove) self.assertEqual([], update_remove)

View File

@ -75,22 +75,25 @@ class TestGetRecordCleanToken(BaseTest):
def test_record_user_token(self): def test_record_user_token(self):
token = user_api.record_user_token( token = user_api.record_user_token(
self.user_object,
'test_token', 'test_token',
datetime.datetime.now() + datetime.timedelta(seconds=10000) datetime.datetime.now() + datetime.timedelta(seconds=10000),
user=self.user_object,
) )
self.assertIsNotNone(token) self.assertIsNotNone(token)
self.assertEqual(token['token'], 'test_token') self.assertEqual(token['token'], 'test_token')
def test_clean_user_token(self): def test_clean_user_token(self):
token = user_api.clean_user_token(self.user_object, 'test_token') token = user_api.clean_user_token(
'test_token',
user=self.user_object,
)
self.assertEqual([], token) self.assertEqual([], token)
def test_get_user_object_from_token(self): def test_get_user_object_from_token(self):
token = user_api.record_user_token( token = user_api.record_user_token(
self.user_object,
'test_token', 'test_token',
datetime.datetime.now() + datetime.timedelta(seconds=10000) datetime.datetime.now() + datetime.timedelta(seconds=10000),
user=self.user_object,
) )
self.assertIsNotNone(token) self.assertIsNotNone(token)
@ -112,9 +115,12 @@ class TestGetUser(BaseTest):
super(TestGetUser, self).tearDown() super(TestGetUser, self).tearDown()
def test_get_user(self): def test_get_user(self):
user = user_api.get_user(self.user_object, self.user_object.id) get_user = user_api.get_user(
self.assertIsNotNone(user) self.user_object.id,
self.assertEqual(user['email'], setting.COMPASS_ADMIN_EMAIL) user=self.user_object
)
self.assertIsNotNone(get_user)
self.assertEqual(get_user['email'], setting.COMPASS_ADMIN_EMAIL)
class TestGetCurrentUser(BaseTest): class TestGetCurrentUser(BaseTest):
@ -128,7 +134,7 @@ class TestGetCurrentUser(BaseTest):
def test_get_current_user(self): def test_get_current_user(self):
current_user = user_api.get_current_user( current_user = user_api.get_current_user(
self.user_object user=self.user_object
) )
self.assertIsNotNone(current_user) self.assertIsNotNone(current_user)
self.assertEqual(current_user['email'], setting.COMPASS_ADMIN_EMAIL) self.assertEqual(current_user['email'], setting.COMPASS_ADMIN_EMAIL)
@ -140,7 +146,7 @@ class TestListUsers(BaseTest):
def setUp(self): def setUp(self):
super(TestListUsers, self).setUp() super(TestListUsers, self).setUp()
user_api.add_user( user_api.add_user(
self.user_object, user=self.user_object,
email='test@huawei.com', email='test@huawei.com',
password='test' password='test'
) )
@ -149,11 +155,13 @@ class TestListUsers(BaseTest):
super(TestListUsers, self).tearDown() super(TestListUsers, self).tearDown()
def test_list_users(self): def test_list_users(self):
user = user_api.list_users(self.user_object) list_users = user_api.list_users(
self.assertIsNotNone(user) user=self.user_object
)
self.assertIsNotNone(list_users)
result = [] result = []
for item in user: for list_user in list_users:
result.append(item['email']) result.append(list_user['email'])
expects = ['test@huawei.com', setting.COMPASS_ADMIN_EMAIL] expects = ['test@huawei.com', setting.COMPASS_ADMIN_EMAIL]
for expect in expects: for expect in expects:
self.assertIn(expect, result) self.assertIn(expect, result)
@ -170,18 +178,18 @@ class TestAddUser(BaseTest):
def test_add_user(self): def test_add_user(self):
user_objs = user_api.add_user( user_objs = user_api.add_user(
self.user_object,
email='test@abc.com', email='test@abc.com',
password='password' password='password',
user=self.user_object,
) )
self.assertEqual('test@abc.com', user_objs['email']) self.assertEqual('test@abc.com', user_objs['email'])
def test_add_user_session(self): def test_add_user_session(self):
with database.session() as session: with database.session() as session:
user_objs = user_api.add_user( user_objs = user_api.add_user(
self.user_object,
email='test@abc.com', email='test@abc.com',
password='password', password='password',
user=self.user_object,
session=session session=session
) )
self.assertEqual('test@abc.com', user_objs['email']) self.assertEqual('test@abc.com', user_objs['email'])
@ -197,8 +205,11 @@ class TestDelUser(BaseTest):
super(TestDelUser, self).tearDown() super(TestDelUser, self).tearDown()
def test_del_user(self): def test_del_user(self):
user_api.del_user(self.user_object, self.user_object.id) user_api.del_user(
del_user = user_api.list_users(self.user_object) self.user_object.id,
user=self.user_object,
)
del_user = user_api.list_users(user=self.user_object)
self.assertEqual([], del_user) self.assertEqual([], del_user)
@ -213,8 +224,8 @@ class TestUpdateUser(BaseTest):
def test_update_admin(self): def test_update_admin(self):
user_objs = user_api.update_user( user_objs = user_api.update_user(
self.user_object,
self.user_object.id, self.user_object.id,
user=self.user_object,
email=setting.COMPASS_ADMIN_EMAIL, email=setting.COMPASS_ADMIN_EMAIL,
firstname='a', firstname='a',
lastname='b', lastname='b',
@ -228,7 +239,7 @@ class TestUpdateUser(BaseTest):
def test_not_admin(self): def test_not_admin(self):
user_api.add_user( user_api.add_user(
self.user_object, user=self.user_object,
email='dummy@abc.com', email='dummy@abc.com',
password='dummy', password='dummy',
is_admin=False is_admin=False
@ -237,8 +248,8 @@ class TestUpdateUser(BaseTest):
self.assertRaises( self.assertRaises(
exception.Forbidden, exception.Forbidden,
user_api.update_user, user_api.update_user,
user_object,
2, 2,
user=user_object,
is_admin=False is_admin=False
) )
@ -254,8 +265,8 @@ class TestGetPermissions(BaseTest):
def test_get_permissions(self): def test_get_permissions(self):
user_permissions = user_api.get_permissions( user_permissions = user_api.get_permissions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object,
) )
self.assertIsNotNone(user_permissions) self.assertIsNotNone(user_permissions)
result = [] result = []
@ -275,9 +286,9 @@ class TestGetPermission(BaseTest):
def test_get_permission(self): def test_get_permission(self):
user_permission = user_api.get_permission( user_permission = user_api.get_permission(
self.user_object,
self.user_object.id, self.user_object.id,
1, 1,
user=self.user_object,
) )
self.assertEqual(user_permission['name'], 'list_permissions') self.assertEqual(user_permission['name'], 'list_permissions')
@ -294,13 +305,13 @@ class TestAddDelUserPermission(BaseTest):
def test_add_permission(self): def test_add_permission(self):
user_api.add_permission( user_api.add_permission(
self.user_object,
self.user_object.id, self.user_object.id,
user=self.user_object,
permission_id=2 permission_id=2
) )
permissions = user_api.get_permissions( permissions = user_api.get_permissions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object,
) )
result = None result = None
for permission in permissions: for permission in permissions:
@ -310,14 +321,14 @@ class TestAddDelUserPermission(BaseTest):
def test_add_permission_position(self): def test_add_permission_position(self):
user_api.add_permission( user_api.add_permission(
self.user_object,
self.user_object.id, self.user_object.id,
True, True,
2 2,
user=self.user_object,
) )
permissions = user_api.get_permissions( permissions = user_api.get_permissions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object,
) )
result = None result = None
for permission in permissions: for permission in permissions:
@ -328,33 +339,14 @@ class TestAddDelUserPermission(BaseTest):
def test_add_permission_session(self): def test_add_permission_session(self):
with database.session() as session: with database.session() as session:
user_api.add_permission( user_api.add_permission(
self.user_object,
self.user_object.id, self.user_object.id,
user=self.user_object,
permission_id=2, permission_id=2,
session=session session=session
) )
permissions = user_api.get_permissions( permissions = user_api.get_permissions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object,
)
result = None
for permission in permissions:
if permission['id'] == 2:
result = permission['name']
self.assertEqual(result, 'list_switches')
def test_add_permission_position_session(self):
with database.session() as session:
user_api.add_permission(
self.user_object,
self.user_object.id,
True,
2,
session
)
permissions = user_api.get_permissions(
self.user_object,
self.user_object.id
) )
result = None result = None
for permission in permissions: for permission in permissions:
@ -364,13 +356,13 @@ class TestAddDelUserPermission(BaseTest):
def test_del_permission(self): def test_del_permission(self):
user_api.del_permission( user_api.del_permission(
self.user_object,
self.user_object.id, self.user_object.id,
1 1,
user=self.user_object,
) )
del_user = user_api.get_permissions( del_user = user_api.get_permissions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object,
) )
self.assertEqual([], del_user) self.assertEqual([], del_user)
@ -386,25 +378,25 @@ class TestUpdatePermissions(BaseTest):
def test_remove_permissions(self): def test_remove_permissions(self):
user_api.update_permissions( user_api.update_permissions(
self.user_object,
self.user_object.id, self.user_object.id,
user=self.user_object,
remove_permissions=1 remove_permissions=1
) )
del_user_permission = user_api.get_permissions( del_user_permission = user_api.get_permissions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object,
) )
self.assertEqual([], del_user_permission) self.assertEqual([], del_user_permission)
def test_add_permissions(self): def test_add_permissions(self):
user_api.update_permissions( user_api.update_permissions(
self.user_object,
self.user_object.id, self.user_object.id,
user=self.user_object,
add_permissions=2 add_permissions=2
) )
permissions = user_api.get_permissions( permissions = user_api.get_permissions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object,
) )
result = None result = None
for permission in permissions: for permission in permissions:

View File

@ -46,16 +46,19 @@ class TestListUserActions(BaseTest):
def test_list_user_actions(self): def test_list_user_actions(self):
user_log.log_user_action( user_log.log_user_action(
self.user_object.id, self.user_object.id,
action='/testaction' action='/users/login'
) )
user_action = user_log.list_user_actions( user_action = user_log.list_user_actions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object
)
self.assertEqual(
1,
user_action[0]['user_id']
) )
expected = {
'action': '/users/login',
'user_id': 1
}
self.assertTrue(
all(item in user_action[0].items()
for item in expected.items()))
def test_list_none_user_actions(self): def test_list_none_user_actions(self):
user_log.log_user_action( user_log.log_user_action(
@ -63,8 +66,8 @@ class TestListUserActions(BaseTest):
action='/testaction' action='/testaction'
) )
user_action = user_log.list_user_actions( user_action = user_log.list_user_actions(
self.user_object, 2,
2 user=self.user_object
) )
self.assertEqual([], user_action) self.assertEqual([], user_action)
@ -83,8 +86,16 @@ class TestListActions(BaseTest):
self.user_object.id, self.user_object.id,
action='/testaction' action='/testaction'
) )
action = user_log.list_actions(self.user_object) action = user_log.list_actions(user=self.user_object)
self.assertIsNotNone(action) self.assertIsNotNone(action)
expected = {
'action': '/testaction',
'user_id': 1
}
print action
self.assertTrue(
all(item in action[0].items()
for item in expected.items()))
class TestDelUserActions(BaseTest): class TestDelUserActions(BaseTest):
@ -102,12 +113,12 @@ class TestDelUserActions(BaseTest):
action='/testaction' action='/testaction'
) )
user_log.del_user_actions( user_log.del_user_actions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object
) )
del_user_action = user_log.list_user_actions( del_user_action = user_log.list_user_actions(
self.user_object, self.user_object.id,
self.user_object.id user=self.user_object
) )
self.assertEqual([], del_user_action) self.assertEqual([], del_user_action)
@ -127,10 +138,10 @@ class TestDelActions(BaseTest):
action='/testaction' action='/testaction'
) )
user_log.del_actions( user_log.del_actions(
self.user_object user=self.user_object
) )
del_action = user_log.list_actions( del_action = user_log.list_actions(
self.user_object user=self.user_object
) )
self.assertEqual([], del_action) self.assertEqual([], del_action)