Move some methods from mdns to dnsutils
Agent needs some methods in mdns. These methods are moved to dnsutils to prevent code duplication. Change-Id: I9794341b2bfa06b34b994ed028e41a03e27f196d Closes-Bug: 1413387
This commit is contained in:
parent
45ffc1ab29
commit
d31411dd78
|
@ -13,10 +13,20 @@
|
|||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import socket
|
||||
import struct
|
||||
|
||||
import dns
|
||||
from dns import rdatatype
|
||||
from oslo_log import log as logging
|
||||
|
||||
from designate import exceptions
|
||||
from designate import objects
|
||||
from designate.i18n import _LE
|
||||
from designate.i18n import _LI
|
||||
from designate.i18n import _LW
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def from_dnspython_zone(dnspython_zone):
|
||||
|
@ -73,3 +83,142 @@ def dnspythonrecord_to_recordset(rname, rdataset):
|
|||
rr = objects.Record(data=rdata.to_text())
|
||||
rrset.records.append(rr)
|
||||
return rrset
|
||||
|
||||
|
||||
def _deserialize_request(payload, addr):
|
||||
"""
|
||||
Deserialize a DNS Request Packet
|
||||
|
||||
:param payload: Raw DNS query payload
|
||||
:param addr: Tuple of the client's (IP, Port)
|
||||
"""
|
||||
try:
|
||||
request = dns.message.from_wire(payload)
|
||||
except dns.exception.DNSException:
|
||||
LOG.error(_LE("Failed to deserialize packet from %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
return None
|
||||
else:
|
||||
# Create + Attach the initial "environ" dict. This is similar to
|
||||
# the environ dict used in typical WSGI middleware.
|
||||
request.environ = {'addr': addr}
|
||||
return request
|
||||
|
||||
|
||||
def _serialize_response(response):
|
||||
"""
|
||||
Serialize a DNS Response Packet
|
||||
|
||||
:param response: DNS Response Message
|
||||
"""
|
||||
return response.to_wire()
|
||||
|
||||
|
||||
def bind_tcp(host, port, tcp_backlog):
|
||||
# Bind to the TCP port
|
||||
LOG.info(_LI('Opening TCP Listening Socket on %(host)s:%(port)d') %
|
||||
{'host': host, 'port': port})
|
||||
sock_tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock_tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock_tcp.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
sock_tcp.bind((host, port))
|
||||
sock_tcp.listen(tcp_backlog)
|
||||
|
||||
return sock_tcp
|
||||
|
||||
|
||||
def bind_udp(host, port):
|
||||
# Bind to the UDP port
|
||||
LOG.info(_LI('Opening UDP Listening Socket on %(host)s:%(port)d') %
|
||||
{'host': host, 'port': port})
|
||||
sock_udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock_udp.bind((host, port))
|
||||
|
||||
return sock_udp
|
||||
|
||||
|
||||
def handle_tcp(sock_tcp, tg, handle, application, timeout=None):
|
||||
LOG.info(_LI("_handle_tcp thread started"))
|
||||
while True:
|
||||
client, addr = sock_tcp.accept()
|
||||
if timeout:
|
||||
client.settimeout(timeout)
|
||||
|
||||
LOG.info(_LI("Handling TCP Request from: %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
||||
# Prepare a variable for the payload to be buffered
|
||||
payload = ""
|
||||
|
||||
try:
|
||||
# Receive the first 2 bytes containing the payload length
|
||||
expected_length_raw = client.recv(2)
|
||||
(expected_length, ) = struct.unpack('!H', expected_length_raw)
|
||||
|
||||
# Keep receiving data until we've got all the data we expect
|
||||
while len(payload) < expected_length:
|
||||
data = client.recv(65535)
|
||||
if not data:
|
||||
break
|
||||
payload += data
|
||||
|
||||
except socket.timeout:
|
||||
client.close()
|
||||
LOG.warn(_LW("TCP Timeout from: %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
||||
# Dispatch a thread to handle the query
|
||||
tg.add_thread(handle, addr, payload, application, client=client)
|
||||
|
||||
|
||||
def handle_udp(sock_udp, tg, handle, application):
|
||||
LOG.info(_LI("_handle_udp thread started"))
|
||||
while True:
|
||||
# TODO(kiall): Determine the appropriate default value for
|
||||
# UDP recvfrom.
|
||||
payload, addr = sock_udp.recvfrom(8192)
|
||||
LOG.info(_LI("Handling UDP Request from: %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
||||
tg.add_thread(handle, addr, payload, application, sock_udp=sock_udp)
|
||||
|
||||
|
||||
def handle(addr, payload, application, sock_udp=None, client=None):
|
||||
"""
|
||||
Handle a DNS Query
|
||||
|
||||
:param addr: Tuple of the client's (IP, Port)
|
||||
:param payload: Raw DNS query payload
|
||||
:param client: Client socket (for TCP only)
|
||||
"""
|
||||
try:
|
||||
request = _deserialize_request(payload, addr)
|
||||
|
||||
if request is None:
|
||||
# We failed to deserialize the request, generate a failure
|
||||
# response using a made up request.
|
||||
response = dns.message.make_response(
|
||||
dns.message.make_query('unknown', dns.rdatatype.A))
|
||||
response.set_rcode(dns.rcode.FORMERR)
|
||||
else:
|
||||
response = application(request)
|
||||
|
||||
# send back a response only if present
|
||||
if response:
|
||||
response = _serialize_response(response)
|
||||
|
||||
if client:
|
||||
# Handle TCP Responses
|
||||
msg_length = len(response)
|
||||
tcp_response = struct.pack("!H", msg_length) + response
|
||||
client.send(tcp_response)
|
||||
client.close()
|
||||
elif sock_udp:
|
||||
# Handle UDP Responses
|
||||
sock_udp.sendto(response, addr)
|
||||
else:
|
||||
LOG.warn(_LW("Both sock_udp and client are None"))
|
||||
except Exception:
|
||||
LOG.exception(_LE("Unhandled exception while processing request "
|
||||
"from %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
|
|
@ -13,20 +13,15 @@
|
|||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import socket
|
||||
import struct
|
||||
|
||||
import dns
|
||||
from oslo.config import cfg
|
||||
from oslo_log import log as logging
|
||||
|
||||
from designate import dnsutils
|
||||
from designate import service
|
||||
from designate.mdns import handler
|
||||
from designate.mdns import middleware
|
||||
from designate.mdns import notify
|
||||
from designate.i18n import _LE
|
||||
from designate.i18n import _LI
|
||||
from designate.i18n import _LW
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
CONF = cfg.CONF
|
||||
|
@ -48,142 +43,26 @@ class Service(service.RPCService):
|
|||
# in the API.
|
||||
self.application = middleware.ContextMiddleware(self.application)
|
||||
|
||||
# Bind to the TCP port
|
||||
LOG.info(_LI('Opening TCP Listening Socket on %(host)s:%(port)d') %
|
||||
{'host': CONF['service:mdns'].host,
|
||||
'port': CONF['service:mdns'].port})
|
||||
self._sock_tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._sock_tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self._sock_tcp.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
self._sock_tcp.bind((CONF['service:mdns'].host,
|
||||
CONF['service:mdns'].port))
|
||||
self._sock_tcp.listen(CONF['service:mdns'].tcp_backlog)
|
||||
self._sock_tcp = dnsutils.bind_tcp(
|
||||
CONF['service:mdns'].host, CONF['service:mdns'].port,
|
||||
CONF['service:mdns'].tcp_backlog)
|
||||
|
||||
# Bind to the UDP port
|
||||
LOG.info(_LI('Opening UDP Listening Socket on %(host)s:%(port)d') %
|
||||
{'host': CONF['service:mdns'].host,
|
||||
'port': CONF['service:mdns'].port})
|
||||
self._sock_udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._sock_udp.bind((CONF['service:mdns'].host,
|
||||
CONF['service:mdns'].port))
|
||||
self._sock_udp = dnsutils.bind_udp(
|
||||
CONF['service:mdns'].host, CONF['service:mdns'].port)
|
||||
|
||||
def start(self):
|
||||
super(Service, self).start()
|
||||
|
||||
self.tg.add_thread(self._handle_tcp)
|
||||
self.tg.add_thread(self._handle_udp)
|
||||
self.tg.add_thread(
|
||||
dnsutils.handle_tcp, self._sock_tcp, self.tg, dnsutils.handle,
|
||||
self.application, timeout=CONF['service:mdns'].tcp_recv_timeout)
|
||||
self.tg.add_thread(
|
||||
dnsutils.handle_udp, self._sock_udp, self.tg, dnsutils.handle,
|
||||
self.application)
|
||||
LOG.info(_LI("started mdns service"))
|
||||
|
||||
def stop(self):
|
||||
# When the service is stopped, the threads for _handle_tcp and
|
||||
# _handle_udp are stopped too.
|
||||
super(Service, self).stop()
|
||||
|
||||
def _deserialize_request(self, payload, addr):
|
||||
"""
|
||||
Deserialize a DNS Request Packet
|
||||
|
||||
:param payload: Raw DNS query payload
|
||||
:param addr: Tuple of the client's (IP, Port)
|
||||
"""
|
||||
try:
|
||||
request = dns.message.from_wire(payload)
|
||||
except dns.exception.DNSException:
|
||||
LOG.error(_LE("Failed to deserialize packet from "
|
||||
"%(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
return None
|
||||
else:
|
||||
# Create + Attach the initial "environ" dict. This is similar to
|
||||
# the environ dict used in typical WSGI middleware.
|
||||
request.environ = {'addr': addr}
|
||||
return request
|
||||
|
||||
def _serialize_response(self, response):
|
||||
"""
|
||||
Serialize a DNS Response Packet
|
||||
|
||||
:param response: DNS Response Message
|
||||
"""
|
||||
return response.to_wire()
|
||||
|
||||
def _handle_tcp(self):
|
||||
LOG.info(_LI("_handle_tcp thread started"))
|
||||
while True:
|
||||
client, addr = self._sock_tcp.accept()
|
||||
client.settimeout(CONF['service:mdns'].tcp_recv_timeout)
|
||||
|
||||
LOG.warn(_LW("Handling TCP Request from: %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
||||
# Prepare a variable for the payload to be buffered
|
||||
payload = ""
|
||||
|
||||
try:
|
||||
# Receive the first 2 bytes containing the payload length
|
||||
expected_length_raw = client.recv(2)
|
||||
(expected_length, ) = struct.unpack('!H', expected_length_raw)
|
||||
|
||||
# Keep receiving data until we've got all the data we expect
|
||||
while len(payload) < expected_length:
|
||||
data = client.recv(65535)
|
||||
if not data:
|
||||
break
|
||||
payload += data
|
||||
|
||||
except socket.timeout:
|
||||
client.close()
|
||||
LOG.warn(_LW("TCP Timeout from: %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
||||
# Dispatch a thread to handle the query
|
||||
self.tg.add_thread(self._handle, addr, payload, client)
|
||||
|
||||
def _handle_udp(self):
|
||||
LOG.info(_LI("_handle_udp thread started"))
|
||||
while True:
|
||||
# TODO(kiall): Determine the appropriate default value for
|
||||
# UDP recvfrom.
|
||||
payload, addr = self._sock_udp.recvfrom(8192)
|
||||
LOG.warn(_LW("Handling UDP Request from: %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
||||
self.tg.add_thread(self._handle, addr, payload)
|
||||
|
||||
def _handle(self, addr, payload, client=None):
|
||||
"""
|
||||
Handle a DNS Query
|
||||
|
||||
:param addr: Tuple of the client's (IP, Port)
|
||||
:param payload: Raw DNS query payload
|
||||
:param client: Client socket (for TCP only)
|
||||
"""
|
||||
try:
|
||||
request = self._deserialize_request(payload, addr)
|
||||
|
||||
if request is None:
|
||||
# We failed to deserialize the request, generate a failure
|
||||
# response using a made up request.
|
||||
response = dns.message.make_response(
|
||||
dns.message.make_query('unknown', dns.rdatatype.A))
|
||||
response.set_rcode(dns.rcode.FORMERR)
|
||||
else:
|
||||
response = self.application(request)
|
||||
|
||||
# send back a response only if present
|
||||
if response:
|
||||
response = self._serialize_response(response)
|
||||
|
||||
if client is not None:
|
||||
# Handle TCP Responses
|
||||
msg_length = len(response)
|
||||
tcp_response = struct.pack("!H", msg_length) + response
|
||||
client.send(tcp_response)
|
||||
client.close()
|
||||
else:
|
||||
# Handle UDP Responses
|
||||
self._sock_udp.sendto(response, addr)
|
||||
except Exception:
|
||||
LOG.exception(_LE("Unhandled exception while processing request "
|
||||
"from %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
LOG.info(_LI("stopped mdns service"))
|
||||
|
|
|
@ -19,6 +19,7 @@ import socket
|
|||
import dns
|
||||
import mock
|
||||
|
||||
from designate import dnsutils
|
||||
from designate.tests.test_mdns import MdnsTestCase
|
||||
|
||||
|
||||
|
@ -38,7 +39,8 @@ class MdnsServiceTest(MdnsTestCase):
|
|||
|
||||
@mock.patch.object(dns.message, 'make_query')
|
||||
def test_handle_empty_payload(self, query_mock):
|
||||
self.service._handle(self.addr, None)
|
||||
dnsutils.handle(self.addr, None, self.service.application,
|
||||
sock_udp=self.service._sock_udp)
|
||||
query_mock.assert_called_once_with('unknown', dns.rdatatype.A)
|
||||
|
||||
@mock.patch.object(socket.socket, 'sendto', new_callable=mock.MagicMock)
|
||||
|
@ -59,6 +61,8 @@ class MdnsServiceTest(MdnsTestCase):
|
|||
expected_response = ("271289050001000000000000076578616d706c6503636f6d"
|
||||
"0000010001")
|
||||
|
||||
self.service._handle(self.addr, binascii.a2b_hex(payload))
|
||||
dnsutils.handle(
|
||||
self.addr, binascii.a2b_hex(payload), self.service.application,
|
||||
sock_udp=self.service._sock_udp)
|
||||
sendto_mock.assert_called_once_with(
|
||||
binascii.a2b_hex(expected_response), self.addr)
|
||||
|
|
Loading…
Reference in New Issue