[svn] commit: r2807 - in /branches/trac216/src/bin/xfrin: tests/xfrin_test.py xfrin.py.in

BIND 10 source code commits bind10-changes at lists.isc.org
Thu Aug 26 02:40:09 UTC 2010


Author: jinmei
Date: Thu Aug 26 02:40:09 2010
New Revision: 2807

Log:
poposed changes to the use of select.select.

Modified:
    branches/trac216/src/bin/xfrin/tests/xfrin_test.py
    branches/trac216/src/bin/xfrin/xfrin.py.in

Modified: branches/trac216/src/bin/xfrin/tests/xfrin_test.py
==============================================================================
--- branches/trac216/src/bin/xfrin/tests/xfrin_test.py (original)
+++ branches/trac216/src/bin/xfrin/tests/xfrin_test.py Thu Aug 26 02:40:09 2010
@@ -105,7 +105,7 @@
         return len(data)
 
     def close(self):
-        return True
+        pass
 
 class MockXfrinConnection(XfrinConnection):
     def __init__(self, conn_socket, zone_name, rrclass, db_file, shutdown_flag,
@@ -114,31 +114,37 @@
                          master_addr)
         self.query_data = b''
         self.reply_data = b''
+        self.mock_test_data = b"hello bind10"
         self.force_time_out = False
         self.force_close = False
         self.qlen = None
         self.qid = None
         self.response_generator = None
+        self.closed = False
 
     def connect_to_master(self):
         return self._socket
 
-    def _loop(self):
+    def _select(self):
+        if self.force_time_out:
+            return 0
         if self.force_close:
             self.close()
             self._conn_socket.close()
-        elif not self.force_time_out:
-            self.handle_read()
-
-    def mock_handle_read(self):
-        return True
+        return 1
 
     def recv(self, size):
+        if self.closed:
+            raise socket.error('recv attempt on a closed socket')
         data = self.reply_data[:size]
         self.reply_data = self.reply_data[size:]
         if len(data) < size:
             raise XfrinTestException('cannot get reply data')
         return data
+
+    def close(self):
+        self.closed = True
+        super().close()
 
     def send(self, data):
         if self.qlen != None and len(self.query_data) >= self.qlen:
@@ -187,12 +193,15 @@
 
 class TestXfrinConnection(unittest.TestCase):
     def setUp(self):
-        conn_socket = socket.socketpair()
+        self.conn_sockets = socket.socketpair()
+        self.mock_xfrsockets = socket.socketpair()
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
-        self.conn = MockXfrinConnection(conn_socket[1], 'example.com.',
+        self.conn = MockXfrinConnection(self.conn_sockets[1], 'example.com.',
                                         TEST_RRCLASS, TEST_DB_FILE,
                                         0, TEST_MASTER_IPV4_ADDRINFO)
+        # replace the XFR socket with our local mock
+        self.conn._socket = self.mock_xfrsockets[1]
         self.axfr_after_soa = False
         self.soa_response_params = {
             'questions': [example_soa_question],
@@ -204,6 +213,10 @@
 
     def tearDown(self):
         self.conn.close()
+        self.conn_sockets[0].close()
+        self.conn_sockets[1].close()
+        self.mock_xfrsockets[0].close()
+        self.mock_xfrsockets[1].close()
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
 
@@ -215,23 +228,24 @@
     def test_send(self):
         self.conn._socket.close()
         self.conn._socket = MockSocket()
-        self.assertEqual(len(b"hello bind10"),
+        self.assertEqual(len(self.conn.mock_test_data),
                          super(MockXfrinConnection,
-                               self.conn).send(b"hello bind10"))
+                               self.conn).send(self.conn.mock_test_data))
 
     def test_send_exception(self):
         self.conn._socket.close()
         self.conn._socket = MockSocket()
         self.assertRaises(socket.error,
                           super(MockXfrinConnection, self.conn).send,
-                          "hello bind10")
+                          "not binary data")
 
     def test_recv(self):
         self.conn._socket.close()
         self.conn._socket = MockSocket()
-        super(MockXfrinConnection, self.conn).send(b"hello bind10")
-        self.assertEqual(b"hello bind10",
-                         super(MockXfrinConnection, self.conn).recv(20))
+        super(MockXfrinConnection, self.conn).send(self.conn.mock_test_data)
+        self.assertEqual(self.conn.mock_test_data,
+                         super(MockXfrinConnection,
+                               self.conn).recv(len(self.conn.mock_test_data)))
 
     def test_recv_nodata(self):
         self.conn._socket.close()
@@ -241,17 +255,26 @@
     def test_recv_exception(self):
         self.conn._socket.close()
         self.conn._socket = MockSocket()
-        super(MockXfrinConnection, self.conn).send(b"hello bind10")
+        super(MockXfrinConnection, self.conn).send(self.conn.mock_test_data)
         self.assertRaises(socket.error,
                           super(MockXfrinConnection, self.conn).recv, -1)
 
-    def test_loop(self):
+    def test_select_readok(self):
+        self.mock_xfrsockets[0].send(self.conn.mock_test_data)
+        self.conn._idle_timeout = 3 # timeout shouldn't occur
+        self.assertEqual(1, super(MockXfrinConnection, self.conn)._select())
+        self.assertEqual(self.conn.mock_test_data,
+                         self.conn._socket.recv(len(self.conn.mock_test_data)))
+
+    def test_select_timeout(self):
         self.conn._idle_timeout = 0.1
-
-        self.conn.handle_read = self.conn.mock_handle_read
-        self.assertRaises(XfrinException,
-                          super(MockXfrinConnection, self.conn)._loop)
-
+        self.assertEqual(0, super(MockXfrinConnection, self.conn)._select())
+
+    def test_select_shutdown(self):
+        self.conn._idle_timeout = 3 # timeout shouldn't occur
+        self.conn_sockets[0].send(b"shutdown")
+        self.assertRaises(XfrinException, super(MockXfrinConnection,
+                                                self.conn)._select)
 
     def test_init_ip6(self):
         # This test simply creates a new XfrinConnection object with an
@@ -354,7 +377,7 @@
     def test_response_remote_close(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.force_close = True
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertRaises(socket.error, self._handle_xfrin_response)
 
     def test_response_bad_message(self):
         self.conn.response_generator = self._create_broken_response_data
@@ -440,6 +463,7 @@
         bogus_data = b'xxxx'
         self.conn.reply_data = struct.pack('H', socket.htons(len(bogus_data)))
         self.conn.reply_data += bogus_data
+
 class MockThread:
     def is_alive(self):
         return True

Modified: branches/trac216/src/bin/xfrin/xfrin.py.in
==============================================================================
--- branches/trac216/src/bin/xfrin/xfrin.py.in (original)
+++ branches/trac216/src/bin/xfrin/xfrin.py.in Thu Aug 26 02:40:09 2010
@@ -175,33 +175,39 @@
         self._send_data(header_len)
         self._send_data(render.get_data())
 
-    def _loop(self):
-        try:
-            rlist, wlist, elist = select.select([self._socket, self._conn_socket], \
-                    [], [], self._idle_timeout)
-        except select.error as err:
-            if err.args[0] != EINTR:
-                raise
-            else:
-                return
-        for s in rlist:
-            if s == self._conn_socket:
+    def _select(self):
+        '''
+        This method is a trivial asynchronous I/O routine using select.
+        It's extracted from _get_request_response so that we can test the
+        rest of the code without involving actual communication with a remote
+        server.'''
+        while True:
+            try:
+                rlist, wlist, elist = select.select([self._socket,
+                                                     self._conn_socket],
+                                                    [], [], self._idle_timeout)
+            except select.error as err:
+                if err.args[0] != EINTR:
+                    raise
+                else:
+                    continue
+            if self._conn_socket in rlist:
                 raise XfrinException("shutdown xfrin!")
-            self.handle_read()
-            
+            return len(rlist)
     
     def _get_request_response(self, size):
         recv_size = 0
         data = b''
         while recv_size < size:
-            self._recv_time_out = True
             self._need_recv_size = size - recv_size
-            self._loop()
-            if self._recv_time_out:
+            if self._select() == 0:
                 raise XfrinException('receive data from socket time out.')
-
-            recv_size += self._recvd_size
-            data += self._recvd_data
+            recvd_data = self.recv(self._need_recv_size)
+            recvd_size = len(recvd_data)
+            if recvd_size == 0:
+                raise XfrinException("receive data from socket error!")
+            recv_size += recvd_size
+            data += recvd_data
 
         return data
 
@@ -347,14 +353,6 @@
             if self._shutdown_flag:
                 raise XfrinException('xfrin is forced to stop')
 
-    def handle_read(self):
-        '''Read query's response from socket. '''
-        self._recvd_data = self.recv(self._need_recv_size)
-        if(len(self._recvd_data)==0):
-            raise XfrinException("receive data from socket error!")
-        self._recvd_size = len(self._recvd_data)
-        self._recv_time_out = False
-
     def log_info(self, msg, type='info'):
         # Overwrite the log function, log nothing
         pass




More information about the bind10-changes mailing list