designate/designate/tests/unit/test_mdns/test_handler.py

208 lines
7.6 KiB
Python

# Copyright 2014 Rackspace Inc.
#
# Author: Eric Larson <eric.larson@rackspace.com>
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
from mock import Mock
from oslo_log import log as logging
import dns
from designate import exceptions
from designate import objects
from designate.mdns import handler
LOG = logging.getLogger(__name__)
class TestRequestHandlerCall(unittest.TestCase):
"""
Unit test to assert the dispatching based on the request
operation.
"""
def setUp(self):
self.storage = Mock()
self.tg = Mock()
self.handler = handler.RequestHandler(self.storage, self.tg)
# Use a simple handlers that doesn't require a real request
self.handler._handle_query_error = Mock(return_value='Error')
self.handler._handle_axfr = Mock(return_value=['AXFR'])
self.handler._handle_record_query = Mock(return_value=['Record Query'])
self.handler._handle_notify = Mock(return_value=['Notify'])
def assert_error(self, request, error_type):
self.handler._handle_query_error.assert_called_with(
request, error_type
)
def test_central_api_property(self):
self.handler._central_api = 'foo'
assert self.handler.central_api == 'foo'
def test___call___unhandled_opcodes(self):
unhandled_codes = [
dns.opcode.STATUS,
dns.opcode.IQUERY,
dns.opcode.UPDATE,
]
request = Mock()
for code in unhandled_codes:
request.opcode.return_value = code # return an error
assert list(self.handler(request)) == ['Error']
self.assert_error(request, dns.rcode.REFUSED)
def test___call__query_error_with_more_than_one_question(self):
request = Mock()
request.opcode.return_value = dns.opcode.QUERY
request.question = [Mock(), Mock()]
assert list(self.handler(request)) == ['Error']
self.assert_error(request, dns.rcode.REFUSED)
def test___call__query_error_with_data_claas_not_in(self):
request = Mock()
request.opcode.return_value = dns.opcode.QUERY
request.question = [Mock(rdclass=dns.rdataclass.ANY)]
assert list(self.handler(request)) == ['Error']
self.assert_error(request, dns.rcode.REFUSED)
def test___call__axfr(self):
request = Mock()
request.opcode.return_value = dns.opcode.QUERY
request.question = [
Mock(rdclass=dns.rdataclass.IN, rdtype=dns.rdatatype.AXFR)
]
assert list(self.handler(request)) == ['AXFR']
def test___call__ixfr(self):
request = Mock()
request.opcode.return_value = dns.opcode.QUERY
request.question = [
Mock(rdclass=dns.rdataclass.IN, rdtype=dns.rdatatype.IXFR)
]
assert list(self.handler(request)) == ['AXFR']
def test___call__record_query(self):
request = Mock()
request.opcode.return_value = dns.opcode.QUERY
request.question = [
Mock(rdclass=dns.rdataclass.IN, rdtype=dns.rdatatype.A)
]
assert list(self.handler(request)) == ['Record Query']
def test___call__notify(self):
request = Mock()
request.opcode.return_value = dns.opcode.NOTIFY
assert list(self.handler(request)) == ['Notify']
def test__convert_to_rrset_no_records(self):
zone = objects.Zone.from_dict({'ttl': 1234})
recordset = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
])
)
r_rrset = self.handler._convert_to_rrset(zone, recordset)
self.assertIsNone(r_rrset)
def test__convert_to_rrset(self):
zone = objects.Zone.from_dict({'ttl': 1234})
recordset = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
objects.Record(data='192.0.2.1'),
objects.Record(data='192.0.2.2'),
])
)
r_rrset = self.handler._convert_to_rrset(zone, recordset)
self.assertEqual(2, len(r_rrset))
class HandleRecordQueryTest(unittest.TestCase):
def setUp(self):
self.storage = Mock()
self.tg = Mock()
self.handler = handler.RequestHandler(self.storage, self.tg)
def test__handle_record_query_empty_recordlist(self):
# bug #1550441
self.storage.find_recordset.return_value = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
])
)
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response_gen = self.handler._handle_record_query(request)
for r in response_gen:
# This was raising an exception due to bug #1550441
out = r.to_wire(max_size=65535)
self.assertEqual(33, len(out))
def test__handle_record_query_zone_not_found(self):
self.storage.find_recordset.return_value = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
objects.Record(data='192.0.2.2'),
])
)
self.storage.find_zone.side_effect = exceptions.ZoneNotFound
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response = tuple(self.handler._handle_record_query(request))
self.assertEqual(1, len(response))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())
def test__handle_record_query_forbidden(self):
self.storage.find_recordset.return_value = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
objects.Record(data='192.0.2.2'),
])
)
self.storage.find_zone.side_effect = exceptions.Forbidden
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response = tuple(self.handler._handle_record_query(request))
self.assertEqual(1, len(response))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())
def test__handle_record_query_find_recordsed_forbidden(self):
self.storage.find_recordset.side_effect = exceptions.Forbidden
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response = tuple(self.handler._handle_record_query(request))
self.assertEqual(1, len(response))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())
def test__handle_record_query_find_recordsed_not_found(self):
self.storage.find_recordset.side_effect = exceptions.NotFound
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response = tuple(self.handler._handle_record_query(request))
self.assertEqual(1, len(response))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())