diff --git a/swift/common/middleware/s3api/acl_handlers.py b/swift/common/middleware/s3api/acl_handlers.py index 17d5101f7b..7f43e664ec 100644 --- a/swift/common/middleware/s3api/acl_handlers.py +++ b/swift/common/middleware/s3api/acl_handlers.py @@ -76,40 +76,28 @@ class BaseAclHandler(object): """ BaseAclHandler: Handling ACL for basic requests mapped on ACL_MAP """ - def __init__(self, req, logger): + def __init__(self, req, logger, container=None, obj=None, headers=None): self.req = req - self.container = self.req.container_name - self.obj = self.req.object_name + self.container = req.container_name if container is None else container + self.obj = req.object_name if obj is None else obj self.method = req.environ['REQUEST_METHOD'] self.user_id = self.req.user_id - self.headers = self.req.headers + self.headers = req.headers if headers is None else headers self.logger = logger @contextmanager - def request_with(self, container=None, obj=None, headers=None): - try: - org_cont = self.container - org_obj = self.obj - org_headers = self.headers - - self.container = container or org_cont - self.obj = obj or org_obj - self.headers = headers or org_headers - yield - - finally: - self.container = org_cont - self.obj = org_obj - self.headers = org_headers + def request_with(self, container, obj, headers): + yield type(self)(self.req, self.logger, + container=container, obj=obj, headers=headers) def handle_acl(self, app, method, container=None, obj=None, headers=None): method = method or self.method - with self.request_with(container, obj, headers): - if hasattr(self, method): - return getattr(self, method)(app) + with self.request_with(container, obj, headers) as ah: + if hasattr(ah, method): + return getattr(ah, method)(app) else: - return self._handle_acl(app, method) + return ah._handle_acl(app, method) def _handle_acl(self, app, sw_method, container=None, obj=None, permission=None, headers=None): @@ -338,16 +326,16 @@ class MultiUploadAclHandler(BaseAclHandler): ========== ====== ============= ========== """ - def __init__(self, req, logger): - super(MultiUploadAclHandler, self).__init__(req, logger) + def __init__(self, req, logger, **kwargs): + super(MultiUploadAclHandler, self).__init__(req, logger, **kwargs) self.acl_checked = False def handle_acl(self, app, method, container=None, obj=None, headers=None): method = method or self.method - with self.request_with(container, obj, headers): + with self.request_with(container, obj, headers) as ah: # MultiUpload stuffs don't need acl check basically. - if hasattr(self, method): - return getattr(self, method)(app) + if hasattr(ah, method): + return getattr(ah, method)(app) else: pass @@ -360,9 +348,9 @@ class PartAclHandler(MultiUploadAclHandler): """ PartAclHandler: Handler for PartController """ - def __init__(self, req, logger): + def __init__(self, req, logger, **kwargs): # pylint: disable-msg=E1003 - super(MultiUploadAclHandler, self).__init__(req, logger) + super(MultiUploadAclHandler, self).__init__(req, logger, **kwargs) def HEAD(self, app): if self.container.endswith(MULTIUPLOAD_SUFFIX):