diff --git a/swift/account/server.py b/swift/account/server.py index cf160e9e21..1d892fba4c 100644 --- a/swift/account/server.py +++ b/swift/account/server.py @@ -32,6 +32,7 @@ from swift.common.utils import get_logger, hash_path, public, \ from swift.common.constraints import check_mount, valid_timestamp, check_utf8 from swift.common import constraints from swift.common.db_replicator import ReplicatorRpc +from swift.common.base_storage_server import BaseStorageServer from swift.common.swob import HTTPAccepted, HTTPBadRequest, \ HTTPCreated, HTTPForbidden, HTTPInternalServerError, \ HTTPMethodNotAllowed, HTTPNoContent, HTTPNotFound, \ @@ -40,18 +41,15 @@ from swift.common.swob import HTTPAccepted, HTTPBadRequest, \ from swift.common.request_helpers import is_sys_or_user_meta -class AccountController(object): +class AccountController(BaseStorageServer): """WSGI controller for the account server.""" def __init__(self, conf, logger=None): + super(AccountController, self).__init__(conf) self.logger = logger or get_logger(conf, log_route='account-server') self.log_requests = config_true_value(conf.get('log_requests', 'true')) self.root = conf.get('devices', '/srv/node') self.mount_check = config_true_value(conf.get('mount_check', 'true')) - replication_server = conf.get('replication_server', None) - if replication_server is not None: - replication_server = config_true_value(replication_server) - self.replication_server = replication_server self.replicator_rpc = ReplicatorRpc(self.root, DATADIR, AccountBroker, self.mount_check, logger=self.logger) @@ -262,15 +260,12 @@ class AccountController(object): try: # disallow methods which are not publicly accessible try: - method = getattr(self, req.method) - getattr(method, 'publicly_accessible') - replication_method = getattr(method, 'replication', False) - if (self.replication_server is not None and - self.replication_server != replication_method): + if req.method not in self.allowed_methods: raise AttributeError('Not allowed method.') except AttributeError: res = HTTPMethodNotAllowed() else: + method = getattr(self, req.method) res = method(req) except HTTPException as error_response: res = error_response diff --git a/swift/common/base_storage_server.py b/swift/common/base_storage_server.py new file mode 100644 index 0000000000..4eb024b270 --- /dev/null +++ b/swift/common/base_storage_server.py @@ -0,0 +1,70 @@ +# Copyright (c) 2010-2014 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from swift.common.utils import public, timing_stats, config_true_value +from swift.common.swob import Response + + +class BaseStorageServer(object): + """ + Implements common OPTIONS method for object, account, container servers. + """ + + def __init__(self, conf, **kwargs): + self._allowed_methods = None + replication_server = conf.get('replication_server', None) + if replication_server is not None: + replication_server = config_true_value(replication_server) + self.replication_server = replication_server + + @property + def allowed_methods(self): + if self._allowed_methods is None: + self._allowed_methods = [] + all_methods = inspect.getmembers(self, predicate=callable) + + if self.replication_server is True: + for name, m in all_methods: + if (getattr(m, 'publicly_accessible', False) and + getattr(m, 'replication', False)): + self._allowed_methods.append(name) + elif self.replication_server is False: + for name, m in all_methods: + if (getattr(m, 'publicly_accessible', False) and not + getattr(m, 'replication', False)): + self._allowed_methods.append(name) + elif self.replication_server is None: + for name, m in all_methods: + if getattr(m, 'publicly_accessible', False): + self._allowed_methods.append(name) + + self._allowed_methods.sort() + return self._allowed_methods + + @public + @timing_stats() + def OPTIONS(self, req): + """ + Base handler for OPTIONS requests + + :param req: swob.Request object + :returns: swob.Response object + """ + # Prepare the default response + headers = {'Allow': ', '.join(self.allowed_methods)} + resp = Response(status=200, request=req, headers=headers) + + return resp diff --git a/swift/container/server.py b/swift/container/server.py index 6ef713188d..a174342147 100644 --- a/swift/container/server.py +++ b/swift/container/server.py @@ -38,6 +38,7 @@ from swift.common.bufferedhttp import http_connect from swift.common.exceptions import ConnectionTimeout from swift.common.http import HTTP_NOT_FOUND, is_success from swift.common.storage_policy import POLICIES +from swift.common.base_storage_server import BaseStorageServer from swift.common.swob import HTTPAccepted, HTTPBadRequest, HTTPConflict, \ HTTPCreated, HTTPInternalServerError, HTTPNoContent, HTTPNotFound, \ HTTPPreconditionFailed, HTTPMethodNotAllowed, Request, Response, \ @@ -71,7 +72,7 @@ def gen_resp_headers(info, is_deleted=False): return headers -class ContainerController(object): +class ContainerController(BaseStorageServer): """WSGI Controller for the container server.""" # Ensure these are all lowercase @@ -79,16 +80,13 @@ class ContainerController(object): 'x-container-sync-key', 'x-container-sync-to'] def __init__(self, conf, logger=None): + super(ContainerController, self).__init__(conf) self.logger = logger or get_logger(conf, log_route='container-server') self.log_requests = config_true_value(conf.get('log_requests', 'true')) self.root = conf.get('devices', '/srv/node') self.mount_check = config_true_value(conf.get('mount_check', 'true')) self.node_timeout = int(conf.get('node_timeout', 3)) self.conn_timeout = float(conf.get('conn_timeout', 0.5)) - replication_server = conf.get('replication_server', None) - if replication_server is not None: - replication_server = config_true_value(replication_server) - self.replication_server = replication_server #: ContainerSyncCluster instance for validating sync-to values. self.realms_conf = ContainerSyncRealms( os.path.join( @@ -564,15 +562,12 @@ class ContainerController(object): try: # disallow methods which have not been marked 'public' try: - method = getattr(self, req.method) - getattr(method, 'publicly_accessible') - replication_method = getattr(method, 'replication', False) - if (self.replication_server is not None and - self.replication_server != replication_method): + if req.method not in self.allowed_methods: raise AttributeError('Not allowed method.') except AttributeError: res = HTTPMethodNotAllowed() else: + method = getattr(self, req.method) res = method(req) except HTTPException as error_response: res = error_response diff --git a/swift/obj/server.py b/swift/obj/server.py index 3ee7e06887..72672609ba 100644 --- a/swift/obj/server.py +++ b/swift/obj/server.py @@ -40,6 +40,7 @@ from swift.common.exceptions import ConnectionTimeout, DiskFileQuarantined, \ DiskFileXattrNotSupported from swift.obj import ssync_receiver from swift.common.http import is_success +from swift.common.base_storage_server import BaseStorageServer from swift.common.request_helpers import get_name_and_placement, \ is_user_meta, is_sys_or_user_meta from swift.common.swob import HTTPAccepted, HTTPBadRequest, HTTPCreated, \ @@ -64,7 +65,7 @@ class EventletPlungerString(str): return wsgi.MINIMUM_CHUNK_SIZE + 1 -class ObjectController(object): +class ObjectController(BaseStorageServer): """Implements the WSGI application for the Swift Object Server.""" def __init__(self, conf, logger=None): @@ -74,6 +75,7 @@ class ObjectController(object): /etc/object-server.conf-sample or /etc/swift/object-server.conf-sample. """ + super(ObjectController, self).__init__(conf) self.logger = logger or get_logger(conf, log_route='object-server') self.node_timeout = int(conf.get('node_timeout', 3)) self.conn_timeout = float(conf.get('conn_timeout', 0.5)) @@ -85,10 +87,6 @@ class ObjectController(object): self.slow = int(conf.get('slow', 0)) self.keep_cache_private = \ config_true_value(conf.get('keep_cache_private', 'false')) - replication_server = conf.get('replication_server', None) - if replication_server is not None: - replication_server = config_true_value(replication_server) - self.replication_server = replication_server default_allowed_headers = ''' content-disposition, @@ -708,15 +706,12 @@ class ObjectController(object): try: # disallow methods which have not been marked 'public' try: - method = getattr(self, req.method) - getattr(method, 'publicly_accessible') - replication_method = getattr(method, 'replication', False) - if (self.replication_server is not None and - self.replication_server != replication_method): + if req.method not in self.allowed_methods: raise AttributeError('Not allowed method.') except AttributeError: res = HTTPMethodNotAllowed() else: + method = getattr(self, req.method) res = method(req) except DiskFileCollision: res = HTTPForbidden(request=req) diff --git a/test/unit/account/test_server.py b/test/unit/account/test_server.py index c18c57edb1..d4c514d273 100644 --- a/test/unit/account/test_server.py +++ b/test/unit/account/test_server.py @@ -55,6 +55,18 @@ class TestAccountController(unittest.TestCase): if err.errno != errno.ENOENT: raise + def test_OPTIONS(self): + server_handler = AccountController( + {'devices': self.testdir, 'mount_check': 'false'}) + req = Request.blank('/sda1/p/a/c/o', {'REQUEST_METHOD': 'OPTIONS'}) + req.content_length = 0 + resp = server_handler.OPTIONS(req) + self.assertEquals(200, resp.status_int) + for verb in 'OPTIONS GET POST PUT DELETE HEAD REPLICATE'.split(): + self.assertTrue( + verb in resp.headers['Allow'].split(', ')) + self.assertEquals(len(resp.headers['Allow'].split(', ')), 7) + def test_DELETE_not_found(self): req = Request.blank('/sda1/p/a', environ={'REQUEST_METHOD': 'DELETE', 'HTTP_X_TIMESTAMP': '0'}) @@ -1600,7 +1612,7 @@ class TestAccountController(unittest.TestCase): with mock.patch.object(self.controller, method, new=mock_method): mock_method.replication = False - response = self.controller.__call__(env, start_response) + response = self.controller(env, start_response) self.assertEqual(response, method_res) def test_not_allowed_method(self): @@ -1642,6 +1654,38 @@ class TestAccountController(unittest.TestCase): response = self.controller.__call__(env, start_response) self.assertEqual(response, answer) + def test_call_incorrect_replication_method(self): + inbuf = StringIO() + errbuf = StringIO() + outbuf = StringIO() + self.controller = AccountController( + {'devices': self.testdir, 'mount_check': 'false', + 'replication_server': 'true'}) + + def start_response(*args): + """Sends args to outbuf""" + outbuf.writelines(args) + + obj_methods = ['DELETE', 'PUT', 'HEAD', 'GET', 'POST', 'OPTIONS'] + for method in obj_methods: + env = {'REQUEST_METHOD': method, + 'SCRIPT_NAME': '', + 'PATH_INFO': '/sda1/p/a/c', + 'SERVER_NAME': '127.0.0.1', + 'SERVER_PORT': '8080', + 'SERVER_PROTOCOL': 'HTTP/1.0', + 'CONTENT_LENGTH': '0', + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': 'http', + 'wsgi.input': inbuf, + 'wsgi.errors': errbuf, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False} + self.controller(env, start_response) + self.assertEquals(errbuf.getvalue(), '') + self.assertEquals(outbuf.getvalue()[:4], '405 ') + def test_GET_log_requests_true(self): self.controller.logger = FakeLogger() self.controller.log_requests = True diff --git a/test/unit/container/test_server.py b/test/unit/container/test_server.py index bbab114913..44f19de9af 100644 --- a/test/unit/container/test_server.py +++ b/test/unit/container/test_server.py @@ -301,6 +301,18 @@ class TestContainerController(unittest.TestCase): resp = req.get_response(self.controller) self.assertEquals(resp.status_int, 400) + def test_OPTIONS(self): + server_handler = container_server.ContainerController( + {'devices': self.testdir, 'mount_check': 'false'}) + req = Request.blank('/sda1/p/a/c/o', {'REQUEST_METHOD': 'OPTIONS'}) + req.content_length = 0 + resp = server_handler.OPTIONS(req) + self.assertEquals(200, resp.status_int) + for verb in 'OPTIONS GET POST PUT DELETE HEAD REPLICATE'.split(): + self.assertTrue( + verb in resp.headers['Allow'].split(', ')) + self.assertEquals(len(resp.headers['Allow'].split(', ')), 7) + def test_PUT(self): req = Request.blank( '/sda1/p/a/c', environ={'REQUEST_METHOD': 'PUT', @@ -2474,7 +2486,7 @@ class TestContainerController(unittest.TestCase): method_res = mock.MagicMock() mock_method = public(lambda x: mock.MagicMock(return_value=method_res)) with mock.patch.object(self.controller, method, new=mock_method): - response = self.controller.__call__(env, start_response) + response = self.controller(env, start_response) self.assertEqual(response, method_res) def test_not_allowed_method(self): @@ -2515,6 +2527,38 @@ class TestContainerController(unittest.TestCase): response = self.controller.__call__(env, start_response) self.assertEqual(response, answer) + def test_call_incorrect_replication_method(self): + inbuf = StringIO() + errbuf = StringIO() + outbuf = StringIO() + self.controller = container_server.ContainerController( + {'devices': self.testdir, 'mount_check': 'false', + 'replication_server': 'true'}) + + def start_response(*args): + """Sends args to outbuf""" + outbuf.writelines(args) + + obj_methods = ['DELETE', 'PUT', 'HEAD', 'GET', 'POST', 'OPTIONS'] + for method in obj_methods: + env = {'REQUEST_METHOD': method, + 'SCRIPT_NAME': '', + 'PATH_INFO': '/sda1/p/a/c', + 'SERVER_NAME': '127.0.0.1', + 'SERVER_PORT': '8080', + 'SERVER_PROTOCOL': 'HTTP/1.0', + 'CONTENT_LENGTH': '0', + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': 'http', + 'wsgi.input': inbuf, + 'wsgi.errors': errbuf, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False} + self.controller(env, start_response) + self.assertEquals(errbuf.getvalue(), '') + self.assertEquals(outbuf.getvalue()[:4], '405 ') + def test_GET_log_requests_true(self): self.controller.logger = FakeLogger() self.controller.log_requests = True diff --git a/test/unit/obj/test_server.py b/test/unit/obj/test_server.py index 27bc9637ff..bc0ac4274a 100755 --- a/test/unit/obj/test_server.py +++ b/test/unit/obj/test_server.py @@ -1122,6 +1122,20 @@ class TestObjectController(unittest.TestCase): os.path.basename(os.path.dirname(disk_file._data_file))) self.assertEquals(os.listdir(quar_dir)[0], file_name) + def test_OPTIONS(self): + conf = {'devices': self.testdir, 'mount_check': 'false'} + server_handler = object_server.ObjectController( + conf, logger=debug_logger()) + req = Request.blank('/sda1/p/a/c/o', {'REQUEST_METHOD': 'OPTIONS'}) + req.content_length = 0 + resp = server_handler.OPTIONS(req) + self.assertEquals(200, resp.status_int) + for verb in 'OPTIONS GET POST PUT DELETE HEAD REPLICATE \ + REPLICATION'.split(): + self.assertTrue( + verb in resp.headers['Allow'].split(', ')) + self.assertEquals(len(resp.headers['Allow'].split(', ')), 8) + def test_GET(self): # Test swift.obj.server.ObjectController.GET req = Request.blank('/sda1/p/a/c', environ={'REQUEST_METHOD': 'GET'}) @@ -4089,7 +4103,7 @@ class TestObjectController(unittest.TestCase): mock.MagicMock(return_value=method_res)) with mock.patch.object(self.object_controller, method, new=mock_method): - response = self.object_controller.__call__(env, start_response) + response = self.object_controller(env, start_response) self.assertEqual(response, method_res) def test_not_allowed_method(self): @@ -4146,6 +4160,38 @@ class TestObjectController(unittest.TestCase): ' 1234',), {})]) + def test_call_incorrect_replication_method(self): + inbuf = StringIO() + errbuf = StringIO() + outbuf = StringIO() + self.object_controller = object_server.ObjectController( + {'devices': self.testdir, 'mount_check': 'false', + 'replication_server': 'true'}, logger=FakeLogger()) + + def start_response(*args): + """Sends args to outbuf""" + outbuf.writelines(args) + + obj_methods = ['DELETE', 'PUT', 'HEAD', 'GET', 'POST', 'OPTIONS'] + for method in obj_methods: + env = {'REQUEST_METHOD': method, + 'SCRIPT_NAME': '', + 'PATH_INFO': '/sda1/p/a/c', + 'SERVER_NAME': '127.0.0.1', + 'SERVER_PORT': '8080', + 'SERVER_PROTOCOL': 'HTTP/1.0', + 'CONTENT_LENGTH': '0', + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': 'http', + 'wsgi.input': inbuf, + 'wsgi.errors': errbuf, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False} + self.object_controller(env, start_response) + self.assertEquals(errbuf.getvalue(), '') + self.assertEquals(outbuf.getvalue()[:4], '405 ') + def test_not_utf8_and_not_logging_requests(self): inbuf = StringIO() errbuf = StringIO()