diff --git a/ec2driver.py b/ec2driver.py index 078349f..e2635d3 100644 --- a/ec2driver.py +++ b/ec2driver.py @@ -1,7 +1,7 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright (c) 2014 Thoughtworks. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may +# Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # @@ -40,6 +40,7 @@ from nova.virt import driver from nova.virt import virtapi from credentials import get_nova_creds +import rule_comparator LOG = logging.getLogger(__name__) @@ -79,13 +80,14 @@ EC2_STATE_MAP = { "running": power_state.RUNNING, "shutting-down": power_state.NOSTATE, "terminated": power_state.SHUTDOWN, - "stopping":power_state.NOSTATE, + "stopping": power_state.NOSTATE, "stopped": power_state.SHUTDOWN } DIAGNOSTIC_KEYS_TO_FILTER = ['group', 'block_device_mapping'] + def set_nodes(nodes): """Sets EC2Driver's node.list. @@ -108,6 +110,7 @@ def restore_nodes(): global _EC2_NODES _EC2_NODES = [CONF.host] + class EC2Driver(driver.ComputeDriver): capabilities = { "has_imagecache": True, @@ -131,6 +134,7 @@ class EC2Driver(driver.ComputeDriver): 'cpu_info': {}, 'disk_available_least': 500000000000, } + self._mounts = {} self._interfaces = {} @@ -139,16 +143,17 @@ class EC2Driver(driver.ComputeDriver): region = RegionInfo(name=aws_region, endpoint=aws_endpoint) self.ec2_conn = ec2.EC2Connection(aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - host=host, - port=port, - region=region, - is_secure=secure) + aws_secret_access_key=aws_secret_access_key, + host=host, + port=port, + region=region, + is_secure=secure) self.cloudwatch_conn = ec2.cloudwatch.connect_to_region( aws_region, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key) self.security_group_lock = Lock() + self.rule_comparator = rule_comparator.RuleComparator(self.ec2_conn) if not '_EC2_NODES' in globals(): set_nodes([CONF.host]) @@ -214,12 +219,13 @@ class EC2Driver(driver.ComputeDriver): if user_data: user_data = base64.b64decode(user_data) - reservation = self.ec2_conn.run_instances(aws_ami, instance_type=flavor_type, placement="us-east-1d", user_data=user_data) + reservation = self.ec2_conn.run_instances(aws_ami, instance_type=flavor_type, placement="us-east-1d", + user_data=user_data) ec2_instance = reservation.instances ec2_id = ec2_instance[0].id self._wait_for_state(instance, ec2_id, "running", power_state.RUNNING) - instance['metadata'].update({'ec2_id':ec2_id, 'public_ip_address':elastic_ip_address.public_ip}) + instance['metadata'].update({'ec2_id': ec2_id, 'public_ip_address': elastic_ip_address.public_ip}) LOG.info("****** Associating the elastic IP to the instance *********") self.ec2_conn.associate_address(instance_id=ec2_id, allocation_id=elastic_ip_address.allocation_id) @@ -258,13 +264,17 @@ class EC2Driver(driver.ComputeDriver): metadata = {'is_public': False, 'location': image_ref, 'properties': { - 'kernel_id': instance['kernel_id'], - 'image_state': 'available', - 'owner_id': instance['project_id'], - 'ramdisk_id': instance['ramdisk_id'], - 'ec2_image_id': ec2_image_id - } + 'kernel_id': instance['kernel_id'], + 'image_state': 'available', + 'owner_id': instance['project_id'], + 'ramdisk_id': instance['ramdisk_id'], + 'ec2_image_id': ec2_image_id } +<<<<<<< HEAD +======= + } + +>>>>>>> Ed & Cameron | Add and remove rules to/from security groups associated with instances image_api.update(context, image_id, metadata) def reboot(self, context, instance, network_info, reboot_type, @@ -442,7 +452,8 @@ class EC2Driver(driver.ComputeDriver): else: # get the elastic ip associated with the instance & disassociate # it, and release it - elastic_ip_address = self.ec2_conn.get_all_addresses(addresses=instance['metadata']['public_ip_address'])[0] + elastic_ip_address = \ + self.ec2_conn.get_all_addresses(addresses=instance['metadata']['public_ip_address'])[0] LOG.info("****** Disassociating the elastic IP *********") self.ec2_conn.disassociate_address(elastic_ip_address.public_ip) @@ -492,8 +503,8 @@ class EC2Driver(driver.ComputeDriver): # wait for the old volume to detach successfully to make sure # /dev/sdn is available for the new volume to be attached time.sleep(60) - self.ec2_conn.attach_volume(volume_map[new_volume_id], - instance['metadata']['ec2_id'], "/dev/sdn", dry_run=False) + self.ec2_conn.attach_volume(volume_map[new_volume_id], instance['metadata']['ec2_id'], "/dev/sdn", + dry_run=False) return True def attach_interface(self, instance, image_meta, vif): @@ -522,7 +533,8 @@ class EC2Driver(driver.ComputeDriver): raise exception.InstanceNotFound(instance_id=instance['name']) ec2_id = instance['metadata']['ec2_id'] - ec2_instances = self.ec2_conn.get_only_instances(instance_ids=[ec2_id], filters=None, dry_run=False, max_results=None) + ec2_instances = self.ec2_conn.get_only_instances(instance_ids=[ec2_id], filters=None, dry_run=False, + max_results=None) if ec2_instances.__len__() == 0: LOG.warning(_("EC2 instance with ID %s not found") % ec2_id, instance=instance) raise exception.InstanceNotFound(instance_id=instance['name']) @@ -547,20 +559,21 @@ class EC2Driver(driver.ComputeDriver): instance = self.nova.servers.get(instance_name) ec2_id = instance.metadata['ec2_id'] - ec2_instances = self.ec2_conn.get_only_instances(instance_ids=[ec2_id], filters=None, dry_run=False, max_results=None) + ec2_instances = self.ec2_conn.get_only_instances(instance_ids=[ec2_id], filters=None, dry_run=False, + max_results=None) if ec2_instances.__len__() == 0: LOG.warning(_("EC2 instance with ID %s not found") % ec2_id, instance=instance) raise exception.InstanceNotFound(instance_id=instance['name']) ec2_instance = ec2_instances[0] diagnostics = {} - for key, value in ec2_instance.__dict__.items() : + for key, value in ec2_instance.__dict__.items(): if self.allow_key(key): diagnostics['instance.' + key] = str(value) - metrics = self.cloudwatch_conn.list_metrics(dimensions={'InstanceId': ec2_id}) import datetime + for metric in metrics: end = datetime.datetime.utcnow() start = end - datetime.timedelta(hours=1) @@ -616,10 +629,13 @@ class EC2Driver(driver.ComputeDriver): return [instance for instance in (self.nova.servers.list()) if openstack_security_group.name in [group['name'] for group in instance.security_groups]] - def _get_id_of_ec2_instance_to_update_security_group(self, ec2_instance_ids_for_security_group, ec2_ids_for_openstack_instances_for_security_group): - return (set(ec2_ids_for_openstack_instances_for_security_group).symmetric_difference(set(ec2_instance_ids_for_security_group))).pop() + def _get_id_of_ec2_instance_to_update_security_group(self, ec2_instance_ids_for_security_group, + ec2_ids_for_openstack_instances_for_security_group): + return (set(ec2_ids_for_openstack_instances_for_security_group).symmetric_difference( + set(ec2_instance_ids_for_security_group))).pop() - def _should_add_security_group_to_instance(self, ec2_instance_ids_for_security_group, ec2_ids_for_openstack_instances_for_security_group): + def _should_add_security_group_to_instance(self, ec2_instance_ids_for_security_group, + ec2_ids_for_openstack_instances_for_security_group): return len(ec2_instance_ids_for_security_group) < len(ec2_ids_for_openstack_instances_for_security_group) def _add_security_group_to_instance(self, ec2_instance_id, ec2_security_group): @@ -642,15 +658,17 @@ class EC2Driver(driver.ComputeDriver): return self.ec2_conn.get_all_security_groups(openstack_security_group.name)[0] except (EC2ResponseError, IndexError) as e: LOG.warning(e) - return self.ec2_conn.create_security_group(openstack_security_group.name, openstack_security_group.description) + return self.ec2_conn.create_security_group(openstack_security_group.name, + openstack_security_group.description) def refresh_security_group_rules(self, security_group_id): LOG.info("************** REFRESH SECURITY GROUP RULES ******************") - openstack_security_group = self.nova.security_groups.get(security_group_id) + openstack_security_group = self.nova. security_groups.get(security_group_id) ec2_security_group = self._get_or_create_ec2_security_group(openstack_security_group) - ec2_ids_for_ec2_instances_with_security_group = self._get_ec2_instance_ids_with_security_group(ec2_security_group) + ec2_ids_for_ec2_instances_with_security_group = self._get_ec2_instance_ids_with_security_group( + ec2_security_group) ec2_ids_for_openstack_instances_with_security_group = [ instance.metadata['ec2_id'] for instance @@ -683,10 +701,62 @@ class EC2Driver(driver.ComputeDriver): LOG.info(security_group_id) return True + def _get_allowed_group_name_from_openstack_rule_if_present(self, openstack_rule): + return openstack_rule['group']['name'] if 'name' in openstack_rule['group'] else None + + def _get_allowed_ip_range_from_openstack_rule_if_present(self, openstack_rule): + return openstack_rule['ip_range']['cidr'] if 'cidr' in openstack_rule['ip_range'] else None + def refresh_instance_security_rules(self, instance): LOG.info("************** REFRESH INSTANCE SECURITY RULES ******************") LOG.info(instance) - return True + + # TODO: lock for case when group is associated with multiple instances [Cameron & Ed] + + openstack_instance = self.nova.servers.get(instance['id']) + + for group_dict in openstack_instance.security_groups: + + openstack_group =\ + [group for group in self.nova.security_groups.list() if group.name == group_dict['name']][0] + + ec2_group = self.ec2_conn.get_all_security_groups(groupnames=group_dict['name'])[0] + + for openstack_rule in openstack_group.rules: + equivalent_rule_found_in_ec2 = False + for ec2_rule in ec2_group.rules: + if self.rule_comparator.rules_are_equal(openstack_rule, ec2_rule): + equivalent_rule_found_in_ec2 = True + break + + if not equivalent_rule_found_in_ec2: + self.ec2_conn.authorize_security_group( + group_name=ec2_group.name, + ip_protocol=openstack_rule['ip_protocol'], + from_port=openstack_rule['from_port'], + to_port=openstack_rule['to_port'], + src_security_group_name=self._get_allowed_group_name_from_openstack_rule_if_present(openstack_rule), + cidr_ip=self._get_allowed_ip_range_from_openstack_rule_if_present(openstack_rule) + ) + + for ec2_rule in ec2_group.rules: + equivalent_rule_found_in_openstack = False + for openstack_rule in openstack_group.rules: + if self.rule_comparator.rules_are_equal(openstack_rule, ec2_rule): + equivalent_rule_found_in_openstack = True + break + + if not equivalent_rule_found_in_openstack: + self.ec2_conn.revoke_security_group( + group_name=ec2_group.name, + ip_protocol=ec2_rule.ip_protocol, + from_port=ec2_rule.from_port, + to_port=ec2_rule.to_port, + cidr_ip=ec2_rule.grants[0].cidr_ip, + src_security_group_group_id=ec2_rule.grants[0].group_id + ) + + return def refresh_provider_fw_rules(self): pass @@ -874,9 +944,7 @@ class EC2Driver(driver.ComputeDriver): timer.start(interval=0.5).wait() - class EC2VirtAPI(virtapi.VirtAPI): - def instance_update(self, context, instance_uuid, updates): return db.instance_update_and_get_original(context, instance_uuid, diff --git a/rule_comparator.py b/rule_comparator.py index 74d3ebe..1370f68 100644 --- a/rule_comparator.py +++ b/rule_comparator.py @@ -3,23 +3,22 @@ class RuleComparator: self.ec2_connection = ec2_connection def rules_are_equal(self, openstack_rule, ec2_rule): - if self._ip_protocols_are_different(ec2_rule, openstack_rule)\ - or self._from_ports_are_different(ec2_rule, openstack_rule)\ - or self._to_ports_are_different(ec2_rule, openstack_rule)\ - or self._ip_ranges_are_present_and_different(ec2_rule, openstack_rule)\ + if self._ip_protocols_are_different(ec2_rule, openstack_rule) \ + or self._from_ports_are_different(ec2_rule, openstack_rule) \ + or self._to_ports_are_different(ec2_rule, openstack_rule) \ + or self._ip_ranges_are_present_and_different(ec2_rule, openstack_rule) \ or self._group_names_are_present_and_different(openstack_rule, ec2_rule): return False - return True def _ip_protocols_are_different(self, ec2_rule, openstack_rule): return openstack_rule['ip_protocol'] != ec2_rule.ip_protocol def _from_ports_are_different(self, ec2_rule, openstack_rule): - return openstack_rule['from_port'] != ec2_rule.from_port + return str(openstack_rule['from_port']) != ec2_rule.from_port def _to_ports_are_different(self, ec2_rule, openstack_rule): - return openstack_rule['to_port'] != ec2_rule.to_port + return str(openstack_rule['to_port']) != ec2_rule.to_port def _ip_ranges_are_present_and_different(self, ec2_rule, openstack_rule): return ('cidr' in openstack_rule['ip_range'] and openstack_rule['ip_range']['cidr'] != ec2_rule.grants[0].cidr_ip) @@ -28,6 +27,5 @@ class RuleComparator: if 'name' not in openstack_rule['group']: return False else: - openstack_group = openstack_rule['group'] - ec2_group_name = self.ec2_connection.get_all_security_groups(ec2_rule.grants[0].group_id)[0].name - return openstack_group['name'] == ec2_group_name + ec2_group_name = self.ec2_connection.get_all_security_groups(group_ids=ec2_rule.grants[0].group_id)[0].name + return openstack_rule['group']['name'] != ec2_group_name diff --git a/tests/fake_ec2_rule_builder.py b/tests/fake_ec2_rule_builder.py index d4b2390..1dd43d6 100644 --- a/tests/fake_ec2_rule_builder.py +++ b/tests/fake_ec2_rule_builder.py @@ -7,8 +7,8 @@ class FakeEC2RuleBuilder(): def __init__(self): self.ip_protocol = 'udp' - self.from_port = 1111 - self.to_port = 3333 + self.from_port = '1111' + self.to_port = '3333' self.ip_range = '0.0.0.0/0' self.allowed_security_group_id = None diff --git a/tests/test_rule_comparator.py b/tests/test_rule_comparator.py index 288997b..2cd0a45 100644 --- a/tests/test_rule_comparator.py +++ b/tests/test_rule_comparator.py @@ -61,7 +61,7 @@ class TestRuleComparator(unittest.TestCase): self.openstack_rule['ip_range'] = {} self.openstack_rule['group'] = {'name': 'secGroup'} - self.ec2_connection.get_all_security_groups.return_value = [self.FakeSecurityGroup('secGroup')] + self.ec2_connection.get_all_security_groups.return_value = [self.FakeSecurityGroup('secGroup2')] ec2_rule = FakeEC2RuleBuilder.an_ec2_rule()\ .with_allowed_security_group_id(5)\ diff --git a/tests/test_security_groups.py b/tests/test_security_groups.py index 70ca191..04517a4 100644 --- a/tests/test_security_groups.py +++ b/tests/test_security_groups.py @@ -25,7 +25,7 @@ class TestSecurityGroups(EC2TestBase): @unittest.skipIf(os.environ.get('MOCK_EC2'), 'Not supported by moto') def test_should_add_security_group_to_ec2_instance(self): - self.assertEqual(self.instance.metadata['ec2_id'], self.matching_ec2_security_groups[0].instances()[0].id) + self.assertEqual(self.matching_ec2_security_groups[0].instances()[0].id, self.instance.metadata['ec2_id']) @unittest.skipIf(os.environ.get('MOCK_EC2'), 'Not supported by moto') def test_should_remove_security_group_from_ec2_instance(self): @@ -36,8 +36,23 @@ class TestSecurityGroups(EC2TestBase): updated_matching_ec2_security_group = self._wait_for_ec2_group_to_have_no_instances(self.security_group) self.assertEqual(updated_matching_ec2_security_group.instances(), []) - def test_should_add_rule_to_ec2_security_group_when_group_has_an_instance(self): - pass + def test_should_add_rule_to_ec2_security_group_when_group_is_added_to_an_instance(self): + + security_group_rule = self.nova.security_group_rules.create( + parent_group_id=self.security_group.id, + ip_protocol='tcp', + from_port='1234', + to_port='4321', + cidr='0.0.0.0/0' + ) + + updated_security_group = self.nova.security_groups.get(self.security_group.id) + + ec2_security_group = self.ec2_conn.get_all_security_groups(groupnames=self.security_group.name)[0] + ec2_rule = ec2_security_group.rules[0] + + self.assertEqual(ec2_rule.ip_protocol, security_group_rule.ip_protocol) + #etc def _destroy_security_group(self): print "Cleanup: Destroying security group"