[svn] commit: r3288 - in /branches/trac216/src/bin/xfrin: tests/xfrin_test.py xfrin.py.in

BIND 10 source code commits bind10-changes at lists.isc.org
Wed Oct 20 03:08:55 UTC 2010


Author: shentingting
Date: Wed Oct 20 03:08:54 2010
New Revision: 3288

Log:
set socket unblocking before connect

Modified:
    branches/trac216/src/bin/xfrin/tests/xfrin_test.py
    branches/trac216/src/bin/xfrin/xfrin.py.in

Modified: branches/trac216/src/bin/xfrin/tests/xfrin_test.py
==============================================================================
--- branches/trac216/src/bin/xfrin/tests/xfrin_test.py (original)
+++ branches/trac216/src/bin/xfrin/tests/xfrin_test.py Wed Oct 20 03:08:54 2010
@@ -65,7 +65,8 @@
 
     def check_command(self):
         return True
-class MockSession: pass
+class MockSession():
+    def group_sendmsg(self, arg1, arg2):pass 
 
 class MockXfrin(Xfrin):
     # This is a class attribute of a callable object that specifies a non
@@ -129,7 +130,7 @@
     def connect_to_master(self):
         return self._socket
 
-    def _select(self):
+    def _select(self, rl, wl, el):
         if self.force_time_out:
             return 0
         if self.force_close:
@@ -224,6 +225,10 @@
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
 
+    def test_connect(self):
+        self.assertRaises(Exception, self.conn.connect,
+                (TEST_MASTER_IPV4_ADDRESS,53))
+
     def test_send(self):
         self.conn._socket.close()
         self.conn._socket = MockSocket()
@@ -261,19 +266,20 @@
     def test_select_readok(self):
         self.mock_xfrsockets[0].send(self.conn.mock_test_data)
         self.conn._idle_timeout = 3 # timeout shouldn't occur
-        self.assertEqual(1, super(MockXfrinConnection, self.conn)._select())
+        self.assertEqual(1, super(MockXfrinConnection, self.conn)._select(\
+                [self.conn._socket, self.conn._conn_socket], [], [] ))
         self.assertEqual(self.conn.mock_test_data,
                          self.conn._socket.recv(len(self.conn.mock_test_data)))
 
     def test_select_timeout(self):
         self.conn._idle_timeout = 0.1
-        self.assertEqual(0, super(MockXfrinConnection, self.conn)._select())
+        self.assertEqual(0, super(MockXfrinConnection, self.conn)._select([], [], []))
 
     def test_select_shutdown(self):
         self.conn._idle_timeout = 3 # timeout shouldn't occur
         self.conn_sockets[0].send(b"shutdown")
         self.assertRaises(XfrinException, super(MockXfrinConnection,
-                                                self.conn)._select)
+                            self.conn)._select, [self.conn._conn_socket], [], [])
     def test_init_ip6(self):
         # This test simply creates a new XfrinConnection object with an
         # IPv6 address, tries to bind it to an IPv6 wildcard address/port

Modified: branches/trac216/src/bin/xfrin/xfrin.py.in
==============================================================================
--- branches/trac216/src/bin/xfrin/xfrin.py.in (original)
+++ branches/trac216/src/bin/xfrin/xfrin.py.in Wed Oct 20 03:08:54 2010
@@ -27,7 +27,7 @@
 import select
 import random
 import time
-from errno import EWOULDBLOCK, ECONNRESET, \
+from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, \
         ENOTCONN, ESHUTDOWN, EINTR, EISCONN, EBADF, ECONNABORTED, errorcode
 
 from optparse import OptionParser, OptionValueError
@@ -84,6 +84,7 @@
             check_soa: when it's true, check soa first before sending xfr query
         '''
         self._socket = socket.socket(master_addrinfo[0], master_addrinfo[1])
+        self._socket.setblocking(0)
         self._conn_socket = conn_socket
         self._zone_name = zone_name
         self._rrclass = rrclass
@@ -92,6 +93,15 @@
         self._idle_timeout = idle_timeout
         self._verbose = verbose
         self._master_address = master_addrinfo[4]
+
+    def connect(self,address): 
+        try:
+            self._socket.connect(address)
+        except socket.error as why:
+            if why.args[0] in (EINPROGRESS, EALREADY, EWOULDBLOCK):
+                return 
+            else:
+                raise
 
     def send(self, data):
         try:
@@ -133,7 +143,7 @@
         '''Connect to master in TCP.'''
 
         try:
-            self._socket.connect(self._master_address)
+            self.connect(self._master_address)
             return True
         except socket.error as e:
             self.log_msg('Failed to connect:(%s), %s' % (self._master_address,
@@ -169,20 +179,23 @@
         msg.to_wire(render)
         header_len = struct.pack('H', socket.htons(render.get_length()))
 
+        if self._select([], [self._socket], []) == 0:
+            raise XfrinException('socket is not writable!')
+        ret = self._socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+        if ret!=0:
+            raise XfrinException("error when writing:" + str(os.strerror(ret)))
+
         self._send_data(header_len)
         self._send_data(render.get_data())
 
-    def _select(self):
+    def _select(self, rl, wl, el):
         '''
         This method is a trivial asynchronous I/O routine using select.
-        It's extracted from _get_request_response so that we can test the
-        rest of the code without involving actual communication with a remote
-        server.'''
+        It checks which socket is writable or readable.'''
         while True:
             try:
-                rlist, wlist, elist = select.select([self._socket,
-                                                     self._conn_socket],
-                                                    [], [], self._idle_timeout)
+                (rlist, wlist, elist) = select.select(rl, wl,\
+                        el, self._idle_timeout)
             except select.error as err:
                 if err.args[0] != EINTR:
                     raise
@@ -194,15 +207,20 @@
                     raise XfrinException("shutdown xfrin!")
                 else:
                     return
-            return len(rlist)
+            if len(rl):
+                return len(rlist)
+            return len(wlist)
     
     def _get_request_response(self, size):
         recv_size = 0
         data = b''
         while recv_size < size:
             self._need_recv_size = size - recv_size
-            if self._select() == 0:
+            if self._select([self._socket, self._conn_socket], [],[]) == 0:
                 raise XfrinException('receive data from socket time out.')
+            ret = self._socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+            if ret!=0:
+                raise XfrinException("error when reading:" + str(os.strerror(ret)))
             recvd_data = self.recv(self._need_recv_size)
             recvd_size = len(recvd_data)
             if recvd_size == 0:




More information about the bind10-changes mailing list