-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathskeleton.py
More file actions
267 lines (213 loc) · 9.83 KB
/
skeleton.py
File metadata and controls
267 lines (213 loc) · 9.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
from gevent import socket
from gevent.pool import Pool
from gevent.server import StreamServer
from collections import namedtuple
from io import BytesIO
from socket import error as socket_error
# We'll use exceptions to notify the connection-handling loop of problems.
class CommandError(Exception): pass # "the command was bad"
class Disconnect(Exception): pass # "the client left"
Error = namedtuple('Error', ('message',)) # this lets you access fields by name instead of position, first argument is type name and then second is field names
class ProtocolHandler(object):
def __init__(self):
self.handlers = {
'+': self.handle_simple_string,
'-': self.handle_error,
':': self.handle_integer,
'$': self.handle_string,
'*': self.handle_array,
'%': self.handle_dict
}
def handle_request(self, socket_file):
# Parse a request from the client into it's component parts.
# this handle functions till the last one is to implement the redis protocol
# this does deserialization, it reads the bytes and rebuilds Python objects, deserialization is taking the transported/stored format and converting it back into usable data in memory
first_byte = socket_file.read(1)
if not first_byte:
raise Disconnect()
first_byte = first_byte.decode('utf-8')
try:
# Delegate to the appropriate handler based on the first byte
return self.handlers[first_byte](socket_file)
except:
raise CommandError('bad request')
def handle_simple_string(self, socket_file):
return socket_file.readline().rstrip(b'\r\n')
def handle_error(self, socket_file):
return Error(socket_file.readline().rstrip(b'\r\n'))
def handle_integer(self, socket_file):
return int(socket_file.readline().rstrip(b'\r\n'))
def handle_string(self, socket_file):
# First read the length ($<length>\r\n).
length = int(socket_file.readline().rstrip(b'\r\n'))
if length == -1:
return None # Special-case for NULLs.
length += 2 # Include the trailing \r\n in count.
return socket_file.read(length)[:-2]
def handle_array(self, socket_file):
num_elements = int(socket_file.readline().rstrip(b'\r\n'))
return [self.handle_request(socket_file) for _ in range(num_elements)]
def handle_dict(self, socket_file):
num_items = int(socket_file.readline().rstrip(b'\r\n'))
elements = [self.handle_request(socket_file) for _ in range(num_items * 2)]
return dict(zip(elements[::2], elements[1::2])) # zip() is used to combine multiple iterables element by element, and it returns an iterator(zip object) as a result which is why we convert it to a dict here
def write_response(self, socket_file, data):
# Serialize the response data and send it to the client.
buf = BytesIO() # creates an empty binary buffer stored in memory
self._write(buf, data)
buf.seek(0)
socket_file.write(buf.getvalue()) # gets the value and sends them down the socket
socket_file.flush()
def _write(self, buf, data):
# this does serialization, it converts python objects to RESP bytes, serialization is taking data that lives in your programs memory and converting it into a format that can be transported or sorted
if isinstance(data, str):
data = data.encode('utf-8')
if isinstance(data, bytes):
buf.write(('$%s\r\n' % len(data)).encode())
buf.write(data)
buf.write(b'\r\n')
elif isinstance(data, int):
buf.write((':%s\r\n' % data).encode())
elif isinstance(data, Error):
buf.write(('-%s\r\n' % data.message).encode())
elif isinstance(data, (list, tuple)):
buf.write(('*%s\r\n' % len(data)).encode())
for item in data:
self._write(buf, item)
elif isinstance(data, dict):
buf.write(('%%%s\r\n' % len(data)).encode())
for key in data:
self._write(buf, key)
self._write(buf, data[key])
elif data is None:
buf.write(b'$-1\r\n') # the b at the front means a byte literal
else:
raise CommandError('unrecognized type: %s' % type(data))
class Server(object):
def __init__(self, host='127.0.0.1', port=31337, max_clients=64):
self._pool = Pool(max_clients) # gevent makes it concurrent
# the server binds to an address and port and just waits for incoming connections, so it starts listening
self._server = StreamServer(
(host, port),
self.connection_handler,
spawn=self._pool
)
self._protocol = ProtocolHandler()
self._kv = {} # this is the entire database, redis is just a fancy dict, where you store something under a name and retrieve it by that name. redis just adds on top data types, expiry, pub/sub, persistence, networking
# there is also the speed v persistence part
self._commands = self.get_commands()
def get_commands(self):
# this is a good design pattern, where you store a dict where the keys are the commands and the values are the functions, this way when you call dict[key]() you are calling the function
return {
'GET': self.get,
'SET': self.set,
'DELETE': self.delete,
'FLUSH': self.flush,
'MGET': self.mget,
'MSET': self.mset
}
def connection_handler(self, conn, address):
# Convert "conn" (a socket object) into a file-like object.
socket_file = conn.makefile('rwb') # socket becomes a file on the server side so we can use convenient methods like read(), .readline() etc
# Process client requests until client disconnects.
while True:
try:
data = self._protocol.handle_request(socket_file)
except Disconnect:
break # client left, exit the loop cleanly
try:
resp = self.get_response(data)
except CommandError as exc:
resp = Error(exc.args[0]) # bad command, send error back to client
except Exception as exc:
print('ERROR:', exc) # ← add this line
import traceback # ← add this line
traceback.print_exc() # ← add this line
resp = Error(str(exc))
self._protocol.write_response(socket_file, resp)
def get_response(self, data):
# Here we'll actually unpack the data sent by the client, execute the
# command they specified, and pass back the return value.
if not isinstance(data, list):
try:
data = data.split()
except:
raise CommandError('Request must be list or simple string')
if not data:
raise CommandError('Missing command')
command = data[0]
if isinstance(command, bytes):
command = command.decode('utf-8') # decode first
command = command.upper()
if command not in self._commands:
raise CommandError('Unrecognized command: %s' % command)
return self._commands[command](*data[1:])
def get(self, key):
return self._kv.get(key)
def set(self, key, value):
self._kv[key] = value
return 1
def delete(self, key):
if key in self._kv:
del self._kv[key]
return 1
return 0
def flush(self):
kvlen = len(self._kv)
self._kv.clear()
return kvlen
def mget(self, *keys):
return [self._kv.get(key) for key in keys]
def mset(self, *items):
data = list(zip(items[::2], items[1::2]))
for key, value in data:
self._kv[key] = value
return len(data)
def run(self):
self._server.serve_forever()
class Client(object):
def __init__(self, host='127.0.0.1', port=31337):
self._protocol = ProtocolHandler()
# self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# self._socket.connect((host, port))
# The client knocks on port `31337`. TCP does a **3-way handshake** behind the scenes:
# self._fh = self._socket.makefile('rwb') # socket becomes a file client side to enable convenient methods
self._host = host
self._port = port
self._connect()
def execute(self, *args):
self._protocol.write_response(self._fh, args)
resp = self._protocol.handle_request(self._fh)
if isinstance(resp, Error):
raise CommandError(resp.message)
if isinstance(resp, bytes):
resp = resp.decode('utf-8')
return resp
def get(self, key):
return self.execute('GET', key)
def set(self, key, value):
return self.execute('SET', key, value)
def delete(self, key):
return self.execute('DELETE', key)
def flush(self):
return self.execute('FLUSH')
def mget(self, *keys):
return self.execute('MGET', *keys)
def mset(self, *items):
return self.execute('MSET', *items)
def _connect(self):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.connect((self._host, self._port))
self._fh = self._socket.makefile('rwb')
def disconnect(self):
try:
self._fh.close()
self._socket.close()
except:
pass
def reconnect(self):
self.disconnect() # close existing connection first
self._connect() # open a fresh one
if __name__ == '__main__':
from gevent import monkey; monkey.patch_all()
Server().run()