root/python/stackless/stacklesssocket.py

リビジョン 41, 18.4 kB (コミッタ: sgk, コミット時期: 1 年 前)

stackless sample

Line 
1 #
2 # Stackless compatible socket module:
3 #
4 # Author: Richard Tew <richard.m.tew@gmail.com>
5 #
6 # This code was written to serve as an example of Stackless Python usage.
7 # Feel free to email me with any questions, comments, or suggestions for
8 # improvement.
9 #
10 # This wraps the asyncore module and the dispatcher class it provides in order
11 # write a socket module replacement that uses channels to allow calls to it to
12 # block until a delayed event occurs.
13 #
14 # Not all aspects of the socket module are provided by this file.  Examples of
15 # it in use can be seen at the bottom of this file.
16 #
17 # NOTE: Versions of the asyncore module from Python 2.4 or later include bug
18 #       fixes and earlier versions will not guarantee correct behaviour.
19 #       Specifically, it monitors for errors on sockets where the version in
20 #       Python 2.3.3 does not.
21 #
22
23 # Possible improvements:
24 # - More correct error handling.  When there is an error on a socket found by
25 #   poll, there is no idea what it actually is.
26 # - Launching each bit of incoming data in its own tasklet on the recvChannel
27 #   send is a little over the top.  It should be possible to add it to the
28 #   rest of the queued data
29
30 import stackless
31 import asyncore
32 import socket as stdsocket # We need the "socket" name for the function we export.
33
34 # If we are to masquerade as the socket module, we need to provide the constants.
35 if "__all__" in stdsocket.__dict__:
36     __all__ = stdsocket.__dict__
37     for k, v in stdsocket.__dict__.iteritems():
38         if k in __all__:
39             globals()[k] = v
40 else:
41     for k, v in stdsocket.__dict__.iteritems():
42         if k.upper() == k:
43             globals()[k] = v
44     error = stdsocket.error
45     timeout = stdsocket.timeout
46     # WARNING: this function blocks and is not thread safe.
47     # The only solution is to spawn a thread to handle all
48     # getaddrinfo requests.  Implementing a stackless DNS
49     # lookup service is only second best as getaddrinfo may
50     # use other methods.
51     getaddrinfo = stdsocket.getaddrinfo
52
53 # urllib2 apparently uses this directly.  We need to cater for that.
54 _fileobject = stdsocket._fileobject
55
56 # Someone needs to invoke asyncore.poll() regularly to keep the socket
57 # data moving.  The "ManageSockets" function here is a simple example
58 # of such a function.  It is started by StartManager(), which uses the
59 # global "managerRunning" to ensure that no more than one copy is
60 # running.
61 #
62 # If you think you can do this better, register an alternative to
63 # StartManager using stacklesssocket_manager().  Your function will be
64 # called every time a new socket is created; it's your responsibility
65 # to ensure it doesn't start multiple copies of itself unnecessarily.
66 #
67
68 managerRunning = False
69
70 def ManageSockets():
71     global managerRunning
72
73     while len(asyncore.socket_map):
74         # Check the sockets for activity.
75         asyncore.poll(0.05)
76         # Yield to give other tasklets a chance to be scheduled.
77         stackless.schedule()
78
79     managerRunning = False
80
81 def StartManager():
82     global managerRunning
83     if not managerRunning:
84         managerRunning = True
85         stackless.tasklet(ManageSockets)()
86
87 _manage_sockets_func = StartManager
88
89 def stacklesssocket_manager(mgr):
90     global _manage_sockets_func
91     _manage_sockets_func = mgr
92
93 #
94 # Replacement for standard socket() constructor.
95 #
96 def socket(family=AF_INET, type=SOCK_STREAM, proto=0):
97     global managerRunning
98
99     currentSocket = stdsocket.socket(family, type, proto)
100     ret = stacklesssocket(currentSocket)
101     # Ensure that the sockets actually work.
102     _manage_sockets_func()
103     return ret
104
105 # This is a facade to the dispatcher object.
106 # It exists because asyncore's socket map keeps a bound reference to
107 # the dispatcher and hence the dispatcher will never get gc'ed.
108 #
109 # The rest of the world sees a 'stacklesssocket' which has no cycles
110 # and will be gc'ed correctly
111
112 class stacklesssocket(object):
113     def __init__(self, sock):
114         self.sock = sock
115         self.dispatcher = dispatcher(sock)
116
117     def __getattr__(self, name):
118         # Forward nearly everything to the dispatcher
119         if not name.startswith("__"):
120             # I don't like forwarding __repr__
121             return getattr(self.dispatcher, name)
122
123     def __setattr__(self, name, value):
124         if name == "wrap_accept_socket":
125             # We need to pass setting of this to the dispatcher.
126             self.dispatcher.wrap_accept_socket = value
127         else:
128             # Anything else gets set locally.
129             object.__setattr__(self, name, value)
130
131     def __del__(self):
132         # Close dispatcher if it isn't already closed
133         if self.dispatcher._fileno is not None:
134             try:
135                 self.dispatcher.close()
136             finally:
137                 self.dispatcher = None
138
139     # Catch this one here to make gc work correctly.
140     # (Consider if stacklesssocket gets gc'ed before the _fileobject)
141     def makefile(self, mode='r', bufsize=-1):
142         return stdsocket._fileobject(self, mode, bufsize)
143
144
145 class dispatcher(asyncore.dispatcher):
146     connectChannel = None
147     acceptChannel = None
148     recvChannel = None
149
150     def __init__(self, sock):
151         # This is worth doing.  I was passing in an invalid socket which was
152         # an instance of dispatcher and it was causing tasklet death.
153         if not isinstance(sock, stdsocket.socket):
154             raise StandardError("Invalid socket passed to dispatcher")
155         asyncore.dispatcher.__init__(self, sock)
156
157         # if self.socket.type == SOCK_DGRAM:
158         #    self.dgramRecvChannels = {}
159         #    self.dgramReadBuffers = {}
160         #else:
161         self.recvChannel = stackless.channel()
162         self.readBufferString = ''
163         self.readBufferList = []
164
165         self.sendBuffer = ''
166         self.sendToBuffers = []
167
168     def writable(self):
169         if self.socket.type != SOCK_DGRAM and not self.connected:
170             return True
171         return len(self.sendBuffer) or len(self.sendToBuffers)
172
173     def accept(self):
174         if not self.acceptChannel:
175             self.acceptChannel = stackless.channel()
176         return self.acceptChannel.receive()
177
178     def connect(self, address):
179         asyncore.dispatcher.connect(self, address)
180         # UDP sockets do not connect.
181         if self.socket.type != SOCK_DGRAM and not self.connected:
182             if not self.connectChannel:
183                 self.connectChannel = stackless.channel()
184                 # Prefer the sender.  Do not block when sending, given that
185                 # there is a tasklet known to be waiting, this will happen.
186                 self.connectChannel.preference = 1
187             self.connectChannel.receive()
188
189     def send(self, data):
190         self.sendBuffer += data
191         stackless.schedule()
192         return len(data)
193
194     def sendall(self, data):
195         # WARNING: this will busy wait until all data is sent
196         # It should be possible to do away with the busy wait with
197         # the use of a channel.
198         self.sendBuffer += data
199         while self.sendBuffer:
200             stackless.schedule()
201         return len(data)
202
203     def sendto(self, sendData, sendAddress):
204         waitChannel = None
205         for idx, (data, address, channel, sentBytes) in enumerate(self.sendToBuffers):
206             if address == sendAddress:
207                 self.sendToBuffers[idx] = (data + sendData, address, channel, sentBytes)
208                 waitChannel = channel
209                 break
210         if waitChannel is None:
211             waitChannel = stackless.channel()
212             self.sendToBuffers.append((sendData, sendAddress, waitChannel, 0))
213         return waitChannel.receive()
214
215     # Read at most byteCount bytes.
216     def recv(self, byteCount):
217         # recv() must not concatnate two or more data fragments sent with
218         # send() on the remote side. Single fragment sent with signle send()
219         # call should be splitted as strings which length is less or equal
220         # to 'byteCount', and returned by one ore more recv() calls.
221         if not self.readBufferString:
222             self.readBufferString += self.recvChannel.receive()
223         ret = self.readBufferString[:byteCount]
224         self.readBufferString = self.readBufferString[byteCount:]
225         # ret wil be '' when EOF.
226         return ret
227
228     def recvfrom(self, byteCount):
229         if self.socket.type == SOCK_STREAM:
230             return (self.recv(byteCount), None)
231
232         # recvfrom() must not concatnate two or more packets.
233         # Each calls should return the first 'byteCount' part of the packet.
234         (data, address) = self.recvChannel.receive()
235         return (data[:byteCount], address)
236
237     def close(self):
238         asyncore.dispatcher.close(self)
239         self.connected = False
240         self.accepting = False
241         self.sendBuffer = None  # breaks the loop in sendall
242
243         # Clear out all the channels with relevant errors.
244         while self.acceptChannel and self.acceptChannel.balance < 0:
245             self.acceptChannel.send_exception(error, 9, 'Bad file descriptor')
246         while self.connectChannel and self.connectChannel.balance < 0:
247             self.connectChannel.send_exception(error, 10061, 'Connection refused')
248         while self.recvChannel and self.recvChannel.balance < 0:
249             # The closing of a socket is indicted by receiving nothing.  The
250             # exception would have been sent if the server was killed, rather
251             # than closed down gracefully.
252             self.recvChannel.send("")
253             #self.recvChannel.send_exception(error, 10054, 'Connection reset by peer')
254
255     # asyncore doesn't support this.  Why not?
256     def fileno(self):
257         return self.socket.fileno()
258
259     def handle_accept(self):
260         if self.acceptChannel and self.acceptChannel.balance < 0:
261             currentSocket, clientAddress = asyncore.dispatcher.accept(self)
262             currentSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
263             # Give them the asyncore based socket, not the standard one.
264             currentSocket = self.wrap_accept_socket(currentSocket)
265             stackless.tasklet(self.acceptChannel.send)((currentSocket, clientAddress))
266
267     # Inform the blocked connect call that the connection has been made.
268     def handle_connect(self):
269         if self.socket.type != SOCK_DGRAM:
270             self.connectChannel.send(None)
271
272     # Asyncore says its done but self.readBuffer may be non-empty
273     # so can't close yet.  Do nothing and let 'recv' trigger the close.
274     def handle_close(self):
275         pass
276
277     # Some error, just close the channel and let that raise errors to
278     # blocked calls.
279     def handle_expt(self):
280         self.close()
281
282     def handle_read(self):
283         try:
284             if self.socket.type == SOCK_DGRAM:
285                 ret, address = self.socket.recvfrom(20000)
286                 stackless.tasklet(self.recvChannel.send)((ret, address))
287             else:
288                 ret = asyncore.dispatcher.recv(self, 20000)
289                 # Not sure this is correct, but it seems to give the
290                 # right behaviour.  Namely removing the socket from
291                 # asyncore.
292                 if not ret:
293                     self.close()
294                 stackless.tasklet(self.recvChannel.send)(ret)
295         except stdsocket.error, err:
296             # XXX Is this correct?
297             # If there's a read error assume the connection is
298             # broken and drop any pending output
299             if self.sendBuffer:
300                 self.sendBuffer = ""
301             # Why can't I pass the 'err' by itself?
302             self.recvChannel.send_exception(stdsocket.error, err)
303
304     def handle_write(self):
305         if len(self.sendBuffer):
306             sentBytes = asyncore.dispatcher.send(self, self.sendBuffer[:512])
307             self.sendBuffer = self.sendBuffer[sentBytes:]
308         elif len(self.sendToBuffers):
309             data, address, channel, oldSentBytes = self.sendToBuffers[0]
310             sentBytes = self.socket.sendto(data, address)
311             totalSentBytes = oldSentBytes + sentBytes
312             if len(data) > sentBytes:
313                 self.sendToBuffers[0] = data[sentBytes:], address, channel, totalSentBytes
314             else:
315                 del self.sendToBuffers[0]
316                 stackless.tasklet(channel.send)(totalSentBytes)
317
318     # In order for incoming connections to be stackless compatible,
319     # they need to be wrapped by an asyncore based dispatcher subclass.
320     def wrap_accept_socket(self, currentSocket):
321         return stacklesssocket(currentSocket)
322
323
324 if __name__ == '__main__':
325     import sys
326     import struct
327     # Test code goes here.
328     testAddress = "127.0.0.1", 3000
329     info = -12345678
330     data = struct.pack("i", info)
331     dataLength = len(data)
332
333     print "creating listen socket"
334     def TestTCPServer(address, socketClass=None):
335         global info, data, dataLength
336
337         if not socketClass:
338             socketClass = socket
339
340         listenSocket = socketClass(AF_INET, SOCK_STREAM)
341         listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
342         listenSocket.bind(address)
343         listenSocket.listen(5)
344
345         NUM_TESTS = 2
346
347         i = 1
348         while i < NUM_TESTS + 1:
349             # No need to schedule this tasklet as the accept should yield most
350             # of the time on the underlying channel.
351             print "waiting for connection test", i
352             currentSocket, clientAddress = listenSocket.accept()
353             print "received connection", i, "from", clientAddress
354
355             if i == 1:
356                 currentSocket.close()
357             elif i == 2:
358                 print "server test", i, "send"
359                 currentSocket.send(data)
360                 print "server test", i, "recv"
361                 if currentSocket.recv(4) != "":
362                     print "server recv(1)", i, "FAIL"
363                     break
364                 # multiple empty recvs are fine
365                 if currentSocket.recv(4) != "":
366                     print "server recv(2)", i, "FAIL"
367                     break
368             else:
369                 currentSocket.close()
370
371             print "server test", i, "OK"
372             i += 1
373
374         if i != NUM_TESTS+1:
375             print "server: FAIL", i
376         else:
377             print "server: OK", i
378
379         print "Done server"
380
381     def TestTCPClient(address, socketClass=None):
382         global info, data, dataLength
383
384         if not socketClass:
385             socketClass = socket
386
387         # Attempt 1:
388         clientSocket = socketClass()
389         clientSocket.connect(address)
390         print "client connection", 1, "waiting to recv"
391         if clientSocket.recv(5) != "":
392             print "client test", 1, "FAIL"
393         else:
394             print "client test", 1, "OK"
395
396         # Attempt 2:
397         clientSocket = socket()
398         clientSocket.connect(address)
399         print "client connection", 2, "waiting to recv"
400         s = clientSocket.recv(dataLength)
401         if s == "":
402             print "client test", 2, "FAIL (disconnect)"
403         else:
404             t = struct.unpack("i", s)
405             if t[0] == info:
406                 print "client test", 2, "OK"
407             else:
408                 print "client test", 2, "FAIL (wrong data)"
409
410     def TestMonkeyPatchUrllib(uri):
411         # replace the system socket with this module
412         oldSocket = sys.modules["socket"]
413         sys.modules["socket"] = __import__(__name__)
414         try:
415             import urllib  # must occur after monkey-patching!
416             f = urllib.urlopen(uri)
417             if not isinstance(f.fp._sock, stacklesssocket):
418                 raise AssertionError("failed to apply monkeypatch")
419             s = f.read()
420             if len(s) != 0:
421                 print "Fetched", len(s), "bytes via replaced urllib"
422             else:
423                 raise AssertionError("no text received?")
424         finally:
425             sys.modules["socket"] = oldSocket
426
427     def TestMonkeyPatchUDP(address):
428         # replace the system socket with this module
429         oldSocket = sys.modules["socket"]
430         sys.modules["socket"] = __import__(__name__)
431         try:
432             def UDPServer(address):
433                 listenSocket = socket(AF_INET, SOCK_DGRAM)
434                 listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
435                 listenSocket.bind(address)
436
437                 i = 1
438                 cnt = 0
439                 while 1:
440                     #print "waiting for connection test", i
441                     #currentSocket, clientAddress = listenSocket.accept()
442                     #print "received connection", i, "from", clientAddress
443
444                     print "waiting to receive"
445                     t = listenSocket.recvfrom(256)
446                     cnt += len(t[0])
447                     print "received", t[0], cnt
448                     if cnt == 512:
449                         break
450
451             def UDPClient(address):
452                 clientSocket = socket(AF_INET, SOCK_DGRAM)
453                 # clientSocket.connect(address)
454                 print "sending 512 byte packet"
455                 sentBytes = clientSocket.sendto("-"+ ("*" * 510) +"-", address)
456                 print "sent 512 byte packet", sentBytes
457
458             stackless.tasklet(UDPServer)(address)
459             stackless.tasklet(UDPClient)(address)
460             stackless.run()
461         finally:
462             sys.modules["socket"] = oldSocket
463
464     if len(sys.argv) == 2:
465         if sys.argv[1] == "client":
466             print "client started"
467             TestTCPClient(testAddress, stdsocket.socket)
468             print "client exited"
469         elif sys.argv[1] == "slpclient":
470             print "client started"
471             stackless.tasklet(TestTCPClient)(testAddress)
472             stackless.run()
473             print "client exited"
474         elif sys.argv[1] == "server":
475             print "server started"
476             TestTCPServer(testAddress, stdsocket.socket)
477             print "server exited"
478         elif sys.argv[1] == "slpserver":
479             print "server started"
480             stackless.tasklet(TestTCPServer)(testAddress)
481             stackless.run()
482             print "server exited"
483         else:
484             print "Usage:", sys.argv[0], "[client|server|slpclient|slpserver]"
485
486         sys.exit(1)
487     else:
488         stackless.tasklet(TestTCPServer)(testAddress)
489         stackless.tasklet(TestTCPClient)(testAddress)
490         stackless.run()
491
492         stackless.tasklet(TestMonkeyPatchUrllib)("http://python.org/")
493         stackless.run()
494
495         TestMonkeyPatchUDP(testAddress)
496
497         print "result: SUCCESS"
Note: リポジトリブラウザについてのヘルプは TracBrowser を参照してください。