Merge "Fix create call for security group rules"
This commit is contained in:
commit
358240964f
|
@ -142,3 +142,24 @@ control_exchange = trove
|
|||
[mysql]
|
||||
|
||||
root_on_create = False
|
||||
|
||||
# ================= Security groups related ========================
|
||||
# Each future datastore implementation should implement
|
||||
# its own oslo group with defined in it:
|
||||
# - tcp_ports; upd_ports;
|
||||
|
||||
[mysql]
|
||||
# Format (single port or port range): A, B-C
|
||||
# where C greater than B
|
||||
tcp_ports = 3306
|
||||
|
||||
[redis]
|
||||
# Format (single port or port range): A, B-C
|
||||
# where C greater than B
|
||||
tcp_ports = 6379
|
||||
|
||||
[cassandra]
|
||||
tcp_ports = 7000, 7001, 9042, 9160
|
||||
|
||||
[couchbase]
|
||||
tcp_ports = 8091, 8092, 4369, 11209-11211, 21100-21199
|
|
@ -334,3 +334,13 @@ def try_recover(func):
|
|||
'func': func.__name__})
|
||||
raise
|
||||
return _decorator
|
||||
|
||||
|
||||
def gen_ports(portstr):
|
||||
from_port, sep, to_port = portstr.partition('-')
|
||||
if not (to_port and from_port):
|
||||
if not sep:
|
||||
to_port = from_port
|
||||
if int(from_port) > int(to_port):
|
||||
raise ValueError
|
||||
return from_port, to_port
|
||||
|
|
|
@ -13,18 +13,19 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
|
||||
from trove.common import cfg
|
||||
from trove.common import exception
|
||||
from trove.common import wsgi
|
||||
from trove.common import cfg
|
||||
from trove.common import utils
|
||||
from trove.datastore.models import DatastoreVersion
|
||||
from trove.extensions.security_group import models
|
||||
from trove.extensions.security_group import views
|
||||
from trove.instance import models as instance_models
|
||||
from trove.openstack.common import log as logging
|
||||
from trove.openstack.common.gettextutils import _
|
||||
|
||||
CONF = cfg.CONF
|
||||
LOG = logging.getLogger(__name__)
|
||||
CONF = cfg.CONF
|
||||
|
||||
|
||||
class SecurityGroupController(wsgi.Controller):
|
||||
|
@ -87,24 +88,38 @@ class SecurityGroupRuleController(wsgi.Controller):
|
|||
sec_group = models.SecurityGroup.find_by(id=sec_group_id,
|
||||
tenant_id=tenant_id,
|
||||
deleted=False)
|
||||
instance_id = (models.SecurityGroupInstanceAssociation.
|
||||
get_instance_id_by_security_group_id(sec_group_id))
|
||||
db_info = instance_models.get_db_info(context, id=instance_id)
|
||||
manager = (DatastoreVersion.load_by_uuid(
|
||||
db_info.datastore_version_id).manager)
|
||||
tcp_ports = CONF.get(manager).tcp_ports
|
||||
udp_ports = CONF.get(manager).udp_ports
|
||||
|
||||
sec_group_rule = models.SecurityGroupRule.create_sec_group_rule(
|
||||
sec_group,
|
||||
CONF.trove_security_group_rule_protocol,
|
||||
CONF.trove_security_group_rule_port,
|
||||
CONF.trove_security_group_rule_port,
|
||||
body['security_group_rule']['cidr'],
|
||||
context)
|
||||
def _create_rules(sec_group, ports, protocol):
|
||||
rules = []
|
||||
try:
|
||||
for port_or_range in set(ports):
|
||||
from_, to_ = utils.gen_ports(port_or_range)
|
||||
rule = models.SecurityGroupRule.create_sec_group_rule(
|
||||
sec_group, protocol, int(from_), int(to_),
|
||||
body['security_group_rule']['cidr'], context)
|
||||
rules.append(rule)
|
||||
except (ValueError, AttributeError) as e:
|
||||
raise exception.BadRequest(msg=str(e))
|
||||
return rules
|
||||
|
||||
resultView = views.SecurityGroupRulesView(sec_group_rule,
|
||||
req,
|
||||
tenant_id).create()
|
||||
return wsgi.Result(resultView, 201)
|
||||
tcp_rules = _create_rules(sec_group, tcp_ports, 'tcp')
|
||||
udp_rules = _create_rules(sec_group, udp_ports, 'udp')
|
||||
|
||||
all_rules = tcp_rules + udp_rules
|
||||
|
||||
view = views.SecurityGroupRulesView(
|
||||
all_rules, req, tenant_id).create()
|
||||
return wsgi.Result(view, 201)
|
||||
|
||||
def _validate_create_body(self, body):
|
||||
try:
|
||||
# TODO(slicknik): Add some better validation here around ports,
|
||||
# protocol, and cidr values.
|
||||
body['security_group_rule']
|
||||
body['security_group_rule']['group_id']
|
||||
body['security_group_rule']['cidr']
|
||||
|
@ -134,16 +149,6 @@ class SecurityGroupRuleController(wsgi.Controller):
|
|||
"required": True,
|
||||
"maxLength": 255
|
||||
},
|
||||
"from_port": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"maximum": 65535
|
||||
},
|
||||
"to_port": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"maximum": 65535
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,8 +90,8 @@ class SecurityGroupsView(object):
|
|||
groups_data = []
|
||||
|
||||
for secgroup in self.secgroups:
|
||||
rules = \
|
||||
self.rules[secgroup['id']] if self.rules is not None else None
|
||||
rules = (self.rules[secgroup['id']]
|
||||
if self.rules is not None else None)
|
||||
groups_data.append(SecurityGroupView(secgroup,
|
||||
rules,
|
||||
self.request,
|
||||
|
@ -102,22 +102,25 @@ class SecurityGroupsView(object):
|
|||
|
||||
class SecurityGroupRulesView(object):
|
||||
|
||||
def __init__(self, rule, req, tenant_id):
|
||||
self.rule = rule
|
||||
def __init__(self, rules, req, tenant_id):
|
||||
self.rules = rules
|
||||
self.request = req
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def _build_create(self):
|
||||
return {"security_group_rule":
|
||||
{"id": str(self.rule['id']),
|
||||
"security_group_id": self.rule['group_id'],
|
||||
"protocol": self.rule['protocol'],
|
||||
"from_port": self.rule['from_port'],
|
||||
"to_port": self.rule['to_port'],
|
||||
"cidr": self.rule['cidr'],
|
||||
"created": self.rule['created']
|
||||
}
|
||||
}
|
||||
views = []
|
||||
for rule in self.rules:
|
||||
to_append = {
|
||||
"id": rule.id,
|
||||
"security_group_id": rule.group_id,
|
||||
"protocol": rule.protocol,
|
||||
"from_port": rule.from_port,
|
||||
"to_port": rule.to_port,
|
||||
"cidr": rule.cidr,
|
||||
"created": rule.created
|
||||
}
|
||||
views.append(to_append)
|
||||
return {"security_group_rule": views}
|
||||
|
||||
def create(self):
|
||||
return self._build_create()
|
||||
|
|
|
@ -635,8 +635,8 @@ class FreshInstanceTasks(FreshInstance, NotifyMixin, ConfigurationMixin):
|
|||
{'gt': greenthread.getcurrent(), 'id': self.id})
|
||||
|
||||
def _create_secgroup(self, datastore_manager):
|
||||
security_group = SecurityGroup.create_for_instance(self.id,
|
||||
self.context)
|
||||
security_group = SecurityGroup.create_for_instance(
|
||||
self.id, self.context)
|
||||
tcp_ports = CONF.get(datastore_manager).tcp_ports
|
||||
udp_ports = CONF.get(datastore_manager).udp_ports
|
||||
self._create_rules(security_group, tcp_ports, 'tcp')
|
||||
|
@ -655,27 +655,15 @@ class FreshInstanceTasks(FreshInstance, NotifyMixin, ConfigurationMixin):
|
|||
msg = err_msg % {'from': from_port, 'to': to_port}
|
||||
raise MalformedSecurityGroupRuleError(message=msg)
|
||||
|
||||
def _gen_ports(portstr):
|
||||
from_port, sep, to_port = portstr.partition('-')
|
||||
if not (to_port and from_port):
|
||||
if not sep:
|
||||
to_port = from_port
|
||||
try:
|
||||
if int(from_port) > int(to_port):
|
||||
set_error_and_raise([from_port, to_port])
|
||||
except ValueError:
|
||||
set_error_and_raise([from_port, to_port])
|
||||
return from_port, to_port
|
||||
|
||||
for port_or_range in set(ports):
|
||||
|
||||
from_, to_ = _gen_ports(port_or_range)
|
||||
try:
|
||||
from_, to_ = (None, None)
|
||||
from_, to_ = utils.gen_ports(port_or_range)
|
||||
cidr = CONF.trove_security_group_rule_cidr
|
||||
SecurityGroupRule.create_sec_group_rule(
|
||||
s_group, protocol, int(from_), int(to_),
|
||||
CONF.trove_security_group_rule_cidr,
|
||||
self.context)
|
||||
except TroveError:
|
||||
cidr, self.context)
|
||||
except (ValueError, TroveError):
|
||||
set_error_and_raise([from_, to_])
|
||||
|
||||
def _build_heat_nics(self, nics):
|
||||
|
|
|
@ -823,32 +823,22 @@ class SecurityGroupsRulesTest(object):
|
|||
|
||||
@test
|
||||
def test_create_security_group_rule(self):
|
||||
if len(self.testSecurityGroup.rules) == 0:
|
||||
self.testSecurityGroupRule = \
|
||||
dbaas.security_group_rules.create(
|
||||
group_id=self.testSecurityGroup.id,
|
||||
protocol="tcp",
|
||||
from_port=3306,
|
||||
to_port=3306,
|
||||
cidr="0.0.0.0/0")
|
||||
assert_is_not_none(self.testSecurityGroupRule)
|
||||
with TypeCheck('SecurityGroupRule',
|
||||
self.testSecurityGroupRule) as secGrpRule:
|
||||
secGrpRule.has_field('id', basestring)
|
||||
secGrpRule.has_field('security_group_id', basestring)
|
||||
secGrpRule.has_field('protocol', basestring)
|
||||
secGrpRule.has_field('cidr', basestring)
|
||||
secGrpRule.has_field('from_port', int)
|
||||
secGrpRule.has_field('to_port', int)
|
||||
secGrpRule.has_field('created', basestring)
|
||||
assert_equal(self.testSecurityGroupRule.security_group_id,
|
||||
cidr = "1.2.3.4/16"
|
||||
self.testSecurityGroupRules = (
|
||||
dbaas.security_group_rules.create(
|
||||
group_id=self.testSecurityGroup.id,
|
||||
cidr=cidr))
|
||||
assert_not_equal(len(self.testSecurityGroupRules), 0)
|
||||
assert_is_not_none(self.testSecurityGroupRules)
|
||||
for rule in self.testSecurityGroupRules:
|
||||
assert_is_not_none(rule)
|
||||
assert_equal(rule['security_group_id'],
|
||||
self.testSecurityGroup.id)
|
||||
assert_equal(self.testSecurityGroupRule.protocol, "tcp")
|
||||
assert_equal(int(self.testSecurityGroupRule.from_port), 3306)
|
||||
assert_equal(int(self.testSecurityGroupRule.to_port), 3306)
|
||||
assert_equal(self.testSecurityGroupRule.cidr, "0.0.0.0/0")
|
||||
else:
|
||||
assert_not_equal(len(self.testSecurityGroup.rules), 0)
|
||||
assert_is_not_none(rule['id'])
|
||||
assert_equal(rule['cidr'], cidr)
|
||||
assert_equal(rule['from_port'], 3306)
|
||||
assert_equal(rule['to_port'], 3306)
|
||||
assert_is_not_none(rule['created'])
|
||||
|
||||
|
||||
@test(depends_on_classes=[WaitForGuestInstallationToFinish],
|
||||
|
|
Loading…
Reference in New Issue