diff --git a/swift/cli/manage_shard_ranges.py b/swift/cli/manage_shard_ranges.py index 591b46ac1c..34556b0dfe 100644 --- a/swift/cli/manage_shard_ranges.py +++ b/swift/cli/manage_shard_ranges.py @@ -728,11 +728,12 @@ def _add_enable_args(parser): def _add_prompt_args(parser): - parser.add_argument( + group = parser.add_mutually_exclusive_group() + group.add_argument( '--yes', '-y', action='store_true', default=False, help='Apply shard range changes to broker without prompting. ' 'Cannot be used with --dry-run option.') - parser.add_argument( + group.add_argument( '--dry-run', '-n', action='store_true', default=False, help='Do not apply any shard range changes to broker. ' 'Cannot be used with --yes option.') @@ -890,10 +891,6 @@ def main(args=None): print('\nA sub-command is required.', file=sys.stderr) return EXIT_INVALID_ARGS - if getattr(args, 'yes', False) and getattr(args, 'dry_run', False): - print('--yes and --dry-run cannot both be set.', file=sys.stderr) - return EXIT_INVALID_ARGS - conf = {} rows_per_shard = DEFAULT_ROWS_PER_SHARD shrink_threshold = DEFAULT_SHRINK_THRESHOLD diff --git a/test/unit/cli/test_manage_shard_ranges.py b/test/unit/cli/test_manage_shard_ranges.py index 47a83277ce..9e3fcb01b2 100644 --- a/test/unit/cli/test_manage_shard_ranges.py +++ b/test/unit/cli/test_manage_shard_ranges.py @@ -12,6 +12,7 @@ import json import os +import sys import unittest from argparse import Namespace from textwrap import dedent @@ -1776,8 +1777,15 @@ class TestManageShardRanges(unittest.TestCase): out = StringIO() err = StringIO() with mock.patch('sys.stdout', out), \ - mock.patch('sys.stderr', err): - ret = main(['db file', 'repair', '--dry-run', '--yes']) - self.assertEqual(2, ret) + mock.patch('sys.stderr', err), \ + self.assertRaises(SystemExit) as cm: + main(['db file', 'repair', '--dry-run', '--yes']) + self.assertEqual(2, cm.exception.code) err_lines = err.getvalue().split('\n') - self.assertIn('--yes and --dry-run cannot both be set.', err_lines) + runner = os.path.basename(sys.argv[0]) + self.assertEqual( + 'usage: %s path_to_file repair [-h] [--yes | --dry-run]' % runner, + err_lines[0]) + self.assertIn( + "argument --yes/-y: not allowed with argument --dry-run/-n", + err_lines[1])