BIND 10 trac816, updated. 560e7210fa3006c98c711e96cdbac6f9a0b391a0 [trac816] update xfrout to use tsig
BIND 10 source code commits
bind10-changes at lists.isc.org
Wed Jun 1 07:46:51 UTC 2011
The branch, trac816 has been updated
via 560e7210fa3006c98c711e96cdbac6f9a0b391a0 (commit)
from c37ebedf94c5dbbed3c685272a0cdc4bea67fb04 (commit)
Those revisions listed above that are new to this repository have
not appeared on any other notification email; so we list those
revisions in full, below.
- Log -----------------------------------------------------------------
commit 560e7210fa3006c98c711e96cdbac6f9a0b391a0
Author: chenzhengzhang <jerry.zzpku at gmail.com>
Date: Wed Jun 1 15:46:27 2011 +0800
[trac816] update xfrout to use tsig
-----------------------------------------------------------------------
Summary of changes:
src/bin/xfrout/tests/xfrout_test.py.in | 143 +++++++++++++++++++++++++++++++-
src/bin/xfrout/xfrout.py.in | 96 ++++++++++++++++++---
src/bin/xfrout/xfrout.spec.pre.in | 12 +++
3 files changed, 236 insertions(+), 15 deletions(-)
-----------------------------------------------------------------------
diff --git a/src/bin/xfrout/tests/xfrout_test.py.in b/src/bin/xfrout/tests/xfrout_test.py.in
index 472ef3c..9075543 100644
--- a/src/bin/xfrout/tests/xfrout_test.py.in
+++ b/src/bin/xfrout/tests/xfrout_test.py.in
@@ -18,11 +18,14 @@
import unittest
import os
+from isc.testutils.tsigctx_mock import MockTSIGContext
from isc.cc.session import *
from pydnspp import *
from xfrout import *
import xfrout
+TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
+
# our fake socket, where we can read and insert messages
class MySocket():
def __init__(self, family, type):
@@ -85,10 +88,36 @@ class TestXfroutSession(unittest.TestCase):
msg.from_wire(self.mdata)
return msg
+ def create_mock_tsig_ctx(self, error):
+ # This helper function creates a MockTSIGContext for a given key
+ # and TSIG error to be used as a result of verify (normally faked
+ # one)
+ mock_ctx = MockTSIGContext(TSIG_KEY)
+ mock_ctx.error = error
+ return mock_ctx
+
+ def message_has_tsig(self, msg):
+ return msg.get_tsig_record() is not None
+
+ def create_request_data_with_tsig(self):
+ msg = Message(Message.RENDER)
+ query_id = 0x1035
+ msg.set_qid(query_id)
+ msg.set_opcode(Opcode.QUERY())
+ msg.set_rcode(Rcode.NOERROR())
+ query_question = Question(Name("example.com."), RRClass.IN(), RRType.AXFR())
+ msg.add_question(query_question)
+
+ renderer = MessageRenderer()
+ tsig_ctx = MockTSIGContext(TSIG_KEY)#self.create_mock_tsig_ctx(TSIGError.NOERROR)MockTSIGContext(TSIG_KEY)
+ msg.to_wire(renderer, tsig_ctx)
+ reply_data = renderer.get_data()
+ return reply_data
+
def setUp(self):
self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
self.log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False )
- self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(), self.log)
+ self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(), self.log, TSIGKeyRing())
self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
@@ -96,6 +125,18 @@ class TestXfroutSession(unittest.TestCase):
[get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
self.assertEqual(get_rcode.to_text(), "NOERROR")
+ # tsig signed query message
+ request_data = self.create_request_data_with_tsig()
+ # BADKEY
+ [rcode, msg] = self.xfrsess._parse_query_message(request_data)
+ self.assertEqual(rcode.to_text(), "NOTAUTH")
+ self.assertTrue(self.xfrsess._tsig_ctx is not None)
+ # NOERROR
+ self.xfrsess._tsig_key_ring.add(TSIG_KEY)
+ [rcode, msg] = self.xfrsess._parse_query_message(request_data)
+ self.assertEqual(rcode.to_text(), "NOERROR")
+ self.assertTrue(self.xfrsess._tsig_ctx is not None)
+
def test_get_query_zone_name(self):
msg = self.getmsg()
self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
@@ -111,6 +152,14 @@ class TestXfroutSession(unittest.TestCase):
get_msg = self.sock.read_msg()
self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
+ # tsig signed message
+ msg = self.getmsg()
+ self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+ self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
+ get_msg = self.sock.read_msg()
+ self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
+ self.assertTrue(self.message_has_tsig(get_msg))
+
def test_send_message(self):
msg = self.getmsg()
msg.make_response()
@@ -152,6 +201,14 @@ class TestXfroutSession(unittest.TestCase):
get_msg = self.sock.read_msg()
self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
+ # tsig signed message
+ msg = self.getmsg()
+ self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+ self.xfrsess._reply_query_with_format_error(msg, self.sock)
+ get_msg = self.sock.read_msg()
+ self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
+ self.assertTrue(self.message_has_tsig(get_msg))
+
def test_create_rrset_from_db_record(self):
rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
self.assertEqual(rrset.get_name().to_text(), "example.com.")
@@ -180,6 +237,15 @@ class TestXfroutSession(unittest.TestCase):
rdata = answer.get_rdata()
self.assertEqual(rdata[0].to_text(), self.soa_record[7])
+ def test_send_message_with_last_soa_with_tsig(self):
+ self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+ rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
+ msg = self.getmsg()
+ msg.make_response()
+ self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0)
+ get_msg = self.sock.read_msg()
+ self.assertTrue(self.message_has_tsig(get_msg))
+
def test_trigger_send_message_with_last_soa(self):
rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))
rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
@@ -223,6 +289,21 @@ class TestXfroutSession(unittest.TestCase):
# and it should not have sent anything else
self.assertEqual(0, len(self.sock.sendqueue))
+ def test_trigger_send_message_with_last_soa_with_tsig(self):
+ self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+ rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
+ msg = self.getmsg()
+ msg.make_response()
+ msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
+ self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 65520)
+ get_msg = self.sock.read_msg()
+ # the last packet
+ get_msg = self.sock.read_msg()
+ self.assertTrue(self.message_has_tsig(get_msg))
+
+ # and it should not have sent anything else
+ self.assertEqual(0, len(self.sock.sendqueue))
+
def test_get_rrset_len(self):
rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
self.assertEqual(82, get_rrset_len(rrset_soa))
@@ -313,6 +394,51 @@ class TestXfroutSession(unittest.TestCase):
reply_msg = self.sock.read_msg()
self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
+ def test_reply_xfrout_query_noerror_with_tsig(self):
+ rrset_data = (4, 3, 'a.example.com.', 'com.example.', 3600, 'A', None, '192.168.1.1')
+ global sqlite3_ds
+ global xfrout
+ def get_zone_soa(zonename, file):
+ return self.soa_record
+
+ def get_zone_datas(zone, file):
+ zone_rrsets = []
+ for i in range(0, 100):
+ zone_rrsets.insert(i, rrset_data)
+ return zone_rrsets
+
+ def get_rrset_len(rrset):
+ return 65520
+
+ sqlite3_ds.get_zone_soa = get_zone_soa
+ sqlite3_ds.get_zone_datas = get_zone_datas
+ xfrout.get_rrset_len = get_rrset_len
+
+ self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
+ self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
+
+ # tsig signed first package
+ reply_msg = self.sock.read_msg()
+ self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1)
+ self.assertTrue(self.message_has_tsig(reply_msg))
+ # (TSIG_SIGN_EVERY_NTH - 1) packets have no tsig
+ for i in range(0, xfrout.TSIG_SIGN_EVERY_NTH - 1):
+ reply_msg = self.sock.read_msg()
+ self.assertFalse(self.message_has_tsig(reply_msg))
+ # TSIG_SIGN_EVERY_NTH packet has tsig
+ reply_msg = self.sock.read_msg()
+ self.assertTrue(self.message_has_tsig(reply_msg))
+
+ for i in range(0, 100 - TSIG_SIGN_EVERY_NTH):
+ reply_msg = self.sock.read_msg()
+ self.assertFalse(self.message_has_tsig(reply_msg))
+ # tsig signed last package
+ reply_msg = self.sock.read_msg()
+ self.assertTrue(self.message_has_tsig(reply_msg))
+
+ # and it should not have sent anything else
+ self.assertEqual(0, len(self.sock.sendqueue))
+
class MyCCSession():
def __init__(self):
pass
@@ -347,8 +473,23 @@ class TestUnixSockServer(unittest.TestCase):
self.assertEqual(recv_msg, send_msg)
def test_updata_config_data(self):
+ tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
+ tsig_key_list = [tsig_key_str]
+ bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g==']
self.unix.update_config_data({'transfers_out':10 })
self.assertEqual(self.unix._max_transfers_out, 10)
+ self.assertTrue(self.unix.tsig_key_ring is not None)
+
+ self.unix.update_config_data({'transfers_out':9, 'tsig_key_ring':tsig_key_list})
+ self.assertEqual(self.unix._max_transfers_out, 9)
+ self.assertEqual(self.unix.tsig_key_ring.size(), 1)
+ self.unix.tsig_key_ring.remove(Name("example.com."))
+ self.assertEqual(self.unix.tsig_key_ring.size(), 0)
+
+ # bad tsig key
+ config_data = {'transfers_out':9, 'tsig_key_ring': bad_key_list}
+ self.assertRaises(None, self.unix.update_config_data(config_data))
+ self.assertEqual(self.unix.tsig_key_ring.size(), 0)
def test_get_db_file(self):
self.assertEqual(self.unix.get_db_file(), "initdb.file")
diff --git a/src/bin/xfrout/xfrout.py.in b/src/bin/xfrout/xfrout.py.in
index 17ca3eb..c76a5ae 100755
--- a/src/bin/xfrout/xfrout.py.in
+++ b/src/bin/xfrout/xfrout.py.in
@@ -74,10 +74,15 @@ SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
MAX_TRANSFERS_OUT = 10
VERBOSE_MODE = False
+# tsig sign every N axfr packets.
+TSIG_SIGN_EVERY_NTH = 96
XFROUT_MAX_MESSAGE_SIZE = 65535
+class XfroutException(Exception):
+ pass
+
def get_rrset_len(rrset):
"""Returns the wire length of the given RRset"""
bytes = bytearray()
@@ -86,15 +91,22 @@ def get_rrset_len(rrset):
class XfroutSession():
- def __init__(self, sock_fd, request_data, server, log):
+ def __init__(self, sock_fd, request_data, server, log, tsig_key_ring):
# The initializer for the superclass may call functions
# that need _log to be set, so we set it first
self._sock_fd = sock_fd
self._request_data = request_data
self._server = server
self._log = log
+ self._tsig_key_ring = tsig_key_ring
+ self._tsig_ctx = None
+ self._tsig_len = 0
self.handle()
+ def create_tsig_ctx(self, tsig_record, tsig_key_ring):
+ return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
+ tsig_key_ring)
+
def handle(self):
''' Handle a xfrout query, send xfrout response '''
try:
@@ -105,17 +117,33 @@ class XfroutSession():
os.close(self._sock_fd)
+ def _check_request_tsig(self, msg, request_data):
+ ''' If request has a tsig record, perform tsig related checks '''
+ tsig_record = msg.get_tsig_record()
+ if tsig_record is not None:
+ self._tsig_len = tsig_record.get_length()
+ self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring)
+ tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
+ if tsig_error != TSIGError.NOERROR:
+ return Rcode.NOTAUTH()
+
+ return Rcode.NOERROR()
+
def _parse_query_message(self, mdata):
''' parse query message to [socket,message]'''
#TODO, need to add parseHeader() in case the message header is invalid
try:
msg = Message(Message.PARSE)
Message.from_wire(msg, mdata)
+
+ # TSIG related checks
+ rcode = self._check_request_tsig(msg, mdata)
+
except Exception as err:
self._log.log_message("error", str(err))
return Rcode.FORMERR(), None
- return Rcode.NOERROR(), msg
+ return rcode, msg
def _get_query_zone_name(self, msg):
question = msg.get_question()[0]
@@ -130,13 +158,20 @@ class XfroutSession():
total_count += count
- def _send_message(self, sock_fd, msg):
+ def _send_message(self, sock_fd, msg, tsig_ctx=None):
render = MessageRenderer()
# As defined in RFC5936 section3.4, perform case-preserving name
# compression for AXFR message.
render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
- msg.to_wire(render)
+
+ # XXX Currently, python wrapper doesn't accept 'None' parameter in this case,
+ # we should remove the if statement and use a universal interface later.
+ if tsig_ctx is not None:
+ msg.to_wire(render, tsig_ctx)
+ else:
+ msg.to_wire(render)
+
header_len = struct.pack('H', socket.htons(render.get_length()))
self._send_data(sock_fd, header_len)
self._send_data(sock_fd, render.get_data())
@@ -145,7 +180,7 @@ class XfroutSession():
def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
msg.make_response()
msg.set_rcode(rcode_)
- self._send_message(sock_fd, msg)
+ self._send_message(sock_fd, msg, self._tsig_ctx)
def _reply_query_with_format_error(self, msg, sock_fd):
@@ -155,7 +190,7 @@ class XfroutSession():
msg.make_response()
msg.set_rcode(Rcode.FORMERR())
- self._send_message(sock_fd, msg)
+ self._send_message(sock_fd, msg, self._tsig_ctx)
def _zone_has_soa(self, zone):
'''Judge if the zone has an SOA record.'''
@@ -204,8 +239,10 @@ class XfroutSession():
def dns_xfrout_start(self, sock_fd, msg_query):
rcode_, msg = self._parse_query_message(msg_query)
#TODO. create query message and parse header
- if rcode_ != Rcode.NOERROR():
+ if rcode_ == Rcode.FORMERR():
return self._reply_query_with_format_error(msg, sock_fd)
+ elif rcode_ == Rcode.NOTAUTH():
+ return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
zone_name = self._get_query_zone_name(msg)
rcode_ = self._check_xfrout_available(zone_name)
@@ -254,31 +291,32 @@ class XfroutSession():
'''
rrset_len = get_rrset_len(rrset_soa)
- if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
+ if message_upper_len + rrset_len + self._tsig_len < XFROUT_MAX_MESSAGE_SIZE:
msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
else:
self._send_message(sock_fd, msg)
msg = self._clear_message(msg)
msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
- self._send_message(sock_fd, msg)
+ # If tsig context exist, sign the last packet
+ self._send_message(sock_fd, msg, self._tsig_ctx)
def _reply_xfrout_query(self, msg, sock_fd, zone_name):
#TODO, there should be a better way to insert rrset.
+ count_since_last_tsig_sign = TSIG_SIGN_EVERY_NTH
msg.make_response()
msg.set_header_flag(Message.HEADERFLAG_AA)
soa_record = sqlite3_ds.get_zone_soa(zone_name, self._server.get_db_file())
rrset_soa = self._create_rrset_from_db_record(soa_record)
msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
- message_upper_len = get_rrset_len(rrset_soa)
+ message_upper_len = get_rrset_len(rrset_soa) + self._tsig_len
for rr_data in sqlite3_ds.get_zone_datas(zone_name, self._server.get_db_file()):
if self._server._shutdown_event.is_set(): # Check if xfrout is shutdown
self._log.log_message("info", "xfrout process is being shutdown")
return
-
# TODO: RRType.SOA() ?
if RRType(rr_data[5]) == RRType("SOA"): #ignore soa record
continue
@@ -294,10 +332,22 @@ class XfroutSession():
message_upper_len += rrset_len
continue
- self._send_message(sock_fd, msg)
+ # If tsig context exist, sign every N packets
+ if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
+ count_since_last_tsig_sign = 0
+ self._send_message(sock_fd, msg, self._tsig_ctx)
+ else:
+ self._send_message(sock_fd, msg)
+
+ count_since_last_tsig_sign += 1
msg = self._clear_message(msg)
msg.add_rrset(Message.SECTION_ANSWER, rrset_) # Add the rrset to the new message
- message_upper_len = rrset_len
+
+ # Reserve tsig space for signed packet
+ if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
+ message_upper_len = rrset_len + self._tsig_len
+ else:
+ message_upper_len = rrset_len
self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len)
@@ -403,7 +453,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
def finish_request(self, sock_fd, request_data):
'''Finish one request by instantiating RequestHandlerClass.'''
- self.RequestHandlerClass(sock_fd, request_data, self, self._log)
+ self.RequestHandlerClass(sock_fd, request_data, self, self._log, self.tsig_key_ring)
def _remove_unused_sock_file(self, sock_file):
'''Try to remove the socket file. If the file is being used
@@ -449,10 +499,28 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
self._log.log_message('info', 'update config data start.')
self._lock.acquire()
self._max_transfers_out = new_config.get('transfers_out')
+ self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
self._log.log_message('info', 'max transfer out : %d', self._max_transfers_out)
self._lock.release()
self._log.log_message('info', 'update config data complete.')
+ def set_tsig_key_ring(self, key_list):
+ """Set the tsig_key for this zone, given a TSIG key string
+ representation. If tsig_key_str is None, no TSIG key will
+ be set. Raises XfrinZoneInfoException if tsig_key_str cannot
+ be parsed."""
+ self.tsig_key_ring = TSIGKeyRing()
+ # tsig_key_ring list is empty
+ if not key_list:
+ return
+
+ for key_item in key_list:
+ try:
+ self.tsig_key_ring.add(TSIGKey(key_item))
+ except InvalidParameter as ipe:
+ errmsg = "bad TSIG key string: " + str(key_item)
+ self._log.log_message('error', '%s' % errmsg)
+
def get_db_file(self):
file, is_default = self._cc.get_remote_config_value("Auth", "database_file")
# this too should be unnecessary, but currently the
diff --git a/src/bin/xfrout/xfrout.spec.pre.in b/src/bin/xfrout/xfrout.spec.pre.in
index 941db72..2efa3d7 100644
--- a/src/bin/xfrout/xfrout.spec.pre.in
+++ b/src/bin/xfrout/xfrout.spec.pre.in
@@ -37,6 +37,18 @@
"item_type": "integer",
"item_optional": false,
"item_default": 1048576
+ },
+ {
+ "item_name": "tsig_key_ring",
+ "item_type": "list",
+ "item_optional": true,
+ "item_default": [],
+ "list_item_spec" :
+ {
+ "item_name": "tsig_key",
+ "item_type": "string",
+ "item_optional": true
+ }
}
],
"commands": [
More information about the bind10-changes
mailing list