diff --git a/trove/common/exception.py b/trove/common/exception.py index d7e20f64c7..9413e65ced 100644 --- a/trove/common/exception.py +++ b/trove/common/exception.py @@ -606,6 +606,10 @@ class PublicNetworkNotFound(TroveError): message = _("Public network cannot be found.") +class NetworkConflict(BadRequest): + message = _("User network conflicts with the management network.") + + class ClusterVolumeSizeRequired(TroveError): message = _("A volume size is required for each instance in the cluster.") diff --git a/trove/common/neutron.py b/trove/common/neutron.py index 5f3330b117..d11ecf22d5 100644 --- a/trove/common/neutron.py +++ b/trove/common/neutron.py @@ -21,6 +21,7 @@ from trove.common import exception CONF = cfg.CONF LOG = logging.getLogger(__name__) MGMT_NETWORKS = None +MGMT_CIDRS = None def get_management_networks(context): @@ -147,3 +148,27 @@ def create_security_group_rule(client, sg_id, protocol, ports, remote_ips): } client.create_security_group_rule(body) + + +def get_subnet_cidrs(client, network_id): + cidrs = [] + + subnets = client.list_subnets(network_id=network_id)['subnets'] + for subnet in subnets: + cidrs.append(subnet.get('cidr')) + + return cidrs + + +def get_mamangement_subnet_cidrs(client): + """Cache the management subnet CIDRS.""" + global MGMT_CIDRS + + if MGMT_CIDRS is not None: + return MGMT_CIDRS + + MGMT_CIDRS = [] + if len(CONF.management_networks) > 0: + MGMT_CIDRS = get_subnet_cidrs(client, CONF.management_networks[0]) + + return MGMT_CIDRS diff --git a/trove/instance/service.py b/trove/instance/service.py index fbf1b1aa64..8f66de445c 100644 --- a/trove/instance/service.py +++ b/trove/instance/service.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +import ipaddress + from oslo_log import log as logging from oslo_utils import strutils @@ -20,9 +22,10 @@ from trove.backup.models import Backup as backup_model from trove.backup import views as backup_views import trove.common.apischema as apischema from trove.common import cfg -from trove.common.clients import create_guest_client +from trove.common import clients from trove.common import exception from trove.common.i18n import _ +from trove.common import neutron from trove.common import notification from trove.common.notification import StartNotification from trove.common import pagination @@ -279,6 +282,19 @@ class InstanceController(wsgi.Controller): instance.delete() return wsgi.Result(None, 202) + def _check_network_overlap(self, context, user_network): + neutron_client = clients.create_neutron_client(context) + user_cidrs = neutron.get_subnet_cidrs(neutron_client, user_network) + mgmt_cidrs = neutron.get_mamangement_subnet_cidrs(neutron_client) + LOG.debug("Cidrs of the user network: %s, cidrs of the management " + "network: %s", user_cidrs, mgmt_cidrs) + for user_cidr in user_cidrs: + user_net = ipaddress.ip_network(user_cidr) + for mgmt_cidr in mgmt_cidrs: + mgmt_net = ipaddress.ip_network(mgmt_cidr) + if user_net.overlaps(mgmt_net): + raise exception.NetworkConflict() + def create(self, req, body, tenant_id): # TODO(hub-cap): turn this into middleware LOG.info("Creating a database instance for tenant '%s'", @@ -342,7 +358,10 @@ class InstanceController(wsgi.Controller): backup_id = None availability_zone = body['instance'].get('availability_zone') + nics = body['instance'].get('nics', []) + if len(nics) > 0: + self._check_network_overlap(context, nics[0].get('net-id')) slave_of_id = body['instance'].get('replica_of', # also check for older name @@ -499,7 +518,7 @@ class InstanceController(wsgi.Controller): if not instance: raise exception.NotFound(uuid=id) self.authorize_instance_action(context, 'guest_log_list', instance) - client = create_guest_client(context, id) + client = clients.create_guest_client(context, id) guest_log_list = client.guest_log_list() return wsgi.Result({'logs': guest_log_list}, 200) @@ -523,7 +542,7 @@ class InstanceController(wsgi.Controller): discard = body.get('discard', None) if enable and disable: raise exception.BadRequest(_("Cannot enable and disable log.")) - client = create_guest_client(context, id) + client = clients.create_guest_client(context, id) guest_log = client.guest_log_action(log_name, enable, disable, publish, discard) return wsgi.Result({'log': guest_log}, 200) @@ -546,13 +565,13 @@ class InstanceController(wsgi.Controller): def _module_list_guest(self, context, id, include_contents): """Return information about modules on an instance.""" - client = create_guest_client(context, id) + client = clients.create_guest_client(context, id) result_list = client.module_list(include_contents) return wsgi.Result({'modules': result_list}, 200) def _module_list(self, context, id, include_contents): """Return information about instance modules.""" - client = create_guest_client(context, id) + client = clients.create_guest_client(context, id) result_list = client.module_list(include_contents) return wsgi.Result({'modules': result_list}, 200) @@ -568,7 +587,7 @@ class InstanceController(wsgi.Controller): module_models.Modules.validate( modules, instance.datastore.id, instance.datastore_version.id) module_list = module_views.convert_modules_to_list(modules) - client = create_guest_client(context, id) + client = clients.create_guest_client(context, id) result_list = client.module_apply(module_list) models.Instance.add_instance_modules(context, id, modules) return wsgi.Result({'modules': result_list}, 200) @@ -582,7 +601,7 @@ class InstanceController(wsgi.Controller): self.authorize_instance_action(context, 'module_remove', instance) module = module_models.Module.load(context, module_id) module_info = module_views.DetailedModuleView(module).data() - client = create_guest_client(context, id) + client = clients.create_guest_client(context, id) client.module_remove(module_info) instance_modules = module_models.InstanceModules.load_all( context, instance_id=id, module_id=module_id)