From c112203e0ef8f69cdd5a78c260029839a8763d26 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Fri, 2 Nov 2018 21:38:49 +0000 Subject: [PATCH] py3: Monkey-patch json.loads to accept bytes on py35 I'm tired of creating code churn where I just slap .decode("nearly arbitrary choice of encoding") in a bunch of places. Change-Id: I79b2bc59fed130ca537e96c1074212861d7db6b8 --- swift/__init__.py | 35 ++++++++++++++++++ swift/common/direct_client.py | 2 +- swift/common/internal_client.py | 4 +-- swift/common/memcached.py | 4 +-- swift/common/middleware/listing_formats.py | 2 +- swift/common/middleware/symlink.py | 2 +- swift/common/ring/ring.py | 2 +- swift/common/utils.py | 2 +- .../common/middleware/s3api/test_s3api.py | 2 +- .../common/middleware/test_list_endpoints.py | 36 +++++++++---------- test/unit/common/test_direct_client.py | 4 +-- test/unit/proxy/controllers/test_info.py | 10 +++--- 12 files changed, 70 insertions(+), 35 deletions(-) diff --git a/swift/__init__.py b/swift/__init__.py index 9d0e8896f7..f9f0931324 100644 --- a/swift/__init__.py +++ b/swift/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import sys import gettext import pkg_resources @@ -39,3 +40,37 @@ _t = gettext.translation('swift', localedir=_localedir, fallback=True) def gettext_(msg): return _t.gettext(msg) + + +if (3, 0) <= sys.version_info[:2] <= (3, 5): + # In the development of py3, json.loads() stopped accepting byte strings + # for a while. https://bugs.python.org/issue17909 got fixed for py36, but + # since it was termed an enhancement and not a regression, we don't expect + # any backports. At the same time, it'd be better if we could avoid + # leaving a whole bunch of json.loads(resp.body.decode(...)) scars in the + # code that'd probably persist even *after* we drop support for 3.5 and + # earlier. So, monkey patch stdlib. + import json + if not getattr(json.loads, 'patched_to_decode', False): + class JsonLoadsPatcher(object): + def __init__(self, orig): + self._orig = orig + + def __call__(self, s, **kw): + if isinstance(s, bytes): + # No fancy byte-order mark detection for us; just assume + # UTF-8 and raise a UnicodeDecodeError if appropriate. + s = s.decode('utf8') + return self._orig(s, **kw) + + def __getattribute__(self, attr): + if attr == 'patched_to_decode': + return True + if attr == '_orig': + return super().__getattribute__(attr) + # Pass through all other attrs to the original; among other + # things, this preserves doc strings, etc. + return getattr(self._orig, attr) + + json.loads = JsonLoadsPatcher(json.loads) + del JsonLoadsPatcher diff --git a/swift/common/direct_client.py b/swift/common/direct_client.py index 2d4e2ce6a7..174b61601f 100644 --- a/swift/common/direct_client.py +++ b/swift/common/direct_client.py @@ -174,7 +174,7 @@ def _get_direct_account_container(path, stype, node, part, if resp.status == HTTP_NO_CONTENT: resp.read() return resp_headers, [] - return resp_headers, json.loads(resp.read().decode('ascii')) + return resp_headers, json.loads(resp.read()) def gen_headers(hdrs_in=None, add_ts=True): diff --git a/swift/common/internal_client.py b/swift/common/internal_client.py index d4cc2d19e7..95b52ae02d 100644 --- a/swift/common/internal_client.py +++ b/swift/common/internal_client.py @@ -298,7 +298,7 @@ class InternalClient(object): if resp.status_int >= HTTP_MULTIPLE_CHOICES: b''.join(resp.app_iter) break - data = json.loads(resp.body.decode('ascii')) + data = json.loads(resp.body) if not data: break for item in data: @@ -844,7 +844,7 @@ class SimpleClient(object): body = conn.read() info = conn.info() try: - body_data = json.loads(body.decode('ascii')) + body_data = json.loads(body) except ValueError: body_data = None trans_stop = time() diff --git a/swift/common/memcached.py b/swift/common/memcached.py index 4b1a879a41..8583cafb4a 100644 --- a/swift/common/memcached.py +++ b/swift/common/memcached.py @@ -315,7 +315,7 @@ class MemcacheRing(object): else: value = None elif int(line[2]) & JSON_FLAG: - value = json.loads(value.decode('ascii')) + value = json.loads(value) fp.readline() line = fp.readline().strip().split() self._return_conn(server, fp, sock) @@ -484,7 +484,7 @@ class MemcacheRing(object): else: value = None elif int(line[2]) & JSON_FLAG: - value = json.loads(value.decode('ascii')) + value = json.loads(value) responses[line[1]] = value fp.readline() line = fp.readline().strip().split() diff --git a/swift/common/middleware/listing_formats.py b/swift/common/middleware/listing_formats.py index d70ac484e5..7fcd7ff339 100644 --- a/swift/common/middleware/listing_formats.py +++ b/swift/common/middleware/listing_formats.py @@ -185,7 +185,7 @@ class ListingFilter(object): body = b''.join(resp_iter) try: - listing = json.loads(body.decode('ascii')) + listing = json.loads(body) # Do a couple sanity checks if not isinstance(listing, list): raise ValueError diff --git a/swift/common/middleware/symlink.py b/swift/common/middleware/symlink.py index a63eaac77e..94e6c8edfe 100644 --- a/swift/common/middleware/symlink.py +++ b/swift/common/middleware/symlink.py @@ -295,7 +295,7 @@ class SymlinkContainerContext(WSGIContext): """ with closing_if_possible(resp_iter): resp_body = b''.join(resp_iter) - body_json = json.loads(resp_body.decode('ascii')) + body_json = json.loads(resp_body) swift_version, account, _junk = split_path(req.path, 2, 3, True) new_body = json.dumps( [self._extract_symlink_path_json(obj_dict, swift_version, account) diff --git a/swift/common/ring/ring.py b/swift/common/ring/ring.py index 3f1a484e4c..a28d97d2fa 100644 --- a/swift/common/ring/ring.py +++ b/swift/common/ring/ring.py @@ -78,7 +78,7 @@ class RingData(object): """ json_len, = struct.unpack('!I', gz_file.read(4)) - ring_dict = json.loads(gz_file.read(json_len).decode('ascii')) + ring_dict = json.loads(gz_file.read(json_len)) ring_dict['replica2part2dev_id'] = [] if metadata_only: diff --git a/swift/common/utils.py b/swift/common/utils.py index 6289511e10..9e45ceb36d 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -3477,7 +3477,7 @@ def dump_recon_cache(cache_dict, cache_file, logger, lock_timeout=2, try: existing_entry = cf.readline() if existing_entry: - cache_entry = json.loads(existing_entry.decode('utf8')) + cache_entry = json.loads(existing_entry) except ValueError: # file doesn't have a valid entry, we'll recreate it pass diff --git a/test/unit/common/middleware/s3api/test_s3api.py b/test/unit/common/middleware/s3api/test_s3api.py index 460f8bbdec..332fd0dce9 100644 --- a/test/unit/common/middleware/s3api/test_s3api.py +++ b/test/unit/common/middleware/s3api/test_s3api.py @@ -59,7 +59,7 @@ class TestListingMiddleware(S3ApiTestCase): req = Request.blank('/v1/a/c') status, headers, body = self.call_s3api(req) - self.assertEqual(json.loads(body.decode('ascii')), [ + self.assertEqual(json.loads(body), [ {'name': 'obj1', 'hash': '0123456789abcdef0123456789abcdef'}, {'name': 'obj2', 'hash': 'swiftetag', 's3_etag': '"mu-etag"'}, {'name': 'obj2', 'hash': 'swiftetag; something=else'}, diff --git a/test/unit/common/middleware/test_list_endpoints.py b/test/unit/common/middleware/test_list_endpoints.py index 651f9422db..f7dfbc4298 100644 --- a/test/unit/common/middleware/test_list_endpoints.py +++ b/test/unit/common/middleware/test_list_endpoints.py @@ -240,7 +240,7 @@ class TestListEndpoints(unittest.TestCase): self.list_endpoints) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.content_type, 'application/json') - self.assertEqual(json.loads(resp.body.decode("utf-8")), [ + self.assertEqual(json.loads(resp.body), [ "http://10.1.1.1:6200/sdb1/1/a/c/o1", "http://10.1.2.2:6200/sdd1/1/a/c/o1" ]) @@ -260,14 +260,14 @@ class TestListEndpoints(unittest.TestCase): self.list_endpoints) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.content_type, 'application/json') - self.assertEqual(json.loads(resp.body.decode("utf-8")), + self.assertEqual(json.loads(resp.body), expected[pol.idx]) # Here, 'o1/' is the object name. resp = Request.blank('/endpoints/a/c/o1/').get_response( self.list_endpoints) self.assertEqual(resp.status_int, 200) - self.assertEqual(json.loads(resp.body.decode("utf-8")), [ + self.assertEqual(json.loads(resp.body), [ "http://10.1.1.1:6200/sdb1/3/a/c/o1/", "http://10.1.2.2:6200/sdd1/3/a/c/o1/" ]) @@ -275,7 +275,7 @@ class TestListEndpoints(unittest.TestCase): resp = Request.blank('/endpoints/a/c2').get_response( self.list_endpoints) self.assertEqual(resp.status_int, 200) - self.assertEqual(json.loads(resp.body.decode("utf-8")), [ + self.assertEqual(json.loads(resp.body), [ "http://10.1.1.1:6200/sda1/2/a/c2", "http://10.1.2.1:6200/sdc1/2/a/c2" ]) @@ -283,7 +283,7 @@ class TestListEndpoints(unittest.TestCase): resp = Request.blank('/endpoints/a1').get_response( self.list_endpoints) self.assertEqual(resp.status_int, 200) - self.assertEqual(json.loads(resp.body.decode("utf-8")), [ + self.assertEqual(json.loads(resp.body), [ "http://10.1.2.1:6200/sdc1/0/a1", "http://10.1.1.1:6200/sda1/0/a1", "http://10.1.1.1:6200/sdb1/0/a1" @@ -296,7 +296,7 @@ class TestListEndpoints(unittest.TestCase): resp = Request.blank('/endpoints/a/c 2').get_response( self.list_endpoints) self.assertEqual(resp.status_int, 200) - self.assertEqual(json.loads(resp.body.decode("utf-8")), [ + self.assertEqual(json.loads(resp.body), [ "http://10.1.1.1:6200/sdb1/3/a/c%202", "http://10.1.2.2:6200/sdd1/3/a/c%202" ]) @@ -304,7 +304,7 @@ class TestListEndpoints(unittest.TestCase): resp = Request.blank('/endpoints/a/c%202').get_response( self.list_endpoints) self.assertEqual(resp.status_int, 200) - self.assertEqual(json.loads(resp.body.decode("utf-8")), [ + self.assertEqual(json.loads(resp.body), [ "http://10.1.1.1:6200/sdb1/3/a/c%202", "http://10.1.2.2:6200/sdd1/3/a/c%202" ]) @@ -312,7 +312,7 @@ class TestListEndpoints(unittest.TestCase): resp = Request.blank('/endpoints/ac%20count/con%20tainer/ob%20ject') \ .get_response(self.list_endpoints) self.assertEqual(resp.status_int, 200) - self.assertEqual(json.loads(resp.body.decode("utf-8")), [ + self.assertEqual(json.loads(resp.body), [ "http://10.1.1.1:6200/sdb1/3/ac%20count/con%20tainer/ob%20ject", "http://10.1.2.2:6200/sdd1/3/ac%20count/con%20tainer/ob%20ject" ]) @@ -342,7 +342,7 @@ class TestListEndpoints(unittest.TestCase): .get_response(custom_path_le) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.content_type, 'application/json') - self.assertEqual(json.loads(resp.body.decode("utf-8")), + self.assertEqual(json.loads(resp.body), expected[pol.idx]) # test custom path without trailing slash @@ -356,7 +356,7 @@ class TestListEndpoints(unittest.TestCase): .get_response(custom_path_le) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.content_type, 'application/json') - self.assertEqual(json.loads(resp.body.decode("utf-8")), + self.assertEqual(json.loads(resp.body), expected[pol.idx]) def test_v1_response(self): @@ -364,7 +364,7 @@ class TestListEndpoints(unittest.TestCase): resp = req.get_response(self.list_endpoints) expected = ["http://10.1.1.1:6200/sdb1/1/a/c/o1", "http://10.1.2.2:6200/sdd1/1/a/c/o1"] - self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) + self.assertEqual(json.loads(resp.body), expected) def test_v2_obj_response(self): req = Request.blank('/endpoints/v2/a/c/o1') @@ -374,7 +374,7 @@ class TestListEndpoints(unittest.TestCase): "http://10.1.2.2:6200/sdd1/1/a/c/o1"], 'headers': {'X-Backend-Storage-Policy-Index': "0"}, } - self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) + self.assertEqual(json.loads(resp.body), expected) for policy in POLICIES: patch_path = 'swift.common.middleware.list_endpoints' \ '.get_container_info' @@ -390,7 +390,7 @@ class TestListEndpoints(unittest.TestCase): 'X-Backend-Storage-Policy-Index': str(int(policy))}, 'endpoints': [path % node for node in nodes], } - self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) + self.assertEqual(json.loads(resp.body), expected) def test_v2_non_obj_response(self): # account @@ -403,7 +403,7 @@ class TestListEndpoints(unittest.TestCase): 'headers': {}, } # container - self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) + self.assertEqual(json.loads(resp.body), expected) req = Request.blank('/endpoints/v2/a/c') resp = req.get_response(self.list_endpoints) expected = { @@ -412,7 +412,7 @@ class TestListEndpoints(unittest.TestCase): "http://10.1.2.1:6200/sdc1/0/a/c"], 'headers': {}, } - self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) + self.assertEqual(json.loads(resp.body), expected) def test_version_account_response(self): req = Request.blank('/endpoints/a') @@ -420,10 +420,10 @@ class TestListEndpoints(unittest.TestCase): expected = ["http://10.1.2.1:6200/sdc1/0/a", "http://10.1.1.1:6200/sda1/0/a", "http://10.1.1.1:6200/sdb1/0/a"] - self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) + self.assertEqual(json.loads(resp.body), expected) req = Request.blank('/endpoints/v1.0/a') resp = req.get_response(self.list_endpoints) - self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) + self.assertEqual(json.loads(resp.body), expected) req = Request.blank('/endpoints/v2/a') resp = req.get_response(self.list_endpoints) @@ -433,7 +433,7 @@ class TestListEndpoints(unittest.TestCase): "http://10.1.1.1:6200/sdb1/0/a"], 'headers': {}, } - self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) + self.assertEqual(json.loads(resp.body), expected) if __name__ == '__main__': diff --git a/test/unit/common/test_direct_client.py b/test/unit/common/test_direct_client.py index 36d2d56c5d..f6555f18c0 100644 --- a/test/unit/common/test_direct_client.py +++ b/test/unit/common/test_direct_client.py @@ -213,7 +213,7 @@ class TestDirectClient(unittest.TestCase): self.assertEqual(conn.req_headers['user-agent'], self.user_agent) self.assertEqual(resp_headers, stub_headers) - self.assertEqual(json.loads(body.decode('ascii')), resp) + self.assertEqual(json.loads(body), resp) self.assertIn('format=json', conn.query_string) for k, v in req_params.items(): if v is None: @@ -389,7 +389,7 @@ class TestDirectClient(unittest.TestCase): self.assertEqual(conn.req_headers['user-agent'], self.user_agent) self.assertEqual(headers, resp_headers) - self.assertEqual(json.loads(body.decode('ascii')), resp) + self.assertEqual(json.loads(body), resp) self.assertIn('format=json', conn.query_string) for k, v in req_params.items(): if v is None: diff --git a/test/unit/proxy/controllers/test_info.py b/test/unit/proxy/controllers/test_info.py index e9ba2a76aa..2317acfbe1 100644 --- a/test/unit/proxy/controllers/test_info.py +++ b/test/unit/proxy/controllers/test_info.py @@ -62,7 +62,7 @@ class TestInfoController(unittest.TestCase): resp = controller.GET(req) self.assertIsInstance(resp, HTTPException) self.assertEqual('200 OK', str(resp)) - info = json.loads(resp.body.decode('ascii')) + info = json.loads(resp.body) self.assertNotIn('admin', info) self.assertIn('foo', info) self.assertIn('bar', info['foo']) @@ -89,7 +89,7 @@ class TestInfoController(unittest.TestCase): resp = controller.GET(req) self.assertIsInstance(resp, HTTPException) self.assertEqual('200 OK', str(resp)) - info = json.loads(resp.body.decode('ascii')) + info = json.loads(resp.body) self.assertNotIn('admin', info) self.assertIn('foo', info) self.assertIn('bar', info['foo']) @@ -120,7 +120,7 @@ class TestInfoController(unittest.TestCase): resp = controller.GET(req) self.assertIsInstance(resp, HTTPException) self.assertEqual('200 OK', str(resp)) - info = json.loads(resp.body.decode('ascii')) + info = json.loads(resp.body) self.assertIn('foo', info) self.assertIn('bar', info['foo']) self.assertEqual(info['foo']['bar'], 'baz') @@ -156,7 +156,7 @@ class TestInfoController(unittest.TestCase): resp = controller.GET(req) self.assertIsInstance(resp, HTTPException) self.assertEqual('200 OK', str(resp)) - info = json.loads(resp.body.decode('ascii')) + info = json.loads(resp.body) self.assertIn('admin', info) self.assertIn('qux', info['admin']) self.assertIn('quux', info['admin']['qux']) @@ -279,7 +279,7 @@ class TestInfoController(unittest.TestCase): resp = controller.GET(req) self.assertIsInstance(resp, HTTPException) self.assertEqual('200 OK', str(resp)) - info = json.loads(resp.body.decode('ascii')) + info = json.loads(resp.body) self.assertNotIn('foo2', info) self.assertIn('admin', info) self.assertIn('disallowed_sections', info['admin'])