diff --git a/swift/common/ring/ring.py b/swift/common/ring/ring.py index 68635acd67..48be83c627 100644 --- a/swift/common/ring/ring.py +++ b/swift/common/ring/ring.py @@ -14,6 +14,8 @@ # limitations under the License. import array +import contextlib + import six.moves.cPickle as pickle import json from collections import defaultdict @@ -173,22 +175,21 @@ class RingData(object): :param bool metadata_only: If True, only load `devs` and `part_shift`. :returns: A RingData instance containing the loaded data. """ - gz_file = RingReader(filename) - - # See if the file is in the new format - magic = gz_file.read(4) - if magic == b'R1NG': - format_version, = struct.unpack('!H', gz_file.read(2)) - if format_version == 1: - ring_data = cls.deserialize_v1( - gz_file, metadata_only=metadata_only) + with contextlib.closing(RingReader(filename)) as gz_file: + # See if the file is in the new format + magic = gz_file.read(4) + if magic == b'R1NG': + format_version, = struct.unpack('!H', gz_file.read(2)) + if format_version == 1: + ring_data = cls.deserialize_v1( + gz_file, metadata_only=metadata_only) + else: + raise Exception('Unknown ring format version %d' % + format_version) else: - raise Exception('Unknown ring format version %d' % - format_version) - else: - # Assume old-style pickled ring - gz_file.seek(0) - ring_data = pickle.load(gz_file) + # Assume old-style pickled ring + gz_file.seek(0) + ring_data = pickle.load(gz_file) if not hasattr(ring_data, 'devs'): ring_data = RingData(ring_data['replica2part2dev_id'], diff --git a/test/unit/common/ring/test_ring.py b/test/unit/common/ring/test_ring.py index feffc4a6bb..4cd4f5f847 100644 --- a/test/unit/common/ring/test_ring.py +++ b/test/unit/common/ring/test_ring.py @@ -113,6 +113,27 @@ class TestRingData(unittest.TestCase): rd2 = ring.RingData.load(ring_fname) self.assert_ring_data_equal(rd, rd2) + def test_load_closes_file(self): + ring_fname = os.path.join(self.testdir, 'foo.ring.gz') + rd = ring.RingData( + [array.array('H', [0, 1, 0, 1]), array.array('H', [0, 1, 0, 1])], + [{'id': 0, 'zone': 0}, {'id': 1, 'zone': 1}], 30) + rd.save(ring_fname) + + class MockReader(ring.ring.RingReader): + calls = [] + + def close(self): + self.calls.append(('close', self.fp)) + return super(MockReader, self).close() + + with mock.patch('swift.common.ring.ring.RingReader', + MockReader) as mock_reader: + ring.RingData.load(ring_fname) + + self.assertEqual([('close', mock.ANY)], mock_reader.calls) + self.assertTrue(mock_reader.calls[0][1].closed) + def test_byteswapped_serialization(self): # Manually byte swap a ring and write it out, claiming it was written # on a different endian machine. Then read it back in and see if it's