diff --git a/doc/source/cors.rst b/doc/source/cors.rst index 15bdf10ce5..c17ccd7f82 100644 --- a/doc/source/cors.rst +++ b/doc/source/cors.rst @@ -134,7 +134,10 @@ Test CORS Page } request.open(method, url); - request.setRequestHeader('X-Auth-Token', token); + if (token != '') { + // custom headers always trigger a pre-flight request + request.setRequestHeader('X-Auth-Token', token); + } request.send(null); } diff --git a/swift/common/wsgi.py b/swift/common/wsgi.py index eab0988b17..417b161b03 100644 --- a/swift/common/wsgi.py +++ b/swift/common/wsgi.py @@ -573,7 +573,8 @@ def make_env(env, method=None, path=None, agent='Swift', query_string=None, newenv = {} for name in ('eventlet.posthooks', 'HTTP_USER_AGENT', 'HTTP_HOST', 'PATH_INFO', 'QUERY_STRING', 'REMOTE_USER', 'REQUEST_METHOD', - 'SCRIPT_NAME', 'SERVER_NAME', 'SERVER_PORT', 'HTTP_ORIGIN', + 'SCRIPT_NAME', 'SERVER_NAME', 'SERVER_PORT', + 'HTTP_ORIGIN', 'HTTP_ACCESS_CONTROL_REQUEST_METHOD', 'SERVER_PROTOCOL', 'swift.cache', 'swift.source', 'swift.trans_id', 'swift.authorize_override', 'swift.authorize'): diff --git a/swift/proxy/controllers/base.py b/swift/proxy/controllers/base.py index 7aaf2d5f8d..139d6609f0 100644 --- a/swift/proxy/controllers/base.py +++ b/swift/proxy/controllers/base.py @@ -212,6 +212,10 @@ def cors_validation(func): # Call through to the decorated method resp = func(*a, **kw) + if controller.app.strict_cors_mode and \ + not controller.is_origin_allowed(cors_info, req_origin): + return resp + # Expose, # - simple response headers, # http://www.w3.org/TR/cors/#simple-response-header @@ -219,24 +223,32 @@ def cors_validation(func): # - user metadata headers # - headers provided by the user in # x-container-meta-access-control-expose-headers - expose_headers = ['cache-control', 'content-language', - 'content-type', 'expires', 'last-modified', - 'pragma', 'etag', 'x-timestamp', 'x-trans-id'] - for header in resp.headers: - if header.startswith('X-Container-Meta') or \ - header.startswith('X-Object-Meta'): - expose_headers.append(header.lower()) - if cors_info.get('expose_headers'): - expose_headers.extend( - [header_line.strip() - for header_line in cors_info['expose_headers'].split(' ') - if header_line.strip()]) - resp.headers['Access-Control-Expose-Headers'] = \ - ', '.join(expose_headers) + if 'Access-Control-Expose-Headers' not in resp.headers: + expose_headers = [ + 'cache-control', 'content-language', 'content-type', + 'expires', 'last-modified', 'pragma', 'etag', + 'x-timestamp', 'x-trans-id'] + for header in resp.headers: + if header.startswith('X-Container-Meta') or \ + header.startswith('X-Object-Meta'): + expose_headers.append(header.lower()) + if cors_info.get('expose_headers'): + expose_headers.extend( + [header_line.strip() + for header_line in + cors_info['expose_headers'].split(' ') + if header_line.strip()]) + resp.headers['Access-Control-Expose-Headers'] = \ + ', '.join(expose_headers) # The user agent won't process the response if the Allow-Origin # header isn't included - resp.headers['Access-Control-Allow-Origin'] = req_origin + if 'Access-Control-Allow-Origin' not in resp.headers: + if cors_info['allow_origin'] and \ + cors_info['allow_origin'].strip() == '*': + resp.headers['Access-Control-Allow-Origin'] = '*' + else: + resp.headers['Access-Control-Allow-Origin'] = req_origin return resp else: @@ -1256,7 +1268,10 @@ class Controller(object): list_from_csv(req.headers['Access-Control-Request-Headers'])) # Populate the response with the CORS preflight headers - headers['access-control-allow-origin'] = req_origin_value + if cors.get('allow_origin', '').strip() == '*': + headers['access-control-allow-origin'] = '*' + else: + headers['access-control-allow-origin'] = req_origin_value if cors.get('max_age') is not None: headers['access-control-max-age'] = cors.get('max_age') headers['access-control-allow-methods'] = \ diff --git a/swift/proxy/server.py b/swift/proxy/server.py index d42688b423..5b4a5b7b2d 100644 --- a/swift/proxy/server.py +++ b/swift/proxy/server.py @@ -130,6 +130,8 @@ class Application(object): a.strip() for a in conf.get('cors_allow_origin', '').split(',') if a.strip()] + self.strict_cors_mode = config_true_value( + conf.get('strict_cors_mode', 't')) self.node_timings = {} self.timing_expiry = int(conf.get('timing_expiry', 300)) self.sorting_method = conf.get('sorting_method', 'shuffle').lower() @@ -210,7 +212,8 @@ class Application(object): container_listing_limit=constraints.CONTAINER_LISTING_LIMIT, max_account_name_length=constraints.MAX_ACCOUNT_NAME_LENGTH, max_container_name_length=constraints.MAX_CONTAINER_NAME_LENGTH, - max_object_name_length=constraints.MAX_OBJECT_NAME_LENGTH) + max_object_name_length=constraints.MAX_OBJECT_NAME_LENGTH, + strict_cors_mode=self.strict_cors_mode) def check_config(self): """ diff --git a/test/functional/test_object.py b/test/functional/test_object.py index dad8635da9..9b999b6b90 100755 --- a/test/functional/test_object.py +++ b/test/functional/test_object.py @@ -21,6 +21,7 @@ from uuid import uuid4 from swift_testing import check_response, retry, skip, skip3, \ swift_test_perm, web_front_end +from swift.common.utils import json class TestObject(unittest.TestCase): @@ -619,6 +620,117 @@ class TestObject(unittest.TestCase): self.assertEquals(resp.read(), 'Invalid UTF8 or contains NULL') self.assertEquals(resp.status, 412) + def test_cors(self): + if skip: + raise SkipTest + + def is_strict_mode(url, token, parsed, conn): + conn.request('GET', '/info') + resp = conn.getresponse() + if resp.status // 100 == 2: + info = json.loads(resp.read()) + return info.get('swift', {}).get('strict_cors_mode', False) + return False + + def put_cors_cont(url, token, parsed, conn, orig): + conn.request( + 'PUT', '%s/%s' % (parsed.path, self.container), + '', {'X-Auth-Token': token, + 'X-Container-Meta-Access-Control-Allow-Origin': orig}) + return check_response(conn) + + def put_obj(url, token, parsed, conn, obj): + conn.request( + 'PUT', '%s/%s/%s' % (parsed.path, self.container, obj), + 'test', {'X-Auth-Token': token}) + return check_response(conn) + + def check_cors(url, token, parsed, conn, + method, obj, headers): + if method != 'OPTIONS': + headers['X-Auth-Token'] = token + conn.request( + method, '%s/%s/%s' % (parsed.path, self.container, obj), + '', headers) + return conn.getresponse() + + strict_cors = retry(is_strict_mode) + + resp = retry(put_cors_cont, '*') + resp.read() + self.assertEquals(resp.status // 100, 2) + + resp = retry(put_obj, 'cat') + resp.read() + self.assertEquals(resp.status // 100, 2) + + resp = retry(check_cors, + 'OPTIONS', 'cat', {'Origin': 'http://m.com'}) + self.assertEquals(resp.status, 401) + + resp = retry(check_cors, + 'OPTIONS', 'cat', + {'Origin': 'http://m.com', + 'Access-Control-Request-Method': 'GET'}) + + self.assertEquals(resp.status, 200) + resp.read() + headers = dict((k.lower(), v) for k, v in resp.getheaders()) + self.assertEquals(headers.get('access-control-allow-origin'), + '*') + + resp = retry(check_cors, + 'GET', 'cat', {'Origin': 'http://m.com'}) + self.assertEquals(resp.status, 200) + headers = dict((k.lower(), v) for k, v in resp.getheaders()) + self.assertEquals(headers.get('access-control-allow-origin'), + '*') + + resp = retry(check_cors, + 'GET', 'cat', {'Origin': 'http://m.com', + 'X-Web-Mode': 'True'}) + self.assertEquals(resp.status, 200) + headers = dict((k.lower(), v) for k, v in resp.getheaders()) + self.assertEquals(headers.get('access-control-allow-origin'), + '*') + + #################### + + resp = retry(put_cors_cont, 'http://secret.com') + resp.read() + self.assertEquals(resp.status // 100, 2) + + resp = retry(check_cors, + 'OPTIONS', 'cat', + {'Origin': 'http://m.com', + 'Access-Control-Request-Method': 'GET'}) + resp.read() + self.assertEquals(resp.status, 401) + + if strict_cors: + resp = retry(check_cors, + 'GET', 'cat', {'Origin': 'http://m.com'}) + resp.read() + self.assertEquals(resp.status, 200) + headers = dict((k.lower(), v) for k, v in resp.getheaders()) + self.assertTrue('access-control-allow-origin' not in headers) + + resp = retry(check_cors, + 'GET', 'cat', {'Origin': 'http://secret.com'}) + resp.read() + self.assertEquals(resp.status, 200) + headers = dict((k.lower(), v) for k, v in resp.getheaders()) + self.assertEquals(headers.get('access-control-allow-origin'), + 'http://secret.com') + else: + resp = retry(check_cors, + 'GET', 'cat', {'Origin': 'http://m.com'}) + resp.read() + self.assertEquals(resp.status, 200) + headers = dict((k.lower(), v) for k, v in resp.getheaders()) + self.assertEquals(headers.get('access-control-allow-origin'), + 'http://m.com') + if __name__ == '__main__': unittest.main() diff --git a/test/unit/proxy/test_server.py b/test/unit/proxy/test_server.py index a361c069f6..24807a5f78 100644 --- a/test/unit/proxy/test_server.py +++ b/test/unit/proxy/test_server.py @@ -3822,9 +3822,7 @@ class TestObjectController(unittest.TestCase): req.content_length = 0 resp = controller.OPTIONS(req) self.assertEquals(200, resp.status_int) - self.assertEquals( - 'https://bar.baz', - resp.headers['access-control-allow-origin']) + self.assertEquals('*', resp.headers['access-control-allow-origin']) for verb in 'OPTIONS COPY GET POST PUT DELETE HEAD'.split(): self.assertTrue( verb in resp.headers['access-control-allow-methods']) @@ -3840,10 +3838,11 @@ class TestObjectController(unittest.TestCase): def stubContainerInfo(*args): return { 'cors': { - 'allow_origin': 'http://foo.bar' + 'allow_origin': 'http://not.foo.bar' } } controller.container_info = stubContainerInfo + controller.app.strict_cors_mode = False def objectGET(controller, req): return Response(headers={ @@ -3874,6 +3873,50 @@ class TestObjectController(unittest.TestCase): 'x-trans-id', 'x-object-meta-color']) self.assertEquals(expected_exposed, exposed) + controller.app.strict_cors_mode = True + req = Request.blank( + '/v1/a/c/o.jpg', + {'REQUEST_METHOD': 'GET'}, + headers={'Origin': 'http://foo.bar'}) + + resp = cors_validation(objectGET)(controller, req) + + self.assertEquals(200, resp.status_int) + self.assertTrue('access-control-allow-origin' not in resp.headers) + + def test_CORS_valid_with_obj_headers(self): + with save_globals(): + controller = proxy_server.ObjectController(self.app, 'a', 'c', 'o') + + def stubContainerInfo(*args): + return { + 'cors': { + 'allow_origin': 'http://foo.bar' + } + } + controller.container_info = stubContainerInfo + + def objectGET(controller, req): + return Response(headers={ + 'X-Object-Meta-Color': 'red', + 'X-Super-Secret': 'hush', + 'Access-Control-Allow-Origin': 'http://obj.origin', + 'Access-Control-Expose-Headers': 'x-trans-id' + }) + + req = Request.blank( + '/v1/a/c/o.jpg', + {'REQUEST_METHOD': 'GET'}, + headers={'Origin': 'http://foo.bar'}) + + resp = cors_validation(objectGET)(controller, req) + + self.assertEquals(200, resp.status_int) + self.assertEquals('http://obj.origin', + resp.headers['access-control-allow-origin']) + self.assertEquals('x-trans-id', + resp.headers['access-control-expose-headers']) + def _gather_x_container_headers(self, controller_call, req, *connect_args, **kwargs): header_list = kwargs.pop('header_list', ['X-Container-Device', @@ -4841,9 +4884,7 @@ class TestContainerController(unittest.TestCase): req.content_length = 0 resp = controller.OPTIONS(req) self.assertEquals(200, resp.status_int) - self.assertEquals( - 'https://bar.baz', - resp.headers['access-control-allow-origin']) + self.assertEquals('*', resp.headers['access-control-allow-origin']) for verb in 'OPTIONS GET POST PUT DELETE HEAD'.split(): self.assertTrue( verb in resp.headers['access-control-allow-methods'])