-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmehtftp.py
173 lines (141 loc) · 4.41 KB
/
mehtftp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#!/usr/bin/env python3
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE-examples file in the root directory of this source tree.
import argparse
import logging
import os
from fbtftp.base_handler import BaseHandler
from fbtftp.base_handler import ResponseData
from fbtftp.base_server import BaseServer
from jinja2 import FileSystemLoader
from jinja2.environment import Environment
class FileResponseData(ResponseData):
def __init__(self, path):
self._size = os.stat(path).st_size
self._reader = open(path, 'rb')
def read(self, n):
return self._reader.read(n)
def size(self):
return self._size
def close(self):
self._reader.close()
class JinjaResponseData(ResponseData):
def __init__(self, path):
pass
def read(self, n):
return None
def size(self):
return 0
def close(self):
return None
def create_config_from_template(template_path, attr):
folder, file = path.split(template_path)
env = Environment()
env.loader = FileSystemLoader(folder)
config = env.get_template(file)
return config.render(attr)
def print_session_stats(stats):
logging.info('Stats: for %r requesting %r' % (stats.peer, stats.file_path))
logging.info('Error: %r' % stats.error)
logging.info('Time spent: %dms' % (stats.duration() * 1e3))
logging.info('Packets sent: %d' % stats.packets_sent)
logging.info('Packets ACKed: %d' % stats.packets_acked)
logging.info('Bytes sent: %d' % stats.bytes_sent)
logging.info('Options: %r' % stats.options)
logging.info('Blksize: %r' % stats.blksize)
logging.info('Retransmits: %d' % stats.retransmits)
logging.info('Server port: %d' % stats.server_addr[1])
logging.info('Client port: %d' % stats.peer[1])
def print_server_stats(stats):
'''
Print server stats - see the ServerStats class
'''
# NOTE: remember to reset the counters you use, to allow the next cycle to
# start fresh
counters = stats.get_and_reset_all_counters()
logging.info('Server stats - every %d seconds' % stats.interval)
if 'process_count' in counters:
logging.info(
'Number of spawned TFTP workers in stats time frame : %d' %
counters['process_count']
)
class StaticHandler(BaseHandler):
def __init__(self, server_addr, peer, path, options, root, stats_callback):
print('DATA IN STATIC HANDLER:', server_addr, peer, path, options, root)
self._root = root
super().__init__(server_addr, peer, path, options, stats_callback)
def get_response_data(self):
return JinjaResponseData(os.path.join(self._root, self._path))
class StaticServer(BaseServer):
def __init__(
self,
address,
port,
retries,
timeout,
root,
handler_stats_callback,
server_stats_callback=None
):
self._root = root
self._handler_stats_callback = handler_stats_callback
super().__init__(address, port, retries, timeout, server_stats_callback)
def get_handler(self, server_addr, peer, path, options):
return StaticHandler(
server_addr, peer, path, options, self._root,
self._handler_stats_callback
)
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--ip',
type=str,
default='::',
help='IP address to bind to'
)
parser.add_argument(
'--port',
type=int,
default=69,
help='port to bind to'
)
parser.add_argument(
'--retries',
type=int,
default=5,
help='number of per-packet retries'
)
parser.add_argument(
'--timeout_s',
type=int,
default=2,
help='timeout for packet retransmission'
)
parser.add_argument(
'--root',
type=str,
default='',
help='root of the static filesystem'
)
return parser.parse_args()
def main():
args = get_arguments()
logging.getLogger().setLevel(logging.DEBUG)
server = StaticServer(
args.ip,
args.port,
args.retries,
args.timeout_s,
args.root,
print_session_stats,
print_server_stats,
)
try:
server.run()
except KeyboardInterrupt:
server.close()
if __name__ == '__main__':
main()