diff --git a/swift/common/wsgi.py b/swift/common/wsgi.py index 6f0f600f54..5e8d0311ae 100644 --- a/swift/common/wsgi.py +++ b/swift/common/wsgi.py @@ -35,7 +35,8 @@ from six import StringIO from swift.common import utils, constraints from swift.common.storage_policy import BindPortsCache -from swift.common.swob import Request, wsgi_unquote +from swift.common.swob import Request, wsgi_quote, wsgi_unquote, \ + wsgi_quote_plus, wsgi_unquote_plus, wsgi_to_bytes, bytes_to_wsgi from swift.common.utils import capture_stdio, disable_fallocate, \ drop_privileges, get_logger, NullLogger, config_true_value, \ validate_configuration, get_hub, config_auto_int_value, \ @@ -433,6 +434,36 @@ class SwiftHttpProtocol(wsgi.HttpProtocol): '''If the client didn't provide a content type, leave it blank.''' return '' + def parse_request(self): + if not six.PY2: + # request lines *should* be ascii per the RFC, but historically + # we've allowed (and even have func tests that use) arbitrary + # bytes. This breaks on py3 (see https://bugs.python.org/issue33973 + # ) but the work-around is simple: munge the request line to be + # properly quoted. py2 will do the right thing without this, but it + # doesn't hurt to re-write the request line like this and it + # simplifies testing. + if self.raw_requestline.count(b' ') >= 2: + parts = self.raw_requestline.split(b' ', 2) + path, q, query = parts[1].partition(b'?') + # unquote first, so we don't over-quote something + # that was *correctly* quoted + path = wsgi_to_bytes(wsgi_quote(wsgi_unquote( + bytes_to_wsgi(path)))) + query = b'&'.join( + sep.join([ + wsgi_to_bytes(wsgi_quote_plus(wsgi_unquote_plus( + bytes_to_wsgi(key)))), + wsgi_to_bytes(wsgi_quote_plus(wsgi_unquote_plus( + bytes_to_wsgi(val)))) + ]) + for part in query.split(b'&') + for key, sep, val in (part.partition(b'='), )) + parts[1] = path + q + query + self.raw_requestline = b' '.join(parts) + # else, mangled protocol, most likely; let base class deal with it + return wsgi.HttpProtocol.parse_request(self) + class SwiftHttpProxiedProtocol(SwiftHttpProtocol): """ diff --git a/test/unit/common/test_wsgi.py b/test/unit/common/test_wsgi.py index 26ce8092c0..d99877cc3b 100644 --- a/test/unit/common/test_wsgi.py +++ b/test/unit/common/test_wsgi.py @@ -27,6 +27,7 @@ import types import eventlet.wsgi +import six from six import BytesIO from six.moves.urllib.parse import quote @@ -984,11 +985,6 @@ class TestWSGI(unittest.TestCase): class TestSwiftHttpProtocol(unittest.TestCase): - def setUp(self): - patcher = mock.patch('swift.common.wsgi.wsgi.HttpProtocol') - self.mock_super = patcher.start() - self.addCleanup(patcher.stop) - def _proto_obj(self): # Make an object we can exercise... note the base class's __init__() # does a bunch of work, so we just new up an object like eventlet.wsgi @@ -1041,12 +1037,38 @@ class TestSwiftHttpProtocol(unittest.TestCase): self.assertEqual(False, proto_obj.parse_request()) - self.assertEqual([], self.mock_super.mock_calls) self.assertEqual([ mock.call(400, "Bad HTTP/0.9 request type ('jimmy')"), ], proto_obj.send_error.mock_calls) self.assertEqual(('a', '123'), proto_obj.client_address) + def test_request_line_cleanup(self): + def do_test(line_from_socket, expected_line=None): + if expected_line is None: + expected_line = line_from_socket + + proto_obj = self._proto_obj() + proto_obj.raw_requestline = line_from_socket + with mock.patch('swift.common.wsgi.wsgi.HttpProtocol') \ + as mock_super: + proto_obj.parse_request() + + self.assertEqual([mock.call.parse_request(proto_obj)], + mock_super.mock_calls) + self.assertEqual(proto_obj.raw_requestline, expected_line) + + do_test(b'GET / HTTP/1.1') + do_test(b'GET /%FF HTTP/1.1') + + if not six.PY2: + do_test(b'GET /\xff HTTP/1.1', b'GET /%FF HTTP/1.1') + do_test(b'PUT /Here%20Is%20A%20SnowMan:\xe2\x98\x83 HTTP/1.0', + b'PUT /Here%20Is%20A%20SnowMan%3A%E2%98%83 HTTP/1.0') + do_test( + b'POST /?and%20it=fixes+params&' + b'PALMTREE=\xf0%9f\x8c%b4 HTTP/1.1', + b'POST /?and+it=fixes+params&PALMTREE=%F0%9F%8C%B4 HTTP/1.1') + class TestProxyProtocol(unittest.TestCase): def _run_bytes_through_protocol(self, bytes_from_client, protocol_class):