[svn] commit: r2573 - in /branches/trac216/src/bin/xfrin: tests/xfrin_test.py xfrin.py.in

BIND 10 source code commits bind10-changes at lists.isc.org
Thu Jul 22 02:48:56 UTC 2010


Author: shentingting
Date: Thu Jul 22 02:48:56 2010
New Revision: 2573

Log:
using self-designed select mechanism replace asyncore library, and amend test case

Modified:
    branches/trac216/src/bin/xfrin/tests/xfrin_test.py
    branches/trac216/src/bin/xfrin/xfrin.py.in

Modified: branches/trac216/src/bin/xfrin/tests/xfrin_test.py
==============================================================================
--- branches/trac216/src/bin/xfrin/tests/xfrin_test.py (original)
+++ branches/trac216/src/bin/xfrin/tests/xfrin_test.py Thu Jul 22 02:48:56 2010
@@ -62,17 +62,17 @@
     check_command_hook = None
 
     def _cc_setup(self):
-        pass
+        self._max_transfers_in = 10
     
     def _cc_check_command(self):
-        self._shutdown_event.set()
+        self._shutdown_flag = 1
         if MockXfrin.check_command_hook:
             MockXfrin.check_command_hook()
 
 class MockXfrinConnection(XfrinConnection):
-    def __init__(self, sock_map, zone_name, rrclass, db_file, shutdown_event,
+    def __init__(self, conn_socket, zone_name, rrclass, db_file, shutdown_flag,
                  master_addr):
-        super().__init__(sock_map, zone_name, rrclass, db_file, shutdown_event,
+        super().__init__(conn_socket, zone_name, rrclass, db_file, shutdown_flag,
                          master_addr)
         self.query_data = b''
         self.reply_data = b''
@@ -82,14 +82,15 @@
         self.qid = None
         self.response_generator = None
 
-    def _asyncore_loop(self):
+    def connect_to_master(self):
+        return True
+
+    def _loop(self):
         if self.force_close:
-            self.handle_close()
+            self.close()
+            self._conn_socket.close()
         elif not self.force_time_out:
             self.handle_read()
-
-    def connect_to_master(self):
-        return True
 
     def recv(self, size):
         data = self.reply_data[:size]
@@ -145,13 +146,12 @@
 
 class TestXfrinConnection(unittest.TestCase):
     def setUp(self):
+        conn_socket = socket.socketpair()
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
-        self.sock_map = {}
-        self.conn = MockXfrinConnection(self.sock_map, 'example.com.',
+        self.conn = MockXfrinConnection(conn_socket[1], 'example.com.',
                                         TEST_RRCLASS, TEST_DB_FILE,
-                                        threading.Event(),
-                                        TEST_MASTER_IPV4_ADDRINFO)
+                                        0, TEST_MASTER_IPV4_ADDRINFO)
         self.axfr_after_soa = False
         self.soa_response_params = {
             'questions': [example_soa_question],
@@ -166,14 +166,9 @@
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
 
-    def test_close(self):
-        # we shouldn't be using the global asyncore map.
-        self.assertEqual(len(asyncore.socket_map), 0)
-        # there should be exactly one entry in our local map
-        self.assertEqual(len(self.sock_map), 1)
-        # once closing the dispatch the map should become empty
-        self.conn.close()
-        self.assertEqual(len(self.sock_map), 0)
+    def test_connect(self):
+        #self.assertEqual(, "")
+        self.assertRaises(Exception, self.conn.connect, (TEST_MASTER_IPV4_ADDRESS,53))
 
     def test_init_ip6(self):
         # This test simply creates a new XfrinConnection object with an
@@ -182,14 +177,13 @@
         # tends to assume it's IPv4 only and hardcode AF_INET.  This test
         # uncovers such a bug.
         c = MockXfrinConnection({}, 'example.com.', TEST_RRCLASS, TEST_DB_FILE,
-                                threading.Event(),
-                                TEST_MASTER_IPV6_ADDRINFO)
-        c.bind(('::', 0))
+                                0, TEST_MASTER_IPV6_ADDRINFO)
+        c._socket.bind(('::', 0))
         c.close()
 
     def test_init_chclass(self):
         c = XfrinConnection({}, 'example.com.', RRClass.CH(), TEST_DB_FILE,
-                            threading.Event(), TEST_MASTER_IPV4_ADDRINFO)
+                            0, TEST_MASTER_IPV4_ADDRINFO)
         axfrmsg = c._create_query(RRType.AXFR())
         self.assertEqual(axfrmsg.get_question()[0].get_class(),
                          RRClass.CH())
@@ -265,7 +259,7 @@
 
     def test_response_shutdown(self):
         self.conn.response_generator = self._create_normal_response_data
-        self.conn._shutdown_event.set()
+        self.conn._shutdown_flag = 1
         self.conn._send_query(RRType.AXFR())
         self.assertRaises(XfrinException, self._handle_xfrin_response)
 
@@ -363,38 +357,9 @@
         bogus_data = b'xxxx'
         self.conn.reply_data = struct.pack('H', socket.htons(len(bogus_data)))
         self.conn.reply_data += bogus_data
-
-class TestXfrinRecorder(unittest.TestCase):
-    def setUp(self):
-        self.recorder = XfrinRecorder()
-
-    def test_increment(self):
-        self.assertEqual(self.recorder.count(), 0)
-        self.recorder.increment(TEST_ZONE_NAME)
-        self.assertEqual(self.recorder.count(), 1)
-        # duplicate "increment" should probably be rejected.  but it's not
-        # checked at this moment
-        self.recorder.increment(TEST_ZONE_NAME)
-        self.assertEqual(self.recorder.count(), 2)
-
-    def test_decrement(self):
-        self.assertEqual(self.recorder.count(), 0)
-        self.recorder.increment(TEST_ZONE_NAME)
-        self.assertEqual(self.recorder.count(), 1)
-        self.recorder.decrement(TEST_ZONE_NAME)
-        self.assertEqual(self.recorder.count(), 0)
-
-    def test_decrement_from_empty(self):
-        self.assertEqual(self.recorder.count(), 0)
-        self.recorder.decrement(TEST_ZONE_NAME)
-        self.assertEqual(self.recorder.count(), 0)
-
-    def test_inprogress(self):
-        self.assertEqual(self.recorder.count(), 0)
-        self.recorder.increment(TEST_ZONE_NAME)
-        self.assertEqual(self.recorder.xfrin_in_progress(TEST_ZONE_NAME), True)
-        self.recorder.decrement(TEST_ZONE_NAME)
-        self.assertEqual(self.recorder.xfrin_in_progress(TEST_ZONE_NAME), False)
+class MockThread:
+    def is_alive(self):
+        return True
 
 class TestXfrin(unittest.TestCase):
     def setUp(self):
@@ -474,18 +439,18 @@
 
     def test_command_handler_retransfer_quota(self):
         for i in range(self.xfr._max_transfers_in - 1):
-            self.xfr.recorder.increment(str(i) + TEST_ZONE_NAME)
+            self.xfr._zones[str(i) + TEST_ZONE_NAME] = MockThread()
         # there can be one more outstanding transfer.
         self.assertEqual(self.xfr.command_handler("retransfer",
                                                   self.args)['result'][0], 0)
         # make sure the # xfrs would excceed the quota
-        self.xfr.recorder.increment(str(self.xfr._max_transfers_in) + TEST_ZONE_NAME)
+        self.xfr._zones[str(self.xfr._max_transfers_in) + TEST_ZONE_NAME] = MockThread()
         # this one should fail
         self.assertEqual(self.xfr.command_handler("retransfer",
                                                   self.args)['result'][0], 1)
 
     def test_command_handler_retransfer_inprogress(self):
-        self.xfr.recorder.increment(TEST_ZONE_NAME)
+        self.xfr._zones[TEST_ZONE_NAME] = MockThread()
         self.assertEqual(self.xfr.command_handler("retransfer",
                                                   self.args)['result'][0], 1)
 

Modified: branches/trac216/src/bin/xfrin/xfrin.py.in
==============================================================================
--- branches/trac216/src/bin/xfrin/xfrin.py.in (original)
+++ branches/trac216/src/bin/xfrin/xfrin.py.in Thu Jul 22 02:48:56 2010
@@ -21,12 +21,15 @@
 import os
 import signal
 import isc
-import asyncore
 import struct
 import threading
-import socket, _socket, select
+import socket
+import select
 import random
 import time
+from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, \
+        ENOTCONN, ESHUTDOWN, EINTR, EISCONN, EBADF, ECONNABORTED, errorcode
+
 from optparse import OptionParser, OptionValueError
 from isc.config.ccsession import *
 try:
@@ -64,76 +67,70 @@
 class XfrinException(Exception): 
     pass
 
-def Xfrin_poll(timeout=0.0, map=None):
-    if map is None:
-        map = socket_map
-    if map:
-        abc = None
-        r = []; w = []; e = []
-        for fd, obj in list(map.items()):
-            if isinstance(obj, _socket.socket):
-                r.append(obj)
-            else:
-                is_r = obj.readable()
-                is_w = obj.writable()
-                if is_r:
-                    r.append(fd)
-                if is_w:
-                    w.append(fd)
-                if is_r or is_w:
-                    e.append(fd)
-        if [] == r == w == e:
-            time.sleep(timeout)
-            return
-        try:
-            r, w, e = select.select(r, w, e, timeout)
-        except select.error as err:
-            if err.args[0] != EINTR:
-                raise
-            else:
-                return
-        for el in r:
-            if isinstance(el, _socket.socket):
-                raise XfrinException("shutdown xfrin!")
-            obj = map.get(el)
-            if obj is None:
-                continue
-            asyncore.read(obj)
-        for fd in w:
-            obj = map.get(fd)
-            if obj is None:
-                continue
-            asyncore.write(obj)
-        for fd in e:
-            obj = map.get(fd)
-            if obj is None:
-                continue
-            _exception(obj)
-            
-asyncore.poll = Xfrin_poll
-
-class XfrinConnection(asyncore.dispatcher):
+class XfrinConnection:
     '''Do xfrin in this class. '''    
 
     def __init__(self,
-                 sock_map, zone_name, rrclass, db_file, shutdown_event,
+                 conn_socket, zone_name, rrclass, db_file, shutdown_flag,
                  master_addrinfo, verbose = False, idle_timeout = 60): 
         ''' idle_timeout: max idle time for read data from socket.
             db_file: specify the data source file.
             check_soa: when it's true, check soa first before sending xfr query
         '''
-        asyncore.dispatcher.__init__(self, map=sock_map)
-        self.create_socket(master_addrinfo[0], master_addrinfo[1])
+        self._socket = socket.socket(master_addrinfo[0], master_addrinfo[1])
+        self._socket.setblocking(1)
+        self._conn_socket = conn_socket
         self._zone_name = zone_name
-        self._sock_map = sock_map
         self._rrclass = rrclass
         self._db_file = db_file
         self._soa_rr_count = 0
         self._idle_timeout = idle_timeout
-        self.setblocking(1)
-        self._shutdown_event = shutdown_event
+        self._shutdown_flag = shutdown_flag
         self._verbose = verbose
         self._master_address = master_addrinfo[4]
+    
+    def connect(self, address):
+        err = self._socket.connect_ex(address)
+        if err in (EINPROGRESS, EALREADY, EWOULDBLOCK):
+            return 
+        if err not in (0, EISCONN):
+            raise socket.error(err, errorcode[err])
+
+    def send(self, data):
+        try:
+            result = self._socket.send(data)
+            return result
+        except socket.error as why:
+            if why.args[0] == EWOULDBLOCK:
+                return 0
+            elif why.args[0] in (ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED):
+                self.close()
+                return 0
+            else:
+                raise
+
+    def recv(self, buffer_size):
+        try:
+            data = self._socket.recv(buffer_size)
+            if not data:
+                self.close()
+                return b''
+            else:
+                return data
+        except socket.error as why:
+            # winsock sometimes throws ENOTCONN
+            if why.args[0] in [ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED]:
+                self.close()
+                return b''
+            else:
+                raise
+
+    def close(self):
+        try:
+            self._socket.close()
+        except socket.error as why:
+            if why.args[0] not in (ENOTCONN, EBADF):
+                raise
 
     def connect_to_master(self):
         '''Connect to master in TCP.'''
@@ -143,7 +140,7 @@
             return True
         except socket.error as e:
             self.log_msg('Failed to connect:(%s), %s' % (self._master_address,
-                                                            str(e)))
+                                                                    str(e)))
             return False
 
     def _create_query(self, query_type):
@@ -159,6 +156,7 @@
         msg.add_question(query_question)
         return msg
 
+
     def _send_data(self, data):
         size = len(data)
         total_count = 0
@@ -177,12 +175,20 @@
         self._send_data(header_len)
         self._send_data(render.get_data())
 
-    def _asyncore_loop(self):
-        '''
-        This method is a trivial wrapper for asyncore.loop().  It's extracted from
-        _get_request_response so that we can test the rest of the code without
-        involving actual communication with a remote server.'''
-        asyncore.loop(self._idle_timeout, map=self._sock_map, count=1)
+    def _loop(self):
+        try:
+            rlist, wlist, elist = select.select([self._socket, self._conn_socket], \
+                    [], [], self._idle_timeout)
+        except select.error as err:
+            if err.args[0] != EINTR:
+                raise
+            else:
+                return
+        for s in rlist:
+            if s == self._conn_socket:
+                raise XfrinException("shutdown xfrin!")
+            self.handle_read()
+            
     
     def _get_request_response(self, size):
         recv_size = 0
@@ -190,7 +196,7 @@
         while recv_size < size:
             self._recv_time_out = True
             self._need_recv_size = size - recv_size
-            self._asyncore_loop()
+            self._loop()
             if self._recv_time_out:
                 raise XfrinException('receive data from socket time out.')
 
@@ -244,7 +250,7 @@
             self.log_msg(e)
             self.log_msg(logstr + 'failed')
             ret = XFRIN_FAIL
-            #TODO, recover data source.
+            #TODO, recover data source. and exception from self._handle_xfrin_response
         except isc.datasrc.sqlite3_ds.Sqlite3DSError as e:
             self.log_msg(e)
             self.log_msg(logstr + 'failed')
@@ -259,6 +265,7 @@
             ret = XFRIN_FAIL
         finally:
             self.close()
+            self._conn_socket.close()
         return ret
 
     def _check_response_header(self, msg):
@@ -337,55 +344,44 @@
             if self._soa_rr_count == 2:
                 break
             
-            if self._shutdown_event.is_set():
+            if self._shutdown_flag:
                 raise XfrinException('xfrin is forced to stop')
 
     def handle_read(self):
         '''Read query's response from socket. '''
         self._recvd_data = self.recv(self._need_recv_size)
         if(len(self._recvd_data)==0):
-            raise XfrinException
+            raise XfrinException("receive data from socket error!")
         self._recvd_size = len(self._recvd_data)
         self._recv_time_out = False
-
-    def handle_error(self):
-        raise XfrinException("receive data from socket error!")
-
-
-    def writable(self):
-        '''Ignore the writable socket. '''
-        return False
-
 
     def log_info(self, msg, type='info'):
         # Overwrite the log function, log nothing
         pass
 
     def log_msg(self, msg):
-       if self._verbose:
+        if self._verbose:
             sys.stdout.write('[b10-xfrin] %s\n' % str(msg))
 
 
 def process_xfrin(zone_name, rrclass, db_file, 
-                  shutdown_event, master_addrinfo, check_soa, socket_pair, verbose):
-    sock_map = {}
-    sock_map[socket_pair.fileno()] = socket_pair
-    conn = XfrinConnection(sock_map, zone_name, rrclass, db_file,
-                           shutdown_event, master_addrinfo, verbose)
+                  shutdown_flag, master_addrinfo, check_soa, conn_socket, verbose):
+    conn = XfrinConnection(conn_socket, zone_name, rrclass, db_file,
+                           shutdown_flag, master_addrinfo, verbose)
     if conn.connect_to_master():
         conn.do_xfrin(check_soa)
 
 class Xfrin:
     def __init__(self, verbose = False):
         self._cc_setup()
-        self._shutdown_event = threading.Event()
+        self._shutdown_flag = 0
         self._verbose = verbose
 
         #the item in self._zones: zone name and xfr communication thread. 
-        #The item in self._socket_pairs: a socket and xfr communication thread, the main thread uses 
+        #The item in self._conn_sockets: a socket and xfr communication thread, the main thread uses 
         #the socket to communicate with this xfr communication thread.
         self._zones = {}
-        self._socket_pairs = {}
+        self._conn_sockets = {}
 
     def _cc_setup(self):
         '''
@@ -414,14 +410,14 @@
         return create_answer(0)
 
     def shutdown(self):
-        ''' shutdown the xfrin process. the thread which is doing xfrin should be 
+        ''' shutdown the xfrin process. the thread which is doing xfrin will be 
         terminated.
         '''
-        self._filter_hash(self._socket_pairs)
-        for fd in self._socket_pairs.keys():
+        self._filter_hash(self._conn_sockets)
+        for fd in self._conn_sockets.keys():
             fd.send(b"shutdown")
 
-        self._shutdown_event.set()
+        self._shutdown_flag = 1
         main_thread = threading.currentThread()
         for th in threading.enumerate():
             if th is main_thread:
@@ -432,8 +428,7 @@
         answer = create_answer(0)
         try:
             if command == 'shutdown':
-                self._shutdown_event.set()
-                #self.shutdown()
+                self._shutdown_flag = 1
             elif command == 'retransfer' or command == 'refresh':
                 # The default RR class is IN.  We should fix this so that
                 # the class is passed in the command arg (where we specify
@@ -484,11 +479,11 @@
         return (zone_name, master_addrinfo, db_file)
 
     def startup(self):
-        while not self._shutdown_event.is_set():
+        while self._shutdown_flag == 0:
             self._cc_check_command()
 
     def _filter_hash(self, hash):
-        '''delete zone_name in self._zones or a socket in self._socket_pairs.'''
+        '''delete zone_name in self._zones or a socket in self._conn_sockets.'''
         keys = []
         for key in hash.keys():
             keys.append(key)
@@ -505,7 +500,7 @@
         # check max_transfer_in, else return quota error
         if len(self._zones) >= self._max_transfers_in:
             self._filter_hash(self._zones)
-            self._filter_hash(self._socket_pairs)
+            self._filter_hash(self._conn_sockets)
             if len(self._zones) >= self._max_transfers_in:
                 return (1, 'xfrin quota error')
 
@@ -515,17 +510,17 @@
                 del self._zones[zone_name]
             else:
                 return (1, 'zone xfrin is in progress')
-        socket_pair = socket.socketpair()
+        conn_socket = socket.socketpair()
         xfrin_thread = threading.Thread(target = process_xfrin,
                                         args = (zone_name, rrclass,
                                                 db_file,
-                                                self._shutdown_event,
-                                                master_addrinfo, check_soa, socket_pair[1],
+                                                self._shutdown_flag,
+                                                master_addrinfo, check_soa, conn_socket[1],
                                                 self._verbose))
 
         # recored the zone name which zone xfrin is in process
         self._zones[zone_name] = xfrin_thread
-        self._socket_pairs[socket_pair[0]] = xfrin_thread
+        self._conn_sockets[conn_socket[0]] = xfrin_thread
         xfrin_thread.start()
         return (0, 'zone xfrin is started')
 




More information about the bind10-changes mailing list