diff --git a/troveclient/tests/fakes.py b/troveclient/tests/fakes.py index 90345312..76cb2225 100644 --- a/troveclient/tests/fakes.py +++ b/troveclient/tests/fakes.py @@ -505,9 +505,15 @@ class FakeHTTPClient(base_client.HTTPClient): def post_instances_1234_root(self, **kw): return (202, {}, {"user": {"password": "password", "name": "root"}}) + def post_clusters_cls_1234_root(self, **kw): + return (202, {}, {"user": {"password": "password", "name": "root"}}) + def get_instances_1234_root(self, **kw): return (200, {}, {"rootEnabled": 'True'}) + def get_clusters_cls_1234_root(self, **kw): + return (200, {}, {"rootEnabled": 'True'}) + def get_security_groups(self, **kw): return (200, {}, {"security_groups": [ { diff --git a/troveclient/tests/test_v1_shell.py b/troveclient/tests/test_v1_shell.py index 38fea223..ff9b7f41 100644 --- a/troveclient/tests/test_v1_shell.py +++ b/troveclient/tests/test_v1_shell.py @@ -57,7 +57,23 @@ class ShellTest(utils.TestCase): @mock.patch('sys.stdout', new_callable=six.StringIO) @mock.patch('troveclient.client.get_version_map', return_value=fakes.get_version_map()) - def run_command(self, cmd, mock_stdout, mock_get_version_map): + @mock.patch('troveclient.v1.shell._find_instance_or_cluster', + return_value=('1234', 'instance')) + def run_command(self, cmd, mock_find_instance_or_cluster, + mock_get_version_map, mock_stdout): + if isinstance(cmd, list): + self.shell.main(cmd) + else: + self.shell.main(cmd.split()) + return mock_stdout.getvalue() + + @mock.patch('sys.stdout', new_callable=six.StringIO) + @mock.patch('troveclient.client.get_version_map', + return_value=fakes.get_version_map()) + @mock.patch('troveclient.v1.shell._find_instance_or_cluster', + return_value=('cls-1234', 'cluster')) + def run_command_clusters(self, cmd, mock_find_instance_or_cluster, + mock_get_version_map, mock_stdout): if isinstance(cmd, list): self.shell.main(cmd) else: @@ -461,14 +477,22 @@ class ShellTest(utils.TestCase): self.assert_called('DELETE', '/instances/1234/users/jacob/databases/db1') - def test_root_enable(self): + def test_root_enable_instance(self): self.run_command('root-enable 1234') self.assert_called_anytime('POST', '/instances/1234/root') - def test_root_show(self): + def test_root_enable_cluster(self): + self.run_command_clusters('root-enable cls-1234') + self.assert_called_anytime('POST', '/clusters/cls-1234/root') + + def test_root_show_instance(self): self.run_command('root-show 1234') self.assert_called('GET', '/instances/1234/root') + def test_root_show_cluster(self): + self.run_command_clusters('root-show cls-1234') + self.assert_called('GET', '/clusters/cls-1234/root') + def test_secgroup_list(self): self.run_command('secgroup-list') self.assert_called('GET', '/security-groups') @@ -493,3 +517,29 @@ class ShellTest(utils.TestCase): 'cidr': '15.0.0.0/24', 'group_id': '2', }}) + + @mock.patch('sys.stdout', new_callable=six.StringIO) + @mock.patch('troveclient.client.get_version_map', + return_value=fakes.get_version_map()) + @mock.patch('troveclient.v1.shell._find_instance', + side_effect=exceptions.CommandError) + @mock.patch('troveclient.v1.shell._find_cluster', + return_value='cls-1234') + def test_find_instance_or_cluster_find_cluster(self, mock_find_cluster, + mock_find_instance, + mock_get_version_map, + mock_stdout): + cmd = 'root-show cls-1234' + self.shell.main(cmd.split()) + self.assert_called('GET', '/clusters/cls-1234/root') + + @mock.patch('sys.stdout', new_callable=six.StringIO) + @mock.patch('troveclient.client.get_version_map', + return_value=fakes.get_version_map()) + @mock.patch('troveclient.v1.shell._find_instance', + return_value='1234') + def test_find_instance_or_cluster(self, mock_find_instance, + mock_get_version_map, mock_stdout): + cmd = 'root-show 1234' + self.shell.main(cmd.split()) + self.assert_called('GET', '/instances/1234/root') diff --git a/troveclient/v1/root.py b/troveclient/v1/root.py index 62462193..99c6672b 100644 --- a/troveclient/v1/root.py +++ b/troveclient/v1/root.py @@ -22,22 +22,55 @@ from troveclient.v1 import users class Root(base.ManagerWithFind): """Manager class for Root resource.""" resource_class = users.User - url = "/instances/%s/root" + instances_url = "/instances/%s/root" + clusters_url = "/clusters/%s/root" def create(self, instance): """Implements root-enable API. - Enable the root user and return the root password for the specified db instance. """ - resp, body = self.api.client.post(self.url % base.getid(instance)) - common.check_for_exceptions(resp, body, self.url) + return self.create_instance_root(instance) + + def create_instance_root(self, instance, root_password=None): + """Implements root-enable for instances.""" + return self._enable_root(self.instances_url % base.getid(instance), + root_password) + + def create_cluster_root(self, cluster, root_password=None): + """Implements root-enable for clusters.""" + return self._enable_root(self.clusters_url % base.getid(cluster), + root_password) + + def _enable_root(self, uri, root_password=None): + """Implements root-enable API. + Enable the root user and return the root password for the + specified db instance or cluster. + """ + if root_password: + resp, body = self.api.client.post(uri, + body={"password": root_password}) + else: + resp, body = self.api.client.post(uri) + common.check_for_exceptions(resp, body, uri) return body['user']['name'], body['user']['password'] def is_root_enabled(self, instance): """Return whether root is enabled for the instance.""" - resp, body = self.api.client.get(self.url % base.getid(instance)) - common.check_for_exceptions(resp, body, self.url) + return self.is_instance_root_enabled(instance) + + def is_instance_root_enabled(self, instance): + """Returns whether root is enabled for the instance.""" + return self._is_root_enabled(self.instances_url % base.getid(instance)) + + def is_cluster_root_enabled(self, cluster): + """Returns whether root is enabled for the cluster.""" + return self._is_root_enabled(self.clusters_url % base.getid(cluster)) + + def _is_root_enabled(self, uri): + """Return whether root is enabled for the instance or the cluster.""" + resp, body = self.api.client.get(uri) + common.check_for_exceptions(resp, body, uri) return self.resource_class(self, body, loaded=True) # Appease the abc gods diff --git a/troveclient/v1/shell.py b/troveclient/v1/shell.py index 27a97eeb..46d83dd1 100644 --- a/troveclient/v1/shell.py +++ b/troveclient/v1/shell.py @@ -99,6 +99,22 @@ def _print_object(obj): utils.print_dict(obj._info) +def _find_instance_or_cluster(cs, instance_or_cluster): + """Returns an instance or cluster, found by id, along with the type of + resource, instance or cluster, that was found. + Raises CommandError if none is found. + """ + try: + return _find_instance(cs, instance_or_cluster), 'instance' + except exceptions.CommandError: + try: + return _find_cluster(cs, instance_or_cluster), 'cluster' + except Exception: + raise exceptions.CommandError( + "No instance or cluster with a name or ID of '%s' exists." + % instance_or_cluster) + + def _find_instance(cs, instance): """Get an instance by ID.""" return utils.find_resource(cs.instances, instance) @@ -890,23 +906,37 @@ def do_limit_list(cs, args): # Root related commands -@utils.arg('instance', metavar='', - help='ID or name of the instance.') +@utils.arg('instance_or_cluster', metavar='', + help='ID or name of the instance or cluster.') +@utils.arg('--root_password', + metavar='', + default=None, + help='Root password to set.') @utils.service_type('database') def do_root_enable(cs, args): """Enables root for an instance and resets if already exists.""" - instance = _find_instance(cs, args.instance) - root = cs.root.create(instance) + instance_or_cluster, resource_type = _find_instance_or_cluster( + cs, args.instance_or_cluster) + if resource_type == 'instance': + root = cs.root.create_instance_root(instance_or_cluster, + args.root_password) + else: + root = cs.root.create_cluster_root(instance_or_cluster, + args.root_password) utils.print_dict({'name': root[0], 'password': root[1]}) -@utils.arg('instance', metavar='', - help='ID or name of the instance.') +@utils.arg('instance_or_cluster', metavar='', + help='ID or name of the instance or cluster.') @utils.service_type('database') def do_root_show(cs, args): - """Gets status if root was ever enabled for an instance.""" - instance = _find_instance(cs, args.instance) - root = cs.root.is_root_enabled(instance) + """Gets status if root was ever enabled for an instance or cluster.""" + instance_or_cluster, resource_type = _find_instance_or_cluster( + cs, args.instance_or_cluster) + if resource_type == 'instance': + root = cs.root.is_instance_root_enabled(instance_or_cluster) + else: + root = cs.root.is_cluster_root_enabled(instance_or_cluster) utils.print_dict({'is_root_enabled': root.rootEnabled})