diff --git a/free_style/tf_play/batch_auto_train/gomoku_train_batch_host.py b/free_style/tf_play/batch_auto_train/gomoku_train_batch_host.py index c7473eb..6ccb9d9 100755 --- a/free_style/tf_play/batch_auto_train/gomoku_train_batch_host.py +++ b/free_style/tf_play/batch_auto_train/gomoku_train_batch_host.py @@ -342,7 +342,26 @@ def main(): worker_name = "worker_%03d" % i_worker os.chdir(worker_name) with tarfile.open("output.tar.gz") as tar: - tar.extractall() + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar) newblack_learndata = pickle.load(open('newblack.learndata')) print("%d new black learndata loaded from %s" % (len(newblack_learndata), worker_name)) update_learn_data(black_learndata, newblack_learndata) diff --git a/free_style/tf_play/batch_auto_train/gomoku_worker.py b/free_style/tf_play/batch_auto_train/gomoku_worker.py index cbbf381..e0ccdd4 100755 --- a/free_style/tf_play/batch_auto_train/gomoku_worker.py +++ b/free_style/tf_play/batch_auto_train/gomoku_worker.py @@ -232,7 +232,26 @@ def main(): import tarfile with tarfile.open('input.tar.gz') as tar: - tar.extractall() + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar) import construct_dnn model = construct_dnn.construct_dnn() model.load('tf_model/tf_model')