Merge "Call policy.init() once per API request"

This commit is contained in:
Jenkins 2014-04-07 11:03:53 +00:00 committed by Gerrit Code Review
commit 92470a45d3
3 changed files with 14 additions and 7 deletions

View File

@ -175,6 +175,8 @@ class Controller(object):
if name in self._member_actions: if name in self._member_actions:
def _handle_action(request, id, **kwargs): def _handle_action(request, id, **kwargs):
arg_list = [request.context, id] arg_list = [request.context, id]
# Ensure policy engine is initialized
policy.init()
# Fetch the resource and verify if the user can access it # Fetch the resource and verify if the user can access it
try: try:
resource = self._item(request, id, True) resource = self._item(request, id, True)
@ -185,9 +187,6 @@ class Controller(object):
# Explicit comparison with None to distinguish from {} # Explicit comparison with None to distinguish from {}
if body is not None: if body is not None:
arg_list.append(body) arg_list.append(body)
# TODO(salvatore-orlando): bp/make-authz-ortogonal
# The body of the action request should be included
# in the info passed to the policy engine
# It is ok to raise a 403 because accessibility to the # It is ok to raise a 403 because accessibility to the
# object was checked earlier in this method # object was checked earlier in this method
policy.enforce(request.context, name, resource) policy.enforce(request.context, name, resource)
@ -253,7 +252,6 @@ class Controller(object):
pagination_links = pagination_helper.get_links(obj_list) pagination_links = pagination_helper.get_links(obj_list)
if pagination_links: if pagination_links:
collection[self._collection + "_links"] = pagination_links collection[self._collection + "_links"] = pagination_links
return collection return collection
def _item(self, request, id, do_authz=False, field_list=None, def _item(self, request, id, do_authz=False, field_list=None,
@ -284,6 +282,8 @@ class Controller(object):
def index(self, request, **kwargs): def index(self, request, **kwargs):
"""Returns a list of the requested entity.""" """Returns a list of the requested entity."""
parent_id = kwargs.get(self._parent_id_name) parent_id = kwargs.get(self._parent_id_name)
# Ensure policy engine is initialized
policy.init()
return self._items(request, True, parent_id) return self._items(request, True, parent_id)
def show(self, request, id, **kwargs): def show(self, request, id, **kwargs):
@ -295,6 +295,8 @@ class Controller(object):
field_list, added_fields = self._do_field_list( field_list, added_fields = self._do_field_list(
api_common.list_args(request, "fields")) api_common.list_args(request, "fields"))
parent_id = kwargs.get(self._parent_id_name) parent_id = kwargs.get(self._parent_id_name)
# Ensure policy engine is initialized
policy.init()
return {self._resource: return {self._resource:
self._view(request.context, self._view(request.context,
self._item(request, self._item(request,
@ -363,6 +365,8 @@ class Controller(object):
else: else:
items = [body] items = [body]
bulk = False bulk = False
# Ensure policy engine is initialized
policy.init()
for item in items: for item in items:
self._validate_network_tenant_ownership(request, self._validate_network_tenant_ownership(request,
item[self._resource]) item[self._resource])
@ -433,6 +437,7 @@ class Controller(object):
action = self._plugin_handlers[self.DELETE] action = self._plugin_handlers[self.DELETE]
# Check authz # Check authz
policy.init()
parent_id = kwargs.get(self._parent_id_name) parent_id = kwargs.get(self._parent_id_name)
obj = self._item(request, id, parent_id=parent_id) obj = self._item(request, id, parent_id=parent_id)
try: try:
@ -484,6 +489,8 @@ class Controller(object):
if (value.get('required_by_policy') or if (value.get('required_by_policy') or
value.get('primary_key') or value.get('primary_key') or
'default' not in value)] 'default' not in value)]
# Ensure policy engine is initialized
policy.init()
orig_obj = self._item(request, id, field_list=field_list, orig_obj = self._item(request, id, field_list=field_list,
parent_id=parent_id) parent_id=parent_id)
orig_object_copy = copy.copy(orig_obj) orig_object_copy = copy.copy(orig_obj)

View File

@ -164,7 +164,6 @@ def _build_match_rule(action, target):
action is being executed action is being executed
(e.g.: create_router:external_gateway_info:network_id) (e.g.: create_router:external_gateway_info:network_id)
""" """
match_rule = policy.RuleCheck('rule', action) match_rule = policy.RuleCheck('rule', action)
resource, is_write = get_resource_and_action(action) resource, is_write = get_resource_and_action(action)
# Attribute-based checks shall not be enforced on GETs # Attribute-based checks shall not be enforced on GETs
@ -317,7 +316,6 @@ class FieldCheck(policy.Check):
def _prepare_check(context, action, target): def _prepare_check(context, action, target):
"""Prepare rule, target, and credentials for the policy engine.""" """Prepare rule, target, and credentials for the policy engine."""
init()
# Compare with None to distinguish case in which target is {} # Compare with None to distinguish case in which target is {}
if target is None: if target is None:
target = {} target = {}
@ -374,7 +372,6 @@ def enforce(context, action, target, plugin=None):
:raises neutron.exceptions.PolicyNotAuthorized: if verification fails. :raises neutron.exceptions.PolicyNotAuthorized: if verification fails.
""" """
init()
rule, target, credentials = _prepare_check(context, action, target) rule, target, credentials = _prepare_check(context, action, target)
result = policy.check(rule, target, credentials, action=action) result = policy.check(rule, target, credentials, action=action)
if not result: if not result:

View File

@ -53,12 +53,14 @@ class PolicyFileTestCase(base.BaseTestCase):
action = "example:test" action = "example:test"
with open(tmpfilename, "w") as policyfile: with open(tmpfilename, "w") as policyfile:
policyfile.write("""{"example:test": ""}""") policyfile.write("""{"example:test": ""}""")
policy.init()
policy.enforce(self.context, action, self.target) policy.enforce(self.context, action, self.target)
with open(tmpfilename, "w") as policyfile: with open(tmpfilename, "w") as policyfile:
policyfile.write("""{"example:test": "!"}""") policyfile.write("""{"example:test": "!"}""")
# NOTE(vish): reset stored policy cache so we don't have to # NOTE(vish): reset stored policy cache so we don't have to
# sleep(1) # sleep(1)
policy._POLICY_CACHE = {} policy._POLICY_CACHE = {}
policy.init()
self.assertRaises(exceptions.PolicyNotAuthorized, self.assertRaises(exceptions.PolicyNotAuthorized,
policy.enforce, policy.enforce,
self.context, self.context,
@ -471,6 +473,7 @@ class NeutronPolicyTestCase(base.BaseTestCase):
# Trigger a policy with rule admin_or_owner # Trigger a policy with rule admin_or_owner
action = "create_network" action = "create_network"
target = {'tenant_id': 'fake'} target = {'tenant_id': 'fake'}
policy.init()
self.assertRaises(exceptions.PolicyCheckError, self.assertRaises(exceptions.PolicyCheckError,
policy.enforce, policy.enforce,
self.context, action, target) self.context, action, target)