source: python/stackless/stacklesssocket.py @ 41

Revision 41, 18.4 KB checked in by sgk, 11 years ago (diff)

stackless sample

Line 
1#
2# Stackless compatible socket module:
3#
4# Author: Richard Tew <richard.m.tew@gmail.com>
5#
6# This code was written to serve as an example of Stackless Python usage.
7# Feel free to email me with any questions, comments, or suggestions for
8# improvement.
9#
10# This wraps the asyncore module and the dispatcher class it provides in order
11# write a socket module replacement that uses channels to allow calls to it to
12# block until a delayed event occurs.
13#
14# Not all aspects of the socket module are provided by this file.  Examples of
15# it in use can be seen at the bottom of this file.
16#
17# NOTE: Versions of the asyncore module from Python 2.4 or later include bug
18#       fixes and earlier versions will not guarantee correct behaviour.
19#       Specifically, it monitors for errors on sockets where the version in
20#       Python 2.3.3 does not.
21#
22
23# Possible improvements:
24# - More correct error handling.  When there is an error on a socket found by
25#   poll, there is no idea what it actually is.
26# - Launching each bit of incoming data in its own tasklet on the recvChannel
27#   send is a little over the top.  It should be possible to add it to the
28#   rest of the queued data
29
30import stackless
31import asyncore
32import socket as stdsocket # We need the "socket" name for the function we export.
33
34# If we are to masquerade as the socket module, we need to provide the constants.
35if "__all__" in stdsocket.__dict__:
36    __all__ = stdsocket.__dict__
37    for k, v in stdsocket.__dict__.iteritems():
38        if k in __all__:
39            globals()[k] = v
40else:
41    for k, v in stdsocket.__dict__.iteritems():
42        if k.upper() == k:
43            globals()[k] = v
44    error = stdsocket.error
45    timeout = stdsocket.timeout
46    # WARNING: this function blocks and is not thread safe.
47    # The only solution is to spawn a thread to handle all
48    # getaddrinfo requests.  Implementing a stackless DNS
49    # lookup service is only second best as getaddrinfo may
50    # use other methods.
51    getaddrinfo = stdsocket.getaddrinfo
52
53# urllib2 apparently uses this directly.  We need to cater for that.
54_fileobject = stdsocket._fileobject
55
56# Someone needs to invoke asyncore.poll() regularly to keep the socket
57# data moving.  The "ManageSockets" function here is a simple example
58# of such a function.  It is started by StartManager(), which uses the
59# global "managerRunning" to ensure that no more than one copy is
60# running.
61#
62# If you think you can do this better, register an alternative to
63# StartManager using stacklesssocket_manager().  Your function will be
64# called every time a new socket is created; it's your responsibility
65# to ensure it doesn't start multiple copies of itself unnecessarily.
66#
67
68managerRunning = False
69
70def ManageSockets():
71    global managerRunning
72
73    while len(asyncore.socket_map):
74        # Check the sockets for activity.
75        asyncore.poll(0.05)
76        # Yield to give other tasklets a chance to be scheduled.
77        stackless.schedule()
78
79    managerRunning = False
80
81def StartManager():
82    global managerRunning
83    if not managerRunning:
84        managerRunning = True
85        stackless.tasklet(ManageSockets)()
86
87_manage_sockets_func = StartManager
88
89def stacklesssocket_manager(mgr):
90    global _manage_sockets_func
91    _manage_sockets_func = mgr
92
93#
94# Replacement for standard socket() constructor.
95#
96def socket(family=AF_INET, type=SOCK_STREAM, proto=0):
97    global managerRunning
98
99    currentSocket = stdsocket.socket(family, type, proto)
100    ret = stacklesssocket(currentSocket)
101    # Ensure that the sockets actually work.
102    _manage_sockets_func()
103    return ret
104
105# This is a facade to the dispatcher object.
106# It exists because asyncore's socket map keeps a bound reference to
107# the dispatcher and hence the dispatcher will never get gc'ed.
108#
109# The rest of the world sees a 'stacklesssocket' which has no cycles
110# and will be gc'ed correctly
111
112class stacklesssocket(object):
113    def __init__(self, sock):
114        self.sock = sock
115        self.dispatcher = dispatcher(sock)
116
117    def __getattr__(self, name):
118        # Forward nearly everything to the dispatcher
119        if not name.startswith("__"):
120            # I don't like forwarding __repr__
121            return getattr(self.dispatcher, name)
122
123    def __setattr__(self, name, value):
124        if name == "wrap_accept_socket":
125            # We need to pass setting of this to the dispatcher.
126            self.dispatcher.wrap_accept_socket = value
127        else:
128            # Anything else gets set locally.
129            object.__setattr__(self, name, value)
130
131    def __del__(self):
132        # Close dispatcher if it isn't already closed
133        if self.dispatcher._fileno is not None:
134            try:
135                self.dispatcher.close()
136            finally:
137                self.dispatcher = None
138
139    # Catch this one here to make gc work correctly.
140    # (Consider if stacklesssocket gets gc'ed before the _fileobject)
141    def makefile(self, mode='r', bufsize=-1):
142        return stdsocket._fileobject(self, mode, bufsize)
143
144
145class dispatcher(asyncore.dispatcher):
146    connectChannel = None
147    acceptChannel = None
148    recvChannel = None
149
150    def __init__(self, sock):
151        # This is worth doing.  I was passing in an invalid socket which was
152        # an instance of dispatcher and it was causing tasklet death.
153        if not isinstance(sock, stdsocket.socket):
154            raise StandardError("Invalid socket passed to dispatcher")
155        asyncore.dispatcher.__init__(self, sock)
156
157        # if self.socket.type == SOCK_DGRAM:
158        #    self.dgramRecvChannels = {}
159        #    self.dgramReadBuffers = {}
160        #else:
161        self.recvChannel = stackless.channel()
162        self.readBufferString = ''
163        self.readBufferList = []
164
165        self.sendBuffer = ''
166        self.sendToBuffers = []
167
168    def writable(self):
169        if self.socket.type != SOCK_DGRAM and not self.connected:
170            return True
171        return len(self.sendBuffer) or len(self.sendToBuffers)
172
173    def accept(self):
174        if not self.acceptChannel:
175            self.acceptChannel = stackless.channel()
176        return self.acceptChannel.receive()
177
178    def connect(self, address):
179        asyncore.dispatcher.connect(self, address)
180        # UDP sockets do not connect.
181        if self.socket.type != SOCK_DGRAM and not self.connected:
182            if not self.connectChannel:
183                self.connectChannel = stackless.channel()
184                # Prefer the sender.  Do not block when sending, given that
185                # there is a tasklet known to be waiting, this will happen.
186                self.connectChannel.preference = 1
187            self.connectChannel.receive()
188
189    def send(self, data):
190        self.sendBuffer += data
191        stackless.schedule()
192        return len(data)
193
194    def sendall(self, data):
195        # WARNING: this will busy wait until all data is sent
196        # It should be possible to do away with the busy wait with
197        # the use of a channel.
198        self.sendBuffer += data
199        while self.sendBuffer:
200            stackless.schedule()
201        return len(data)
202
203    def sendto(self, sendData, sendAddress):
204        waitChannel = None
205        for idx, (data, address, channel, sentBytes) in enumerate(self.sendToBuffers):
206            if address == sendAddress:
207                self.sendToBuffers[idx] = (data + sendData, address, channel, sentBytes)
208                waitChannel = channel
209                break
210        if waitChannel is None:
211            waitChannel = stackless.channel()
212            self.sendToBuffers.append((sendData, sendAddress, waitChannel, 0))
213        return waitChannel.receive()
214
215    # Read at most byteCount bytes.
216    def recv(self, byteCount):
217        # recv() must not concatnate two or more data fragments sent with
218        # send() on the remote side. Single fragment sent with signle send()
219        # call should be splitted as strings which length is less or equal
220        # to 'byteCount', and returned by one ore more recv() calls.
221        if not self.readBufferString:
222            self.readBufferString += self.recvChannel.receive()
223        ret = self.readBufferString[:byteCount]
224        self.readBufferString = self.readBufferString[byteCount:]
225        # ret wil be '' when EOF.
226        return ret
227
228    def recvfrom(self, byteCount):
229        if self.socket.type == SOCK_STREAM:
230            return (self.recv(byteCount), None)
231
232        # recvfrom() must not concatnate two or more packets.
233        # Each calls should return the first 'byteCount' part of the packet.
234        (data, address) = self.recvChannel.receive()
235        return (data[:byteCount], address)
236
237    def close(self):
238        asyncore.dispatcher.close(self)
239        self.connected = False
240        self.accepting = False
241        self.sendBuffer = None  # breaks the loop in sendall
242
243        # Clear out all the channels with relevant errors.
244        while self.acceptChannel and self.acceptChannel.balance < 0:
245            self.acceptChannel.send_exception(error, 9, 'Bad file descriptor')
246        while self.connectChannel and self.connectChannel.balance < 0:
247            self.connectChannel.send_exception(error, 10061, 'Connection refused')
248        while self.recvChannel and self.recvChannel.balance < 0:
249            # The closing of a socket is indicted by receiving nothing.  The
250            # exception would have been sent if the server was killed, rather
251            # than closed down gracefully.
252            self.recvChannel.send("")
253            #self.recvChannel.send_exception(error, 10054, 'Connection reset by peer')
254
255    # asyncore doesn't support this.  Why not?
256    def fileno(self):
257        return self.socket.fileno()
258
259    def handle_accept(self):
260        if self.acceptChannel and self.acceptChannel.balance < 0:
261            currentSocket, clientAddress = asyncore.dispatcher.accept(self)
262            currentSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
263            # Give them the asyncore based socket, not the standard one.
264            currentSocket = self.wrap_accept_socket(currentSocket)
265            stackless.tasklet(self.acceptChannel.send)((currentSocket, clientAddress))
266
267    # Inform the blocked connect call that the connection has been made.
268    def handle_connect(self):
269        if self.socket.type != SOCK_DGRAM:
270            self.connectChannel.send(None)
271
272    # Asyncore says its done but self.readBuffer may be non-empty
273    # so can't close yet.  Do nothing and let 'recv' trigger the close.
274    def handle_close(self):
275        pass
276
277    # Some error, just close the channel and let that raise errors to
278    # blocked calls.
279    def handle_expt(self):
280        self.close()
281
282    def handle_read(self):
283        try:
284            if self.socket.type == SOCK_DGRAM:
285                ret, address = self.socket.recvfrom(20000)
286                stackless.tasklet(self.recvChannel.send)((ret, address))
287            else:
288                ret = asyncore.dispatcher.recv(self, 20000)
289                # Not sure this is correct, but it seems to give the
290                # right behaviour.  Namely removing the socket from
291                # asyncore.
292                if not ret:
293                    self.close()
294                stackless.tasklet(self.recvChannel.send)(ret)
295        except stdsocket.error, err:
296            # XXX Is this correct?
297            # If there's a read error assume the connection is
298            # broken and drop any pending output
299            if self.sendBuffer:
300                self.sendBuffer = ""
301            # Why can't I pass the 'err' by itself?
302            self.recvChannel.send_exception(stdsocket.error, err)
303
304    def handle_write(self):
305        if len(self.sendBuffer):
306            sentBytes = asyncore.dispatcher.send(self, self.sendBuffer[:512])
307            self.sendBuffer = self.sendBuffer[sentBytes:]
308        elif len(self.sendToBuffers):
309            data, address, channel, oldSentBytes = self.sendToBuffers[0]
310            sentBytes = self.socket.sendto(data, address)
311            totalSentBytes = oldSentBytes + sentBytes
312            if len(data) > sentBytes:
313                self.sendToBuffers[0] = data[sentBytes:], address, channel, totalSentBytes
314            else:
315                del self.sendToBuffers[0]
316                stackless.tasklet(channel.send)(totalSentBytes)
317
318    # In order for incoming connections to be stackless compatible,
319    # they need to be wrapped by an asyncore based dispatcher subclass.
320    def wrap_accept_socket(self, currentSocket):
321        return stacklesssocket(currentSocket)
322
323
324if __name__ == '__main__':
325    import sys
326    import struct
327    # Test code goes here.
328    testAddress = "127.0.0.1", 3000
329    info = -12345678
330    data = struct.pack("i", info)
331    dataLength = len(data)
332
333    print "creating listen socket"
334    def TestTCPServer(address, socketClass=None):
335        global info, data, dataLength
336
337        if not socketClass:
338            socketClass = socket
339
340        listenSocket = socketClass(AF_INET, SOCK_STREAM)
341        listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
342        listenSocket.bind(address)
343        listenSocket.listen(5)
344
345        NUM_TESTS = 2
346
347        i = 1
348        while i < NUM_TESTS + 1:
349            # No need to schedule this tasklet as the accept should yield most
350            # of the time on the underlying channel.
351            print "waiting for connection test", i
352            currentSocket, clientAddress = listenSocket.accept()
353            print "received connection", i, "from", clientAddress
354
355            if i == 1:
356                currentSocket.close()
357            elif i == 2:
358                print "server test", i, "send"
359                currentSocket.send(data)
360                print "server test", i, "recv"
361                if currentSocket.recv(4) != "":
362                    print "server recv(1)", i, "FAIL"
363                    break
364                # multiple empty recvs are fine
365                if currentSocket.recv(4) != "":
366                    print "server recv(2)", i, "FAIL"
367                    break
368            else:
369                currentSocket.close()
370
371            print "server test", i, "OK"
372            i += 1
373
374        if i != NUM_TESTS+1:
375            print "server: FAIL", i
376        else:
377            print "server: OK", i
378
379        print "Done server"
380
381    def TestTCPClient(address, socketClass=None):
382        global info, data, dataLength
383
384        if not socketClass:
385            socketClass = socket
386
387        # Attempt 1:
388        clientSocket = socketClass()
389        clientSocket.connect(address)
390        print "client connection", 1, "waiting to recv"
391        if clientSocket.recv(5) != "":
392            print "client test", 1, "FAIL"
393        else:
394            print "client test", 1, "OK"
395
396        # Attempt 2:
397        clientSocket = socket()
398        clientSocket.connect(address)
399        print "client connection", 2, "waiting to recv"
400        s = clientSocket.recv(dataLength)
401        if s == "":
402            print "client test", 2, "FAIL (disconnect)"
403        else:
404            t = struct.unpack("i", s)
405            if t[0] == info:
406                print "client test", 2, "OK"
407            else:
408                print "client test", 2, "FAIL (wrong data)"
409
410    def TestMonkeyPatchUrllib(uri):
411        # replace the system socket with this module
412        oldSocket = sys.modules["socket"]
413        sys.modules["socket"] = __import__(__name__)
414        try:
415            import urllib  # must occur after monkey-patching!
416            f = urllib.urlopen(uri)
417            if not isinstance(f.fp._sock, stacklesssocket):
418                raise AssertionError("failed to apply monkeypatch")
419            s = f.read()
420            if len(s) != 0:
421                print "Fetched", len(s), "bytes via replaced urllib"
422            else:
423                raise AssertionError("no text received?")
424        finally:
425            sys.modules["socket"] = oldSocket
426
427    def TestMonkeyPatchUDP(address):
428        # replace the system socket with this module
429        oldSocket = sys.modules["socket"]
430        sys.modules["socket"] = __import__(__name__)
431        try:
432            def UDPServer(address):
433                listenSocket = socket(AF_INET, SOCK_DGRAM)
434                listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
435                listenSocket.bind(address)
436
437                i = 1
438                cnt = 0
439                while 1:
440                    #print "waiting for connection test", i
441                    #currentSocket, clientAddress = listenSocket.accept()
442                    #print "received connection", i, "from", clientAddress
443
444                    print "waiting to receive"
445                    t = listenSocket.recvfrom(256)
446                    cnt += len(t[0])
447                    print "received", t[0], cnt
448                    if cnt == 512:
449                        break
450
451            def UDPClient(address):
452                clientSocket = socket(AF_INET, SOCK_DGRAM)
453                # clientSocket.connect(address)
454                print "sending 512 byte packet"
455                sentBytes = clientSocket.sendto("-"+ ("*" * 510) +"-", address)
456                print "sent 512 byte packet", sentBytes
457
458            stackless.tasklet(UDPServer)(address)
459            stackless.tasklet(UDPClient)(address)
460            stackless.run()
461        finally:
462            sys.modules["socket"] = oldSocket
463
464    if len(sys.argv) == 2:
465        if sys.argv[1] == "client":
466            print "client started"
467            TestTCPClient(testAddress, stdsocket.socket)
468            print "client exited"
469        elif sys.argv[1] == "slpclient":
470            print "client started"
471            stackless.tasklet(TestTCPClient)(testAddress)
472            stackless.run()
473            print "client exited"
474        elif sys.argv[1] == "server":
475            print "server started"
476            TestTCPServer(testAddress, stdsocket.socket)
477            print "server exited"
478        elif sys.argv[1] == "slpserver":
479            print "server started"
480            stackless.tasklet(TestTCPServer)(testAddress)
481            stackless.run()
482            print "server exited"
483        else:
484            print "Usage:", sys.argv[0], "[client|server|slpclient|slpserver]"
485
486        sys.exit(1)
487    else:
488        stackless.tasklet(TestTCPServer)(testAddress)
489        stackless.tasklet(TestTCPClient)(testAddress)
490        stackless.run()
491
492        stackless.tasklet(TestMonkeyPatchUrllib)("http://python.org/")
493        stackless.run()
494
495        TestMonkeyPatchUDP(testAddress)
496
497        print "result: SUCCESS"
Note: See TracBrowser for help on using the repository browser.