@@ -488,13 +488,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
488488 for key in server_keys [server ]: # These are mangled keys
489489 cmd = self ._encode_cmd ('delete' , key , headers , noreply , b'\r \n ' )
490490 write (cmd )
491- try :
491+ with _socket_guard ( server , ( socket . error ,)) as sg :
492492 server .send_cmds (b'' .join (bigcmd ))
493- except socket . error as msg :
493+ if sg . interrupted :
494494 rc = 0
495- if isinstance (msg , tuple ):
496- msg = msg [1 ]
497- server .mark_dead (msg )
498495 dead_servers .append (server )
499496
500497 # if noreply, just return
@@ -506,13 +503,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
506503 del server_keys [server ]
507504
508505 for server , keys in six .iteritems (server_keys ):
509- try :
506+ with _socket_guard ( server , ( socket . error ,)) as sg :
510507 for key in keys :
511508 server .expect (b"DELETED" )
512- except socket .error as msg :
513- if isinstance (msg , tuple ):
514- msg = msg [1 ]
515- server .mark_dead (msg )
509+ if sg .interrupted :
516510 rc = 0
517511 return rc
518512
@@ -558,7 +552,7 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False):
558552 headers = None
559553 fullcmd = self ._encode_cmd (cmd , key , headers , noreply )
560554
561- try :
555+ with _socket_guard ( server , ( socket . error ,)) :
562556 server .send_cmd (fullcmd )
563557 if noreply :
564558 return 1
@@ -567,10 +561,6 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False):
567561 return 1
568562 self .debuglog ('%s expected %s, got: %r'
569563 % (cmd , ' or ' .join (expected ), line ))
570- except socket .error as msg :
571- if isinstance (msg , tuple ):
572- msg = msg [1 ]
573- server .mark_dead (msg )
574564 return 0
575565
576566 def incr (self , key , delta = 1 , noreply = False ):
@@ -633,19 +623,14 @@ def _incrdecr(self, cmd, key, delta, noreply=False):
633623 return None
634624 self ._statlog (cmd )
635625 fullcmd = self ._encode_cmd (cmd , key , str (delta ), noreply )
636- try :
626+ with _socket_guard ( server , ( socket . error ,)) :
637627 server .send_cmd (fullcmd )
638628 if noreply :
639629 return
640630 line = server .readline ()
641631 if line is None or line .strip () == b'NOT_FOUND' :
642632 return None
643633 return int (line )
644- except socket .error as msg :
645- if isinstance (msg , tuple ):
646- msg = msg [1 ]
647- server .mark_dead (msg )
648- return None
649634
650635 def add (self , key , val , time = 0 , min_compress_len = 0 , noreply = False ):
651636 '''Add new key with value.
@@ -902,7 +887,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
902887 for server in six .iterkeys (server_keys ):
903888 bigcmd = []
904889 write = bigcmd .append
905- try :
890+ with _socket_guard ( server , ( socket . error ,)) as sg :
906891 for key in server_keys [server ]: # These are mangled keys
907892 store_info = self ._val_to_store_info (
908893 mapping [prefixed_to_orig_key [key ]],
@@ -917,10 +902,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
917902 else :
918903 notstored .append (prefixed_to_orig_key [key ])
919904 server .send_cmds (b'' .join (bigcmd ))
920- except socket .error as msg :
921- if isinstance (msg , tuple ):
922- msg = msg [1 ]
923- server .mark_dead (msg )
905+ if sg .interrupted :
924906 dead_servers .append (server )
925907
926908 # if noreply, just return early
@@ -936,17 +918,13 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
936918 return list (mapping .keys ())
937919
938920 for server , keys in six .iteritems (server_keys ):
939- try :
921+ with _socket_guard ( server , ( _Error , socket . error )) :
940922 for key in keys :
941923 if server .readline () == b'STORED' :
942924 continue
943925 else :
944926 # un-mangle.
945927 notstored .append (prefixed_to_orig_key [key ])
946- except (_Error , socket .error ) as msg :
947- if isinstance (msg , tuple ):
948- msg = msg [1 ]
949- server .mark_dead (msg )
950928 return notstored
951929
952930 def _val_to_store_info (self , val , min_compress_len ):
@@ -1032,15 +1010,11 @@ def _unsafe_set():
10321010 fullcmd = self ._encode_cmd (cmd , key , headers , noreply ,
10331011 b'\r \n ' , encoded_val )
10341012
1035- try :
1013+ with _socket_guard ( server , ( socket . error ,)) :
10361014 server .send_cmd (fullcmd )
10371015 if noreply :
10381016 return True
10391017 return server .expect (b"STORED" , raise_exception = True ) == b"STORED"
1040- except socket .error as msg :
1041- if isinstance (msg , tuple ):
1042- msg = msg [1 ]
1043- server .mark_dead (msg )
10441018 return 0
10451019
10461020 try :
@@ -1065,7 +1039,7 @@ def _get(self, cmd, key):
10651039 def _unsafe_get ():
10661040 self ._statlog (cmd )
10671041
1068- try :
1042+ with _socket_guard ( server , ( _Error , socket . error )) :
10691043 cmd_bytes = cmd .encode ('utf-8' ) if six .PY3 else cmd
10701044 fullcmd = b'' .join ((cmd_bytes , b' ' , key ))
10711045 server .send_cmd (fullcmd )
@@ -1085,16 +1059,9 @@ def _unsafe_get():
10851059 if not rkey :
10861060 return None
10871061 try :
1088- value = self ._recv_value (server , flags , rlen )
1062+ return self ._recv_value (server , flags , rlen )
10891063 finally :
10901064 server .expect (b"END" , raise_exception = True )
1091- except (_Error , socket .error ) as msg :
1092- if isinstance (msg , tuple ):
1093- msg = msg [1 ]
1094- server .mark_dead (msg )
1095- return None
1096-
1097- return value
10981065
10991066 try :
11001067 return _unsafe_get ()
@@ -1185,13 +1152,10 @@ def get_multi(self, keys, key_prefix=''):
11851152 # send out all requests on each server before reading anything
11861153 dead_servers = []
11871154 for server in six .iterkeys (server_keys ):
1188- try :
1155+ with _socket_guard ( server , ( socket . error ,)) as sg :
11891156 fullcmd = b"get " + b" " .join (server_keys [server ])
11901157 server .send_cmd (fullcmd )
1191- except socket .error as msg :
1192- if isinstance (msg , tuple ):
1193- msg = msg [1 ]
1194- server .mark_dead (msg )
1158+ if sg .interrupted :
11951159 dead_servers .append (server )
11961160
11971161 # if any servers died on the way, don't expect them to respond.
@@ -1200,7 +1164,7 @@ def get_multi(self, keys, key_prefix=''):
12001164
12011165 retvals = {}
12021166 for server in six .iterkeys (server_keys ):
1203- try :
1167+ with _socket_guard ( server , ( _Error , socket . error )) :
12041168 line = server .readline ()
12051169 while line and line != b'END' :
12061170 rkey , flags , rlen = self ._expectvalue (server , line )
@@ -1210,10 +1174,6 @@ def get_multi(self, keys, key_prefix=''):
12101174 # un-prefix returned key.
12111175 retvals [prefixed_to_orig_key [rkey ]] = val
12121176 line = server .readline ()
1213- except (_Error , socket .error ) as msg :
1214- if isinstance (msg , tuple ):
1215- msg = msg [1 ]
1216- server .mark_dead (msg )
12171177 return retvals
12181178
12191179 def _expect_cas_value (self , server , line = None , raise_exception = False ):
@@ -1394,15 +1354,10 @@ def _get_socket(self):
13941354 s = socket .socket (self .family , socket .SOCK_STREAM )
13951355 if hasattr (s , 'settimeout' ):
13961356 s .settimeout (self .socket_timeout )
1397- try :
1357+ with _socket_guard (self , (socket .error ,),
1358+ msg_tmpl = 'connect: {}' ) as sg :
13981359 s .connect (self .address )
1399- except socket .timeout as msg :
1400- self .mark_dead ("connect: %s" % msg )
1401- return None
1402- except socket .error as msg :
1403- if isinstance (msg , tuple ):
1404- msg = msg [1 ]
1405- self .mark_dead ("connect: %s" % msg )
1360+ if sg .interrupted :
14061361 return None
14071362 self .socket = s
14081363 self .buffer = b''
@@ -1497,6 +1452,30 @@ def __str__(self):
14971452 return "unix:%s%s" % (self .address , d )
14981453
14991454
1455+ class _socket_guard (object ):
1456+ def __init__ (self , server , exceptions , msg_tmpl = '{}' ):
1457+ self ._server = server
1458+ self ._exceptions = exceptions
1459+ self ._msg_tmpl = msg_tmpl
1460+ self .interrupted = False
1461+
1462+ def __enter__ (self ):
1463+ return self
1464+
1465+ def __exit__ (self , exc_type , exc , exc_tb ):
1466+ if exc is not None :
1467+ self .interrupted = True
1468+
1469+ if isinstance (exc , self ._exceptions ):
1470+ msg = self ._msg_tmpl .format (exc )
1471+ self ._server .mark_dead (msg )
1472+ return True
1473+ elif exc is not None :
1474+ self ._server .close_socket ()
1475+
1476+ return False
1477+
1478+
15001479def _doctest ():
15011480 import doctest
15021481 import memcache
0 commit comments