From b9b4819a7b26e19f409144630be3039f71c33ca2 Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Fri, 14 May 2021 13:53:14 -0400 Subject: [PATCH 1/6] initial commit for error injection --- .../__pycache__/add_errors.cpython-36.pyc | Bin 0 -> 4441 bytes .../__pycache__/err_expr_utils.cpython-36.pyc | Bin 0 -> 1293 bytes .../__pycache__/misc_utils.cpython-36.pyc | Bin 0 -> 2032 bytes error_genration/add_errors.py | 175 ++++++++++++++++++ error_genration/calculate_stats.py | 119 ++++++++++++ error_genration/err_expr_utils.py | 38 ++++ error_genration/misc_utils.py | 52 ++++++ 7 files changed, 384 insertions(+) create mode 100644 error_genration/__pycache__/add_errors.cpython-36.pyc create mode 100644 error_genration/__pycache__/err_expr_utils.cpython-36.pyc create mode 100644 error_genration/__pycache__/misc_utils.cpython-36.pyc create mode 100644 error_genration/add_errors.py create mode 100644 error_genration/calculate_stats.py create mode 100644 error_genration/err_expr_utils.py create mode 100644 error_genration/misc_utils.py diff --git a/error_genration/__pycache__/add_errors.cpython-36.pyc b/error_genration/__pycache__/add_errors.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ecb0c2b2ec265dddd7e5484b580f3e546aed217 GIT binary patch literal 4441 zcmai1TaO$^6|SnjOwV=K>-BoQxh=vY8}G>Vb#}N+NLyFFX(>;+0nfgpw`~^ZX$x{|qB@el6i%{6o30+%yp>NCDeh}7C6KP}E z44cDN*c!IO_OKInIK9D1mxjyXGU~4M-eut_xgl#Zc$bGOvMw9wIW3#Ag>qH4We4RM z`H)#Y08U7VfmbY0Wg zOLU~-Oy6GcB!+iFnUig%;EJW02E`EaC?q8PlLJUtl2ds((!eePb8Df_M4 zv6PhximlNdvrNqgbwL{zE(8c1m8IDfh2U$PVLUMoyg-prj0Z66C8nWsct$TB18vT5kacTe zgPg*msW4tVij_7Z&B7JHaIlv};}MKcPPLbW0s1OH)_VCk(}@b}7N`VLi^NKFXvYq%Rtx!#BC$8+@6&VvTqBI$E22g}Yomin$9^5$K4pEUJEo zO3C4T7FQZ}+ZQv`i4;(GsCD40S2U18^?gsmO=ny>2TWa-0rZbq(jyCGFPA{x6g{|ntG?IK~_CNO|CI&O`XGQuq>A~>-O_ES$D7R9BmBxn^Ba-!z79}pL=fe zi6|PzgDi@kxM+CeF44%<*Np=Pbbo{Nf_i~ZEGhjvqYAu-fPF7!n=zN2*w*QO9zH= zMN7CSS73?;%%r*)?GZrfM@FEtT0^{r%4A(O;h4))uwY4Q_esHr{te5AGgdrKZO!)u z@g70Ir*)MRL`8)9c-u`Vyw7FO#58v@6upRNEr44biPZsag$2MyOpAyq3rpVwGp6 zF(^RQI?j4Y=n*#;Ve`d&ID{#7&s0?D2_g`CK#KzC zm0URNj8wj-;$c+eW7Pw-9JkX<5%;S{NIXGP*ZN=!iRh~kdUcfMJsW9G=*TQvTvlP% z0`q%VTp(ouFiBAX)V9ohJRP*5w69~t)aO2s$OTbQP@=D}`ui6uggVJ7)O>{UV+tq& z8UzlB8m7*j54E%JBGB+DB&dRF6l6St`4x?~rtPv>wxmz?XtCn=Xtf_`wPYQ^;}hA0Uo^{F8I-NEUbbf*XE#{(q|0>c zprQW8*%bNIQ46_KtL%Kl`mhwx(D`ZSCdnHd$MDg6c<}r5 z`7T!vp<)ed*u675x!$xE*^D>$@?36O`lfJG#T-X{>$t8i9@n~UQ&Y((jeCi5sPh>T zUmt5M&PXeD6GxF zR6(u1fEGNhP+y}rT8P!D)`STC0RoLu;mRBh7C%)~Qw^Hcf-wAZQEKUU{(Os@m{x2; zFun>7B=#d;Z6Qc?__82hT}9rr&eijnvtY!}c^64a-69<_28Ml3MoL94iF3MeZ+uTtUi1PLC+~&^a7oW-p*of z?PI(V^WgiG%-94Ym?PuJK}M&(f#U%Z#70NX#m446{?ySDN|5R>)|)mR?7G(Vk&$My zxINIf?B^OkMmB(O1=g(6=hob0Rm-<)3*%DL*!V-W&`aapBsHELKI!#mPgo$jt!NjO z<9wC#BNDmR3ad6LSQBevl`A5Ldtz&Q!I}gHGBu&IID*g&Svj;Y6Cc19dq~{Dd9H?5 zXY4nxeZbyh@A1AMUU)YEFIsLTzn@&d9FKe-mOD?j*!Rinr?qJSo7KuH%k>?TUfr3+{fdy+She-#fy?17HnT%GdwK}wpxkC2HRR{T~Ry4S08zAHk% z81D`T+I}S69k!kcn?NY;Ek%@#hv>pxA?b}ZGQCq!R&+dQ4N<9>vaoxNOX;A{l=WYp z(E+N{96j5^rDhB|&X5Vrd3AN(`SES4b{G9dE# zb-wISP5lUcJsO}NVavrshqg3P#pEdrS;F#C{6f%$1Xw|aGe^SFrR`h!Fxegs(s=v$ zHnwGNW5_FdNIr-XFW5AEZWCP7>_rv!zALd6zK3?L;NXO(>|FTk;vTE6)9~vg=zew| z0LG=f^VjIB?rb~|>IPaSIL=5x1@zr0=EU#{wH`7eFG9b%OdGeCkM4xfvD~Hf$kIOY zB4ns{)DJK{T)D5_t^OGIW+Y6tc?Fj-%-FEbA1$t#Tf;%oi^h787FB{ls;$0~QlIeA zlk?H8t`gN68(=8VR{^F=HA_~I%de?Uc6Sr#C{rC~0|X>ys8Tzj>z eGBil1A=VKA02v3C>Jg{w1HOQ6sk&a?u83lsVtk=*)HrKwo{W&yh2XQI3=VNmiDv9J2Rg7W@c|?rTOc})!UyAA-~DOa^ZXcpLz|z z38yK^Xh1U-u*?Y@JX69Mcg~1#xjUvW$(Zumwd~F0brRHgo%=BhJl^2VGZNJK5?_Yh z=PR#ByEOstr0tqkEOaQNoEKS`BD-iT_VA)hg==sMs$BN{SACmf)U+PF{`&20eR;82kZXoQLZvP(CnBTp0V zK=IkAUe)EBswNMRX3%i}YgXU8*7x(x|NL{UFMs6cU9kRgLj&LOtPgg1+ZyfW|3-f0 zy`z;shL)^6C~y66N}n6rl$exU1(?##lx_nlNVoB#vJ^RQHv;!0l20w}BBu77&dz&g zIZ@#$mNWzimWrV+huYRNX!gXQkRnvWo^jJ87q>_@TiB`B#?S!RCPf~qB+er}lmd&R zS^$LJrA_M6eb$9t-iMn$UgN?d&D;4I2)6VT_$HcZ2gfg==M-83Mv&Lwo@rpcf*K9z zm2*yDXTXPdII(%VDmfJs^!ksPwkPqzw19~u>ZKygqf97?@0vzF%y2AYVTrTRwVhyj zHjO+48}(OGAf-yKqxYK)XEx>}!f=#A0s%ccA9p}mVUH2Wy@mNdg7aa!VK#3~(oUAB x{%ne7q1Y^JkhNKW55ZDyAs~m=(Lx+HBJ616ce8>IQ?ZZDVISXMkE!qY%|9s(HM{@- literal 0 HcmV?d00001 diff --git a/error_genration/__pycache__/misc_utils.cpython-36.pyc b/error_genration/__pycache__/misc_utils.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa0453c5ef464253e489137c17eccdce297191b0 GIT binary patch literal 2032 zcmb7EOK;mo5Z+yq67{g6IE|6Mfub)Ku@$tJ9tqAol0ZVIFGF6JCcd0lw z3YuRV$!4&oSp9d%;1T* zVsl~bh`9jD##?{_m>Otz2;c*VF`Ki;Ujh)|O@x)sH3KvOlkVpHl{j^e*qmzZ zaT$L-DAn*}x72ZFMoJHKVou`YXq>BbFwCcgR>mYx)NWG5<18B}t#ciXRi@)I$+N*U zF~ew9CaKvg&fm26)5Mf??SLFLwZ|91Se4N*mui%2HB`oeF?;)IK8#ayfWhik*f)6! zb$NzE=jBPBt;0*YkZ3h~?_prs+g`C(^H%d@{R|vUnV=~xG@%KEZlVi)n6UhSo8E@G z6+ya1H`)+gw51Pg7^IU_9?{+=44Xs#C8GYyIqg!oEQ}QvGHtQc_SfO453Nn!BDFPb zfo~nCn#2wT+ADODm6el%{k2zBLp{a5sTy%ns7wZ~_Ms!(T4e%RGU3|Mn)|T3xm{xv z_M+|(P{LE){m-yb49q(ajNjv3?(#P8p})s_qQ~`ZtiLAi)`U@cl0)C2S;8o<=O`Y| zN-tA6|qUPvPkj@L|C9-GECw##sVu6PM1 z!kT3Xo}DQciA((%i7!apAfvy5WrR?b#KHEyeb0A&r|tX$tNgJ! literal 0 HcmV?d00001 diff --git a/error_genration/add_errors.py b/error_genration/add_errors.py new file mode 100644 index 00000000..f631c0fc --- /dev/null +++ b/error_genration/add_errors.py @@ -0,0 +1,175 @@ +import os +import random +import copy +import redbaron as rb +import numpy as np + + +import concurrent.futures as cf + +from error_genration.misc_utils import ( + get_random_int, + load_dataa, + get_codeforeces_paths, + write_csv, +) +from error_genration.err_expr_utils import zerro_error_perturbation + + +def add_perturbation(perturb_node, expr_ass, expr_ass_line, expr_err, expr_err_line): + # import pdb;pdb.set_trace() + print("yo") + perturb_node.at(expr_ass_line).insert_before(expr_ass) + print("yo1") + perturb_node.at(expr_err_line).insert_before(expr_err) + + +def get_ass_expr_lines(code_lines, apart=0.75): + ln = len(code_lines) + assert ln > 0 + assign_upper_range = max(1, int((1 - apart) * ln)) + if assign_upper_range == 1: + ass_line = 1 + else: + ass_line = get_random_int(1, assign_upper_range) + + for counter in range(20): + if ( + to_include(code_lines[ass_line - 1]) + or ass_line == ln + or assign_upper_range == 1 + ): + break + ass_line = get_random_int(1, int((1 - apart) * ln)) + if not to_include(code_lines[ass_line - 1]): + return -1, -1 + + expr_line = min(ln, ass_line + int(apart * ln)) + for counter in range(20): + if to_include(code_lines[expr_line - 1]) or expr_line == ln: + break + expr_line = get_random_int(min(ln, ass_line + int(apart * ln)), ln) + if not to_include(code_lines[expr_line - 1]): + return -1, -1 + return ass_line, expr_line + + +def get_parent_node(line_to_perturb, red): + selected_node = red.at(line_to_perturb) + print(selected_node.dumps()) + if "if __name__ == '__main__':" in selected_node.parent.dumps(): + if len(selected_node.dumps().split("\n")) < 5: + return None, False + else: + return selected_node.parent, True + else: + out_node = selected_node + while out_node != red and "def" not in out_node.dumps(): + if out_node.dumps() == out_node.parent.dumps(): + break + out_node = out_node.parent + return out_node, True + + +def to_include(line): + for token in ["if", "while", "for", "def", "class", "import", "else"]: + if token in line: + return False + if not line.strip(): + return False + line_rb = rb.RedBaron(line.strip()) + if isinstance(line_rb[0], rb.nodes.CommentNode) or isinstance( + line_rb[0], rb.nodes.EndlNode + ): + return False + return True + + +def get_perturb_node(red, program_source, program_ln): + perturb_node = None + for counter in range(20): + line_to_perturb = get_random_int(1, program_ln) + if to_include(program_source[line_to_perturb - 1]): + break + if to_include(program_source[line_to_perturb - 1]): + perturb_node, found_correct_location = get_parent_node(line_to_perturb, red) + if not perturb_node: + perturb_node = red + return perturb_node + + +def perturb_program(program_fp, suffx="perturbed"): + output_fp = program_fp.replace(".txt", f"_{suffx}.txt") + + program = load_dataa(program_fp).strip() + + try: + red = rb.RedBaron(program) + program_lines = program.split("\n") + program_ln = len(program_lines) + perturb_node = get_perturb_node(red, program_lines, program_ln) + + [expr_ass, expr_err], is_err = zerro_error_perturbation() + perturb_node_lines = perturb_node.dumps().split("\n") + expr_ass_line, expr_err_line = get_ass_expr_lines( + perturb_node_lines, apart=0.75 + ) + + if ( + not perturb_node_lines[expr_ass_line - 1] + or not perturb_node_lines[expr_err_line - 1] + or expr_ass_line == -1 + ): + return "", -1, "Not found a good line" + add_perturbation(perturb_node, expr_ass, expr_ass_line, expr_err, expr_err_line) + # write_csv(program.dumps(), output_fp) + except Exception as e: + return "", -1, f"{e}" + return output_fp, is_err, None + + +def perturb_program_wrapper(paths): + + output_paths = [] + errors = [] + for path in paths: + print(path) + if path.endswith(".txt") or path.endswith(".py"): + out_path, label, error = perturb_program(path, suffx="perturbed") + if error: + errors.append(f"{path}:\n {error}") + else: + output_paths.append(f"{out_path},{label}") + else: + errors.append(f"{path}:\n format error") + return (output_paths, errors) + + +def concurrent_program_perturbation(paths, num_processes, out_path="./"): + per_process_paths = np.array_split(paths, num_processes) + output_paths, errors = [], [] + with cf.ProcessPoolExecutor() as executor: + results = [ + executor.submit(perturb_program_wrapper, per_process_paths[process_num]) + for process_num in range(num_processes) + ] + for completed in cf.as_completed(results): + res_out_paths, res_errs = completed.result() + output_paths.extend(res_out_paths) + errors.append(res_errs) + # write_csv("\n".join(output_paths), f"{out_path}/label_file.csv") + # write_csv("\n".join(errors), f"{out_path}/error_file.csv") + + +def main(): + code_forces_paths = get_codeforeces_paths( + "/home/mila/r/rishab.goel/description2code_current/codeforces" + ) + # concurrent_program_perturbation(code_forces_paths[:4], 3) + res_out_paths, res_errs = perturb_program_wrapper(code_forces_paths[:100]) + # import pdb;pdb.set_trace() + + +if __name__ == "__main__": + # data = load_dataa("/home/mila/r/rishab.goel/description2code_current/codechef/easy/ACBALL/solutions_python/10211792.txt") + main() diff --git a/error_genration/calculate_stats.py b/error_genration/calculate_stats.py new file mode 100644 index 00000000..40e7cb8b --- /dev/null +++ b/error_genration/calculate_stats.py @@ -0,0 +1,119 @@ +import redbaron as rb +import os +from collections import defaultdict +import matplotlib + +matplotlib.use("agg") +import matplotlib.pyplot as plt +import json +import math + + +def load_file(file_name): + with open(file_name, "r") as file: + code = file.read() + return code + + +def get_codeforeces_paths(base_path): + problem_paths = [ + os.path.join(base_path, problem_name_dir) + for problem_name_dir in os.listdir(base_path) + if os.path.isdir(os.path.join(base_path, problem_name_dir)) + ] + print(len(problem_paths)) + solution_paths = [] + for problem_path in problem_paths: + solutions_path = os.path.join(problem_path, "solutions_python") + if os.path.exists(solutions_path): + solutions_path = [ + os.path.join(solutions_path, sol_name) + for sol_name in os.listdir(solutions_path) + ] + solution_paths.append(solutions_path) + + solution_paths = [sol_path for path in solution_paths for sol_path in path] + return solution_paths + + +def ceil(x): + return int(math.ceil(x / 10.0)) * 10 + + +def get_depth_stats_wrapper(files, block_types=["def", "class"]): + stats_dict = {block_type: defaultdict(int) for block_type in block_types} + stats_dict["len"] = defaultdict(int) + err_files = [] + print(len(files)) + for file in files: + code = load_file(file) + code_ln = len(code.strip().split("\n")) + stats_dict["len"][ceil(code_ln)] += 1 + is_err = get_depth_stats(code, stats_dict, block_types) + if is_err: + err_files.append(file) + return stats_dict, err_files + + +def get_depth_stats(code, stats_dict, block_types): + try: + red = rb.RedBaron(code) + except Exception as e: + return True + for block_type in block_types: + curr_max = 0 + for block in red.find_all(block_type): + depth = 0 + while block != red: + depth += 1 + # while block.dumps()==block.parent.dumps(): + # import pdb;pdb.set_trace() + block = block.parent + curr_max = max(curr_max, depth) + stats_dict[block_type][curr_max] += 1 + return False + + +def plot_data(x, y, xlabel, ylabel, name): + plt.clf() + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.bar(x, y) + plt.savefig(name + ".pdf") + # plt.show() + + +def save_csv(data, name): + with open(name, "w") as file: + file.write("\n".join(data)) + + +def save_json(data, name): + with open(name, "w") as file: + json.dump(data, file) + + +def plot_wrapper(data): + for key, value in data.items(): + x, y = zip(*value.items()) + plot_data(x, y, key, "Count", key) + + +def get_codechef_paths(base_path): + pass + + +def main(): + paths = get_codeforeces_paths( + "/home/mila/r/rishab.goel/description2code_current/codeforces" + ) + codechef_paths = get_codechef_paths( + "/home/mila/r/rishab.goel/description2code_current/codechef" + ) + data, err_files = get_depth_stats_wrapper(paths) + save_csv(err_files, "error.csv") + save_csv(data, "calculated_stats.json") + plot_wrapper(data) + + +main() diff --git a/error_genration/err_expr_utils.py b/error_genration/err_expr_utils.py new file mode 100644 index 00000000..d642e620 --- /dev/null +++ b/error_genration/err_expr_utils.py @@ -0,0 +1,38 @@ +from error_genration.misc_utils import ( + get_random_int, + get_random_list_sample, + get_random_float, + get_random_int, +) + +variable_names = [f"tmp{i}" for i in range(10)] + [chr(ord("a") + i) for i in range(26)] +num_range = [1, 100] + + +def get_zerro_expression_signature(var1, var2, val1, val2, val3, is_zerro_err): + output_expr = [f"{var1}={val1}\n"] + before_sub = get_random_int(0, 1) + if before_sub: + line = ( + f"{var2}={val2}/{var1}-{val1}\n" + if is_zerro_err + else f"{var2}={val2}/{var1}-{val3}\n" + ) + else: + line = ( + f"{var2}={val2}/{val1}-{var1}\n" + if is_zerro_err + else f"{var2}={val2}/{val3}-{var1}\n" + ) + output_expr.append(line) + return output_expr + + +def zerro_error_perturbation(): + sampled_vars = get_random_list_sample(variable_names, 2) + samples_vals = get_random_float(*num_range, size=3) + is_zerro_err = get_random_int(0, 1) + return ( + get_zerro_expression_signature(*sampled_vars, *samples_vals, is_zerro_err), + is_zerro_err, + ) diff --git a/error_genration/misc_utils.py b/error_genration/misc_utils.py new file mode 100644 index 00000000..fef5d99e --- /dev/null +++ b/error_genration/misc_utils.py @@ -0,0 +1,52 @@ +import random +import numpy as np +import os + + +def get_codeforeces_paths(base_path): + problem_paths = [ + os.path.join(base_path, problem_name_dir) + for problem_name_dir in os.listdir(base_path) + if os.path.isdir(os.path.join(base_path, problem_name_dir)) + ] + print(len(problem_paths)) + solution_paths = [] + for problem_path in problem_paths: + solutions_path = os.path.join(problem_path, "solutions_python") + if os.path.exists(solutions_path): + solutions_path = [ + os.path.join(solutions_path, sol_name) + for sol_name in os.listdir(solutions_path) + ] + solution_paths.append(solutions_path) + + solution_paths = [sol_path for path in solution_paths for sol_path in path] + return solution_paths + + +def set_seeds(seed=10): + random.seed(seed) + np.random.seed(seed) + + +def load_dataa(fp): + with open(fp, "r", encoding="utf-8") as file: + data = file.read().strip() + return data + + +def write_csv(data, fp): + with open(fp, "w") as file: + file.write(data) + + +def get_random_int(lower_limit, upper_limit): + return random.randint(lower_limit, upper_limit) + + +def get_random_float(lower_limit, upper_limit, size=None): + return np.random.uniform(lower_limit, upper_limit, size=size) + + +def get_random_list_sample(lst, num_samples): + return random.sample(lst, num_samples) From c27c35015ed8a72d939f9d7699ab4567607293ee Mon Sep 17 00:00:00 2001 From: RishabGoel Date: Fri, 14 May 2021 13:58:50 -0400 Subject: [PATCH 2/6] added licence --- error_genration/add_errors.py | 17 ++++++++++++++--- error_genration/calculate_stats.py | 14 ++++++++++++++ error_genration/err_expr_utils.py | 14 ++++++++++++++ error_genration/misc_utils.py | 14 ++++++++++++++ 4 files changed, 56 insertions(+), 3 deletions(-) diff --git a/error_genration/add_errors.py b/error_genration/add_errors.py index f631c0fc..2706f79c 100644 --- a/error_genration/add_errors.py +++ b/error_genration/add_errors.py @@ -1,3 +1,17 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import random import copy @@ -17,10 +31,7 @@ def add_perturbation(perturb_node, expr_ass, expr_ass_line, expr_err, expr_err_line): - # import pdb;pdb.set_trace() - print("yo") perturb_node.at(expr_ass_line).insert_before(expr_ass) - print("yo1") perturb_node.at(expr_err_line).insert_before(expr_err) diff --git a/error_genration/calculate_stats.py b/error_genration/calculate_stats.py index 40e7cb8b..99a121e2 100644 --- a/error_genration/calculate_stats.py +++ b/error_genration/calculate_stats.py @@ -1,3 +1,17 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import redbaron as rb import os from collections import defaultdict diff --git a/error_genration/err_expr_utils.py b/error_genration/err_expr_utils.py index d642e620..0e636458 100644 --- a/error_genration/err_expr_utils.py +++ b/error_genration/err_expr_utils.py @@ -1,3 +1,17 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from error_genration.misc_utils import ( get_random_int, get_random_list_sample, diff --git a/error_genration/misc_utils.py b/error_genration/misc_utils.py index fef5d99e..59380c1c 100644 --- a/error_genration/misc_utils.py +++ b/error_genration/misc_utils.py @@ -1,3 +1,17 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import random import numpy as np import os From 7b4db670d0e6132abac21f168ec1a8eb567256b3 Mon Sep 17 00:00:00 2001 From: Rishab Goel Date: Fri, 4 Jun 2021 14:59:17 -0400 Subject: [PATCH 3/6] added trace based 0 div err --- error_generation/add_code.py | 42 ++++ error_generation/get_trace.py | 56 ++++++ error_generation/main.py | 35 ++++ .../misc_utils.py | 19 +- error_generation/trace_code.py | 28 +++ .../__pycache__/add_errors.cpython-36.pyc | Bin 4441 -> 0 bytes .../__pycache__/err_expr_utils.cpython-36.pyc | Bin 1293 -> 0 bytes .../__pycache__/misc_utils.cpython-36.pyc | Bin 2032 -> 0 bytes error_genration/add_errors.py | 186 ------------------ error_genration/calculate_stats.py | 133 ------------- error_genration/err_expr_utils.py | 52 ----- 11 files changed, 176 insertions(+), 375 deletions(-) create mode 100644 error_generation/add_code.py create mode 100644 error_generation/get_trace.py create mode 100644 error_generation/main.py rename {error_genration => error_generation}/misc_utils.py (74%) create mode 100644 error_generation/trace_code.py delete mode 100644 error_genration/__pycache__/add_errors.cpython-36.pyc delete mode 100644 error_genration/__pycache__/err_expr_utils.cpython-36.pyc delete mode 100644 error_genration/__pycache__/misc_utils.cpython-36.pyc delete mode 100644 error_genration/add_errors.py delete mode 100644 error_genration/calculate_stats.py delete mode 100644 error_genration/err_expr_utils.py diff --git a/error_generation/add_code.py b/error_generation/add_code.py new file mode 100644 index 00000000..aa8b4797 --- /dev/null +++ b/error_generation/add_code.py @@ -0,0 +1,42 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import redbaron as rb +from misc_utils import get_random_list_sample, load_json, load_data, write_csv + +def get_perturb_line_step(code_trace): + perturb_line = get_random_list_sample(code_trace.keys(), 1)[0] + perturb_step = get_random_list_sample(code_trace[perturb_line],1)[0] + # import pdb;pdb.set_trace() + perturb_var = get_random_list_sample(perturb_step.keys(),1)[0] + perturb_val = perturb_step[perturb_var] + return int(perturb_line), perturb_var, int(perturb_val) + +def get_perturb_expression(perturb_var, perturb_val): + return 'tmp1 = 1/'+str(perturb_val)+ '-'+ perturb_var, True + +def perturb_program(red, code_trace): + perturb_line, perturb_var, perturb_val = get_perturb_line_step(code_trace) + perturb_expression, is_err_present = get_perturb_expression(perturb_var, perturb_val) + # print perturb_expression, perturb_line + # import pdb;pdb.set_trace() + red.at(perturb_line).insert_after(perturb_expression) + +def add_error(org_code_fp, code_trace_fp, suffx): + code_trace = load_json(code_trace_fp) + err_code_fp = org_code_fp.replace(".txt", "_"+suffx+".txt") + program = load_data(org_code_fp).strip() + red = rb.RedBaron(program) + _ = perturb_program(red, code_trace) + write_csv(red.dumps(), err_code_fp) \ No newline at end of file diff --git a/error_generation/get_trace.py b/error_generation/get_trace.py new file mode 100644 index 00000000..3369a2d9 --- /dev/null +++ b/error_generation/get_trace.py @@ -0,0 +1,56 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import sys +from collections import defaultdict +import subprocess + +def postprocess_and_save(json_fp, offset, processed_suffix): + data = json.load(open(json_fp,"rb")) + processed_data = {} + for key, val in data.items(): + val = [v for v in val if v] + if val: + processed_data[int(key)-offset] = val + out_path=json_fp.replace(".json", "_"+str(processed_suffix)+".json") + open(out_path, 'w').write(json.dumps(processed_data)) + + +def run_for_errors(python_filepath, data_path, trace_path, stdin_file, stdout_file, stderr_file, processed_suffix = 'processed'): + # Assumes the input is stdin when called. + trace_source = open(trace_path, 'r').read() + python_source = open(python_filepath, 'r').read() + python_source = python_source.replace('__name__ == "__main__"', 'True') + python_source = python_source.replace("__name__ == '__main__'", 'True') + python_source = ( + 'import json\n' + + 'import sys\n' + + 'def main__errorchecker__():\n' + + '\n'.join(' ' + line for line in python_source.split('\n')) + + '\n' + + trace_source + + '\nsys.settrace(trace_calls)\n' + + 'main__errorchecker__()\n' + + 'print "yo"\n' + + 'open("' + data_path + '","w").write(json.dumps(data, indent=4, sort_keys=True))\n' + ) + try: + subprocess_call = subprocess.check_call(['python', '-c', python_source], stdin=open(stdin_file, 'rb'), stdout=open(stdout_file, 'wb'), stderr=open(stderr_file, 'wb')) + except Exception as e: + return False + postprocess_and_save(data_path, 3, processed_suffix) + return True + diff --git a/error_generation/main.py b/error_generation/main.py new file mode 100644 index 00000000..5359c9cc --- /dev/null +++ b/error_generation/main.py @@ -0,0 +1,35 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from misc_utils import get_codeforeces_paths +from get_trace import run_for_errors +from add_code import add_error + +def main(base_path, trace_code_path, process_suffix="processed"): + code_inp_data_paths = get_codeforeces_paths(base_path) + for code_path, inp_paths in code_inp_data_paths: + for idx, inp_path in enumerate(inp_paths): + err_path = inp_path.replace(".txt", "_error.txt") + out_path = inp_path.replace(".txt", "_out.txt") + out_code_path = code_path.replace(".txt", "_"+str(idx)+"_perturbed.txt") + data_trace_path = code_path.replace(".txt", "_"+str(idx)+"_trace.json") + trace_successful = run_for_errors(code_path, data_trace_path, trace_code_path, inp_path, out_path, err_path, process_suffix) + if trace_successful: + data_trace_path = data_trace_path.replace(".json", "_"+process_suffix+".json") + add_error(code_path, data_trace_path, "zero_err") + + +main("/Users/rishabgoel/Downloads/description2code_current/codeforces", "trace_code.py") \ No newline at end of file diff --git a/error_genration/misc_utils.py b/error_generation/misc_utils.py similarity index 74% rename from error_genration/misc_utils.py rename to error_generation/misc_utils.py index 59380c1c..7b35ebff 100644 --- a/error_genration/misc_utils.py +++ b/error_generation/misc_utils.py @@ -15,6 +15,12 @@ import random import numpy as np import os +import json + +def get_codeforces_inp_data_paths(base_path): + inp_data_directory = os.path.join(base_path, "samples") + inp_data_paths = [inp_data_directory+"/"+path for path in os.listdir(inp_data_directory) if "input" in path] + return inp_data_paths def get_codeforeces_paths(base_path): @@ -27,9 +33,11 @@ def get_codeforeces_paths(base_path): solution_paths = [] for problem_path in problem_paths: solutions_path = os.path.join(problem_path, "solutions_python") + inp_data_paths = get_codeforces_inp_data_paths(problem_path) + # import pdb;pdb.set_trace() if os.path.exists(solutions_path): solutions_path = [ - os.path.join(solutions_path, sol_name) + (os.path.join(solutions_path, sol_name), inp_data_paths) for sol_name in os.listdir(solutions_path) ] solution_paths.append(solutions_path) @@ -43,11 +51,14 @@ def set_seeds(seed=10): np.random.seed(seed) -def load_dataa(fp): - with open(fp, "r", encoding="utf-8") as file: +def load_data(fp): + with open(fp, "r") as file: data = file.read().strip() return data +def load_json(fp): + with open(fp, "r") as file: + return json.load(file) def write_csv(data, fp): with open(fp, "w") as file: @@ -63,4 +74,4 @@ def get_random_float(lower_limit, upper_limit, size=None): def get_random_list_sample(lst, num_samples): - return random.sample(lst, num_samples) + return random.sample(lst, num_samples) \ No newline at end of file diff --git a/error_generation/trace_code.py b/error_generation/trace_code.py new file mode 100644 index 00000000..f5632285 --- /dev/null +++ b/error_generation/trace_code.py @@ -0,0 +1,28 @@ +from collections import defaultdict + +data = defaultdict(list) + +def trace_lines(frame, event, arg): + if event != 'line': + return + co = frame.f_code + func_name = co.co_name + line_no = frame.f_lineno + filename = co.co_filename + if filename=="": + locals_dict = {} + for key, value in frame.f_locals.items(): + try: + json.dumps(value) + locals_dict[key] = value + except Exception as e: + _ = "" + data[line_no].append(locals_dict) + + +def trace_calls(frame, event, arg): + if event!="call": + return + co = frame.f_code + func_name = co.co_name + return trace_lines diff --git a/error_genration/__pycache__/add_errors.cpython-36.pyc b/error_genration/__pycache__/add_errors.cpython-36.pyc deleted file mode 100644 index 2ecb0c2b2ec265dddd7e5484b580f3e546aed217..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4441 zcmai1TaO$^6|SnjOwV=K>-BoQxh=vY8}G>Vb#}N+NLyFFX(>;+0nfgpw`~^ZX$x{|qB@el6i%{6o30+%yp>NCDeh}7C6KP}E z44cDN*c!IO_OKInIK9D1mxjyXGU~4M-eut_xgl#Zc$bGOvMw9wIW3#Ag>qH4We4RM z`H)#Y08U7VfmbY0Wg zOLU~-Oy6GcB!+iFnUig%;EJW02E`EaC?q8PlLJUtl2ds((!eePb8Df_M4 zv6PhximlNdvrNqgbwL{zE(8c1m8IDfh2U$PVLUMoyg-prj0Z66C8nWsct$TB18vT5kacTe zgPg*msW4tVij_7Z&B7JHaIlv};}MKcPPLbW0s1OH)_VCk(}@b}7N`VLi^NKFXvYq%Rtx!#BC$8+@6&VvTqBI$E22g}Yomin$9^5$K4pEUJEo zO3C4T7FQZ}+ZQv`i4;(GsCD40S2U18^?gsmO=ny>2TWa-0rZbq(jyCGFPA{x6g{|ntG?IK~_CNO|CI&O`XGQuq>A~>-O_ES$D7R9BmBxn^Ba-!z79}pL=fe zi6|PzgDi@kxM+CeF44%<*Np=Pbbo{Nf_i~ZEGhjvqYAu-fPF7!n=zN2*w*QO9zH= zMN7CSS73?;%%r*)?GZrfM@FEtT0^{r%4A(O;h4))uwY4Q_esHr{te5AGgdrKZO!)u z@g70Ir*)MRL`8)9c-u`Vyw7FO#58v@6upRNEr44biPZsag$2MyOpAyq3rpVwGp6 zF(^RQI?j4Y=n*#;Ve`d&ID{#7&s0?D2_g`CK#KzC zm0URNj8wj-;$c+eW7Pw-9JkX<5%;S{NIXGP*ZN=!iRh~kdUcfMJsW9G=*TQvTvlP% z0`q%VTp(ouFiBAX)V9ohJRP*5w69~t)aO2s$OTbQP@=D}`ui6uggVJ7)O>{UV+tq& z8UzlB8m7*j54E%JBGB+DB&dRF6l6St`4x?~rtPv>wxmz?XtCn=Xtf_`wPYQ^;}hA0Uo^{F8I-NEUbbf*XE#{(q|0>c zprQW8*%bNIQ46_KtL%Kl`mhwx(D`ZSCdnHd$MDg6c<}r5 z`7T!vp<)ed*u675x!$xE*^D>$@?36O`lfJG#T-X{>$t8i9@n~UQ&Y((jeCi5sPh>T zUmt5M&PXeD6GxF zR6(u1fEGNhP+y}rT8P!D)`STC0RoLu;mRBh7C%)~Qw^Hcf-wAZQEKUU{(Os@m{x2; zFun>7B=#d;Z6Qc?__82hT}9rr&eijnvtY!}c^64a-69<_28Ml3MoL94iF3MeZ+uTtUi1PLC+~&^a7oW-p*of z?PI(V^WgiG%-94Ym?PuJK}M&(f#U%Z#70NX#m446{?ySDN|5R>)|)mR?7G(Vk&$My zxINIf?B^OkMmB(O1=g(6=hob0Rm-<)3*%DL*!V-W&`aapBsHELKI!#mPgo$jt!NjO z<9wC#BNDmR3ad6LSQBevl`A5Ldtz&Q!I}gHGBu&IID*g&Svj;Y6Cc19dq~{Dd9H?5 zXY4nxeZbyh@A1AMUU)YEFIsLTzn@&d9FKe-mOD?j*!Rinr?qJSo7KuH%k>?TUfr3+{fdy+She-#fy?17HnT%GdwK}wpxkC2HRR{T~Ry4S08zAHk% z81D`T+I}S69k!kcn?NY;Ek%@#hv>pxA?b}ZGQCq!R&+dQ4N<9>vaoxNOX;A{l=WYp z(E+N{96j5^rDhB|&X5Vrd3AN(`SES4b{G9dE# zb-wISP5lUcJsO}NVavrshqg3P#pEdrS;F#C{6f%$1Xw|aGe^SFrR`h!Fxegs(s=v$ zHnwGNW5_FdNIr-XFW5AEZWCP7>_rv!zALd6zK3?L;NXO(>|FTk;vTE6)9~vg=zew| z0LG=f^VjIB?rb~|>IPaSIL=5x1@zr0=EU#{wH`7eFG9b%OdGeCkM4xfvD~Hf$kIOY zB4ns{)DJK{T)D5_t^OGIW+Y6tc?Fj-%-FEbA1$t#Tf;%oi^h787FB{ls;$0~QlIeA zlk?H8t`gN68(=8VR{^F=HA_~I%de?Uc6Sr#C{rC~0|X>ys8Tzj>z eGBil1A=VKA02v3C>Jg{w1HOQ6sk&a?u83lsVtk=*)HrKwo{W&yh2XQI3=VNmiDv9J2Rg7W@c|?rTOc})!UyAA-~DOa^ZXcpLz|z z38yK^Xh1U-u*?Y@JX69Mcg~1#xjUvW$(Zumwd~F0brRHgo%=BhJl^2VGZNJK5?_Yh z=PR#ByEOstr0tqkEOaQNoEKS`BD-iT_VA)hg==sMs$BN{SACmf)U+PF{`&20eR;82kZXoQLZvP(CnBTp0V zK=IkAUe)EBswNMRX3%i}YgXU8*7x(x|NL{UFMs6cU9kRgLj&LOtPgg1+ZyfW|3-f0 zy`z;shL)^6C~y66N}n6rl$exU1(?##lx_nlNVoB#vJ^RQHv;!0l20w}BBu77&dz&g zIZ@#$mNWzimWrV+huYRNX!gXQkRnvWo^jJ87q>_@TiB`B#?S!RCPf~qB+er}lmd&R zS^$LJrA_M6eb$9t-iMn$UgN?d&D;4I2)6VT_$HcZ2gfg==M-83Mv&Lwo@rpcf*K9z zm2*yDXTXPdII(%VDmfJs^!ksPwkPqzw19~u>ZKygqf97?@0vzF%y2AYVTrTRwVhyj zHjO+48}(OGAf-yKqxYK)XEx>}!f=#A0s%ccA9p}mVUH2Wy@mNdg7aa!VK#3~(oUAB x{%ne7q1Y^JkhNKW55ZDyAs~m=(Lx+HBJ616ce8>IQ?ZZDVISXMkE!qY%|9s(HM{@- diff --git a/error_genration/__pycache__/misc_utils.cpython-36.pyc b/error_genration/__pycache__/misc_utils.cpython-36.pyc deleted file mode 100644 index fa0453c5ef464253e489137c17eccdce297191b0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2032 zcmb7EOK;mo5Z+yq67{g6IE|6Mfub)Ku@$tJ9tqAol0ZVIFGF6JCcd0lw z3YuRV$!4&oSp9d%;1T* zVsl~bh`9jD##?{_m>Otz2;c*VF`Ki;Ujh)|O@x)sH3KvOlkVpHl{j^e*qmzZ zaT$L-DAn*}x72ZFMoJHKVou`YXq>BbFwCcgR>mYx)NWG5<18B}t#ciXRi@)I$+N*U zF~ew9CaKvg&fm26)5Mf??SLFLwZ|91Se4N*mui%2HB`oeF?;)IK8#ayfWhik*f)6! zb$NzE=jBPBt;0*YkZ3h~?_prs+g`C(^H%d@{R|vUnV=~xG@%KEZlVi)n6UhSo8E@G z6+ya1H`)+gw51Pg7^IU_9?{+=44Xs#C8GYyIqg!oEQ}QvGHtQc_SfO453Nn!BDFPb zfo~nCn#2wT+ADODm6el%{k2zBLp{a5sTy%ns7wZ~_Ms!(T4e%RGU3|Mn)|T3xm{xv z_M+|(P{LE){m-yb49q(ajNjv3?(#P8p})s_qQ~`ZtiLAi)`U@cl0)C2S;8o<=O`Y| zN-tA6|qUPvPkj@L|C9-GECw##sVu6PM1 z!kT3Xo}DQciA((%i7!apAfvy5WrR?b#KHEyeb0A&r|tX$tNgJ! diff --git a/error_genration/add_errors.py b/error_genration/add_errors.py deleted file mode 100644 index 2706f79c..00000000 --- a/error_genration/add_errors.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (C) 2021 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import random -import copy -import redbaron as rb -import numpy as np - - -import concurrent.futures as cf - -from error_genration.misc_utils import ( - get_random_int, - load_dataa, - get_codeforeces_paths, - write_csv, -) -from error_genration.err_expr_utils import zerro_error_perturbation - - -def add_perturbation(perturb_node, expr_ass, expr_ass_line, expr_err, expr_err_line): - perturb_node.at(expr_ass_line).insert_before(expr_ass) - perturb_node.at(expr_err_line).insert_before(expr_err) - - -def get_ass_expr_lines(code_lines, apart=0.75): - ln = len(code_lines) - assert ln > 0 - assign_upper_range = max(1, int((1 - apart) * ln)) - if assign_upper_range == 1: - ass_line = 1 - else: - ass_line = get_random_int(1, assign_upper_range) - - for counter in range(20): - if ( - to_include(code_lines[ass_line - 1]) - or ass_line == ln - or assign_upper_range == 1 - ): - break - ass_line = get_random_int(1, int((1 - apart) * ln)) - if not to_include(code_lines[ass_line - 1]): - return -1, -1 - - expr_line = min(ln, ass_line + int(apart * ln)) - for counter in range(20): - if to_include(code_lines[expr_line - 1]) or expr_line == ln: - break - expr_line = get_random_int(min(ln, ass_line + int(apart * ln)), ln) - if not to_include(code_lines[expr_line - 1]): - return -1, -1 - return ass_line, expr_line - - -def get_parent_node(line_to_perturb, red): - selected_node = red.at(line_to_perturb) - print(selected_node.dumps()) - if "if __name__ == '__main__':" in selected_node.parent.dumps(): - if len(selected_node.dumps().split("\n")) < 5: - return None, False - else: - return selected_node.parent, True - else: - out_node = selected_node - while out_node != red and "def" not in out_node.dumps(): - if out_node.dumps() == out_node.parent.dumps(): - break - out_node = out_node.parent - return out_node, True - - -def to_include(line): - for token in ["if", "while", "for", "def", "class", "import", "else"]: - if token in line: - return False - if not line.strip(): - return False - line_rb = rb.RedBaron(line.strip()) - if isinstance(line_rb[0], rb.nodes.CommentNode) or isinstance( - line_rb[0], rb.nodes.EndlNode - ): - return False - return True - - -def get_perturb_node(red, program_source, program_ln): - perturb_node = None - for counter in range(20): - line_to_perturb = get_random_int(1, program_ln) - if to_include(program_source[line_to_perturb - 1]): - break - if to_include(program_source[line_to_perturb - 1]): - perturb_node, found_correct_location = get_parent_node(line_to_perturb, red) - if not perturb_node: - perturb_node = red - return perturb_node - - -def perturb_program(program_fp, suffx="perturbed"): - output_fp = program_fp.replace(".txt", f"_{suffx}.txt") - - program = load_dataa(program_fp).strip() - - try: - red = rb.RedBaron(program) - program_lines = program.split("\n") - program_ln = len(program_lines) - perturb_node = get_perturb_node(red, program_lines, program_ln) - - [expr_ass, expr_err], is_err = zerro_error_perturbation() - perturb_node_lines = perturb_node.dumps().split("\n") - expr_ass_line, expr_err_line = get_ass_expr_lines( - perturb_node_lines, apart=0.75 - ) - - if ( - not perturb_node_lines[expr_ass_line - 1] - or not perturb_node_lines[expr_err_line - 1] - or expr_ass_line == -1 - ): - return "", -1, "Not found a good line" - add_perturbation(perturb_node, expr_ass, expr_ass_line, expr_err, expr_err_line) - # write_csv(program.dumps(), output_fp) - except Exception as e: - return "", -1, f"{e}" - return output_fp, is_err, None - - -def perturb_program_wrapper(paths): - - output_paths = [] - errors = [] - for path in paths: - print(path) - if path.endswith(".txt") or path.endswith(".py"): - out_path, label, error = perturb_program(path, suffx="perturbed") - if error: - errors.append(f"{path}:\n {error}") - else: - output_paths.append(f"{out_path},{label}") - else: - errors.append(f"{path}:\n format error") - return (output_paths, errors) - - -def concurrent_program_perturbation(paths, num_processes, out_path="./"): - per_process_paths = np.array_split(paths, num_processes) - output_paths, errors = [], [] - with cf.ProcessPoolExecutor() as executor: - results = [ - executor.submit(perturb_program_wrapper, per_process_paths[process_num]) - for process_num in range(num_processes) - ] - for completed in cf.as_completed(results): - res_out_paths, res_errs = completed.result() - output_paths.extend(res_out_paths) - errors.append(res_errs) - # write_csv("\n".join(output_paths), f"{out_path}/label_file.csv") - # write_csv("\n".join(errors), f"{out_path}/error_file.csv") - - -def main(): - code_forces_paths = get_codeforeces_paths( - "/home/mila/r/rishab.goel/description2code_current/codeforces" - ) - # concurrent_program_perturbation(code_forces_paths[:4], 3) - res_out_paths, res_errs = perturb_program_wrapper(code_forces_paths[:100]) - # import pdb;pdb.set_trace() - - -if __name__ == "__main__": - # data = load_dataa("/home/mila/r/rishab.goel/description2code_current/codechef/easy/ACBALL/solutions_python/10211792.txt") - main() diff --git a/error_genration/calculate_stats.py b/error_genration/calculate_stats.py deleted file mode 100644 index 99a121e2..00000000 --- a/error_genration/calculate_stats.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (C) 2021 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import redbaron as rb -import os -from collections import defaultdict -import matplotlib - -matplotlib.use("agg") -import matplotlib.pyplot as plt -import json -import math - - -def load_file(file_name): - with open(file_name, "r") as file: - code = file.read() - return code - - -def get_codeforeces_paths(base_path): - problem_paths = [ - os.path.join(base_path, problem_name_dir) - for problem_name_dir in os.listdir(base_path) - if os.path.isdir(os.path.join(base_path, problem_name_dir)) - ] - print(len(problem_paths)) - solution_paths = [] - for problem_path in problem_paths: - solutions_path = os.path.join(problem_path, "solutions_python") - if os.path.exists(solutions_path): - solutions_path = [ - os.path.join(solutions_path, sol_name) - for sol_name in os.listdir(solutions_path) - ] - solution_paths.append(solutions_path) - - solution_paths = [sol_path for path in solution_paths for sol_path in path] - return solution_paths - - -def ceil(x): - return int(math.ceil(x / 10.0)) * 10 - - -def get_depth_stats_wrapper(files, block_types=["def", "class"]): - stats_dict = {block_type: defaultdict(int) for block_type in block_types} - stats_dict["len"] = defaultdict(int) - err_files = [] - print(len(files)) - for file in files: - code = load_file(file) - code_ln = len(code.strip().split("\n")) - stats_dict["len"][ceil(code_ln)] += 1 - is_err = get_depth_stats(code, stats_dict, block_types) - if is_err: - err_files.append(file) - return stats_dict, err_files - - -def get_depth_stats(code, stats_dict, block_types): - try: - red = rb.RedBaron(code) - except Exception as e: - return True - for block_type in block_types: - curr_max = 0 - for block in red.find_all(block_type): - depth = 0 - while block != red: - depth += 1 - # while block.dumps()==block.parent.dumps(): - # import pdb;pdb.set_trace() - block = block.parent - curr_max = max(curr_max, depth) - stats_dict[block_type][curr_max] += 1 - return False - - -def plot_data(x, y, xlabel, ylabel, name): - plt.clf() - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.bar(x, y) - plt.savefig(name + ".pdf") - # plt.show() - - -def save_csv(data, name): - with open(name, "w") as file: - file.write("\n".join(data)) - - -def save_json(data, name): - with open(name, "w") as file: - json.dump(data, file) - - -def plot_wrapper(data): - for key, value in data.items(): - x, y = zip(*value.items()) - plot_data(x, y, key, "Count", key) - - -def get_codechef_paths(base_path): - pass - - -def main(): - paths = get_codeforeces_paths( - "/home/mila/r/rishab.goel/description2code_current/codeforces" - ) - codechef_paths = get_codechef_paths( - "/home/mila/r/rishab.goel/description2code_current/codechef" - ) - data, err_files = get_depth_stats_wrapper(paths) - save_csv(err_files, "error.csv") - save_csv(data, "calculated_stats.json") - plot_wrapper(data) - - -main() diff --git a/error_genration/err_expr_utils.py b/error_genration/err_expr_utils.py deleted file mode 100644 index 0e636458..00000000 --- a/error_genration/err_expr_utils.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (C) 2021 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from error_genration.misc_utils import ( - get_random_int, - get_random_list_sample, - get_random_float, - get_random_int, -) - -variable_names = [f"tmp{i}" for i in range(10)] + [chr(ord("a") + i) for i in range(26)] -num_range = [1, 100] - - -def get_zerro_expression_signature(var1, var2, val1, val2, val3, is_zerro_err): - output_expr = [f"{var1}={val1}\n"] - before_sub = get_random_int(0, 1) - if before_sub: - line = ( - f"{var2}={val2}/{var1}-{val1}\n" - if is_zerro_err - else f"{var2}={val2}/{var1}-{val3}\n" - ) - else: - line = ( - f"{var2}={val2}/{val1}-{var1}\n" - if is_zerro_err - else f"{var2}={val2}/{val3}-{var1}\n" - ) - output_expr.append(line) - return output_expr - - -def zerro_error_perturbation(): - sampled_vars = get_random_list_sample(variable_names, 2) - samples_vals = get_random_float(*num_range, size=3) - is_zerro_err = get_random_int(0, 1) - return ( - get_zerro_expression_signature(*sampled_vars, *samples_vals, is_zerro_err), - is_zerro_err, - ) From 456ae1fdf14514a66d5030ff316355d4022510c1 Mon Sep 17 00:00:00 2001 From: Rishab Goel Date: Wed, 16 Jun 2021 18:13:54 -0400 Subject: [PATCH 4/6] fixed bugs and improved zero pertubation additon --- error_generation/add_code.py | 99 ++++++++++++++++++++++++++-------- error_generation/get_trace.py | 86 +++++++++++++++++------------ error_generation/main.py | 49 ++++++++++++----- error_generation/misc_utils.py | 63 +++++++++++++++++----- error_generation/trace_code.py | 7 +-- 5 files changed, 220 insertions(+), 84 deletions(-) diff --git a/error_generation/add_code.py b/error_generation/add_code.py index aa8b4797..7a07e0d0 100644 --- a/error_generation/add_code.py +++ b/error_generation/add_code.py @@ -13,30 +13,85 @@ # limitations under the License. import redbaron as rb -from misc_utils import get_random_list_sample, load_json, load_data, write_csv +from misc_utils import ( + get_random_list_sample, + load_json, + load_data, + write_csv, + get_random_int, + get_random_float, +) + +VARIABLE_NAMES = ["tmp{}".format(i) for i in range(10)] + [ + chr(ord("a") + i) for i in range(26) +] +NUM_RANGE = [1, 100] + def get_perturb_line_step(code_trace): - perturb_line = get_random_list_sample(code_trace.keys(), 1)[0] - perturb_step = get_random_list_sample(code_trace[perturb_line],1)[0] - # import pdb;pdb.set_trace() - perturb_var = get_random_list_sample(perturb_step.keys(),1)[0] - perturb_val = perturb_step[perturb_var] - return int(perturb_line), perturb_var, int(perturb_val) + perturb_line = get_random_list_sample(code_trace.keys(), 1)[0] + perturb_step = get_random_list_sample(code_trace[perturb_line], 1)[0] + # import pdb;pdb.set_trace() + perturb_var = get_random_list_sample(perturb_step.keys(), 1)[0] + perturb_val = perturb_step[perturb_var] + return int(perturb_line), perturb_var, int(perturb_val) + + +def get_zero_perturb_expression(perturb_var, perturb_val): + assign_var = get_random_list_sample(VARIABLE_NAMES, 1)[0] + is_zerro_err = get_random_int(0, 1) + if is_zerro_err: + numerator = get_random_float(*NUM_RANGE, size=1)[0] + return ( + assign_var + + "=" + + str(int(numerator)) + + "/" + + str(perturb_val) + + "-" + + perturb_var, + is_zerro_err, + ) + else: + perturb_val_offset, numerator = get_random_float(*NUM_RANGE, size=2) + perturb_val = perturb_val + int(perturb_val_offset) + return ( + assign_var + + "=" + + str(int(numerator)) + + "/" + + str(perturb_val) + + "-" + + perturb_var, + is_zerro_err, + ) -def get_perturb_expression(perturb_var, perturb_val): - return 'tmp1 = 1/'+str(perturb_val)+ '-'+ perturb_var, True def perturb_program(red, code_trace): - perturb_line, perturb_var, perturb_val = get_perturb_line_step(code_trace) - perturb_expression, is_err_present = get_perturb_expression(perturb_var, perturb_val) - # print perturb_expression, perturb_line - # import pdb;pdb.set_trace() - red.at(perturb_line).insert_after(perturb_expression) - -def add_error(org_code_fp, code_trace_fp, suffx): - code_trace = load_json(code_trace_fp) - err_code_fp = org_code_fp.replace(".txt", "_"+suffx+".txt") - program = load_data(org_code_fp).strip() - red = rb.RedBaron(program) - _ = perturb_program(red, code_trace) - write_csv(red.dumps(), err_code_fp) \ No newline at end of file + perturb_line, perturb_var, perturb_val = get_perturb_line_step(code_trace) + perturb_expression, is_err_present = get_zero_perturb_expression( + perturb_var, perturb_val + ) + red.at(perturb_line).insert_after(perturb_expression) + return is_err_present + + +def add_error(org_code_fp, code_trace_fp, err_code_fp, suffx): + code_trace = load_json(code_trace_fp) + # To keep this function generic the name of the output + # code file has the error type and indicator whether the + # the error is present or not as suffix. + err_code_fp = err_code_fp.replace(".txt", "_" + suffx + ".txt") + program = load_data(org_code_fp).strip() + red = rb.RedBaron(program) + try: + is_zerro_err = perturb_program(red, code_trace) + err_code_fp = err_code_fp.replace(".txt", "_" + str(is_zerro_err) + ".txt") + except Exception as e: + # We can handle the exception as we want. + # But for the time being we can return False. + # import pdb;pdb.set_trace() + return False + + write_csv(red.dumps(), err_code_fp) + return True diff --git a/error_generation/get_trace.py b/error_generation/get_trace.py index 3369a2d9..c4e5c345 100644 --- a/error_generation/get_trace.py +++ b/error_generation/get_trace.py @@ -18,39 +18,59 @@ from collections import defaultdict import subprocess -def postprocess_and_save(json_fp, offset, processed_suffix): - data = json.load(open(json_fp,"rb")) - processed_data = {} - for key, val in data.items(): - val = [v for v in val if v] - if val: - processed_data[int(key)-offset] = val - out_path=json_fp.replace(".json", "_"+str(processed_suffix)+".json") - open(out_path, 'w').write(json.dumps(processed_data)) +def postprocess_and_save(json_fp, offset, processed_suffix): + """Here we offset the lines of the trace to take into account + additional lines added to get the trace of the function. + """ + data = json.load(open(json_fp, "rb")) + processed_data = {} + for key, val in data.items(): + val = [v for v in val if v] + if val: + processed_data[int(key) - offset] = val + out_path = json_fp.replace(".json", "_" + str(processed_suffix) + ".json") + open(out_path, "w").write(json.dumps(processed_data)) -def run_for_errors(python_filepath, data_path, trace_path, stdin_file, stdout_file, stderr_file, processed_suffix = 'processed'): - # Assumes the input is stdin when called. - trace_source = open(trace_path, 'r').read() - python_source = open(python_filepath, 'r').read() - python_source = python_source.replace('__name__ == "__main__"', 'True') - python_source = python_source.replace("__name__ == '__main__'", 'True') - python_source = ( - 'import json\n' - + 'import sys\n' - + 'def main__errorchecker__():\n' - + '\n'.join(' ' + line for line in python_source.split('\n')) - + '\n' - + trace_source - + '\nsys.settrace(trace_calls)\n' - + 'main__errorchecker__()\n' - + 'print "yo"\n' - + 'open("' + data_path + '","w").write(json.dumps(data, indent=4, sort_keys=True))\n' - ) - try: - subprocess_call = subprocess.check_call(['python', '-c', python_source], stdin=open(stdin_file, 'rb'), stdout=open(stdout_file, 'wb'), stderr=open(stderr_file, 'wb')) - except Exception as e: - return False - postprocess_and_save(data_path, 3, processed_suffix) - return True +def run_for_errors( + python_filepath, + data_path, + trace_path, + stdin_file, + stdout_file, + stderr_file, + processed_suffix="processed", + offset=3, +): + # Assumes the input is stdin when called. + # import pdb;pdb.set_trace() + trace_source = open(trace_path, "r").read() + python_source = open(python_filepath, "r").read() + python_source = python_source.replace('__name__ == "__main__"', "True") + python_source = python_source.replace("__name__ == '__main__'", "True") + python_source = ( + "import json\n" + + "import sys\n" + + "def main__errorchecker__():\n" + + "\n".join(" " + line for line in python_source.split("\n")) + + "\n" + + trace_source + + "\nsys.settrace(trace_calls)\n" + + "main__errorchecker__()\n" + + 'open("' + + data_path + + '","w").write(json.dumps(data, indent=4, sort_keys=True))\n' + ) + try: + subprocess_call = subprocess.check_call( + ["python", "-c", python_source], + stdin=open(stdin_file, "rb"), + stdout=open(stdout_file, "wb"), + stderr=open(stderr_file, "wb"), + ) + except Exception as e: + raise e + return False + postprocess_and_save(data_path, offset, processed_suffix) + return True diff --git a/error_generation/main.py b/error_generation/main.py index 5359c9cc..8c8407ae 100644 --- a/error_generation/main.py +++ b/error_generation/main.py @@ -18,18 +18,43 @@ from get_trace import run_for_errors from add_code import add_error + def main(base_path, trace_code_path, process_suffix="processed"): - code_inp_data_paths = get_codeforeces_paths(base_path) - for code_path, inp_paths in code_inp_data_paths: - for idx, inp_path in enumerate(inp_paths): - err_path = inp_path.replace(".txt", "_error.txt") - out_path = inp_path.replace(".txt", "_out.txt") - out_code_path = code_path.replace(".txt", "_"+str(idx)+"_perturbed.txt") - data_trace_path = code_path.replace(".txt", "_"+str(idx)+"_trace.json") - trace_successful = run_for_errors(code_path, data_trace_path, trace_code_path, inp_path, out_path, err_path, process_suffix) - if trace_successful: - data_trace_path = data_trace_path.replace(".json", "_"+process_suffix+".json") - add_error(code_path, data_trace_path, "zero_err") + code_inp_data_paths = get_codeforeces_paths(base_path) + for ( + code_path, + inp_paths, + perturbed_code_path, + trace_data_path, + sol_err_out_path, + ) in code_inp_data_paths: + for idx, inp_path in enumerate(inp_paths): + # print inp_path + err_path = sol_err_out_path.replace( + ".txt", "_error" + "_" + str(idx) + ".txt" + ) + out_path = sol_err_out_path.replace( + ".txt", "_out" + "_" + str(idx) + ".txt" + ) + out_code_path = perturbed_code_path.replace(".txt", "_" + str(idx) + ".txt") + data_trace_path = trace_data_path.replace( + ".json", "_trace_" + str(idx) + ".json" + ) + trace_successful = run_for_errors( + code_path, + data_trace_path, + trace_code_path, + inp_path, + out_path, + err_path, + process_suffix, + ) + # print trace_successful + if trace_successful: + data_trace_path = data_trace_path.replace( + ".json", "_" + process_suffix + ".json" + ) + _ = add_error(code_path, data_trace_path, out_code_path, "zero_err") -main("/Users/rishabgoel/Downloads/description2code_current/codeforces", "trace_code.py") \ No newline at end of file +main("/Users/rishabgoel/Documents/compressive-ipagnn/data/codeforces", "trace_code.py") diff --git a/error_generation/misc_utils.py b/error_generation/misc_utils.py index 7b35ebff..07fefd09 100644 --- a/error_generation/misc_utils.py +++ b/error_generation/misc_utils.py @@ -17,33 +17,66 @@ import os import json + def get_codeforces_inp_data_paths(base_path): inp_data_directory = os.path.join(base_path, "samples") - inp_data_paths = [inp_data_directory+"/"+path for path in os.listdir(inp_data_directory) if "input" in path] + inp_data_paths = [ + inp_data_directory + "/" + path + for path in os.listdir(inp_data_directory) + if "input" in path + ] return inp_data_paths +def create_dirs(dir): + if not os.path.exists(dir): + os.makedirs(dir) + + def get_codeforeces_paths(base_path): problem_paths = [ os.path.join(base_path, problem_name_dir) for problem_name_dir in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, problem_name_dir)) ] - print(len(problem_paths)) - solution_paths = [] + # print(len(problem_paths)) + data_paths = [] for problem_path in problem_paths: - solutions_path = os.path.join(problem_path, "solutions_python") + prob_solutions_path = os.path.join(problem_path, "solutions_python") + # print prob_solutions_path + perturbed_prob_solutions_path = os.path.join( + problem_path, "perturbed_solutions_python" + ) + create_dirs(perturbed_prob_solutions_path) + trace_path = os.path.join(problem_path, "trace") + create_dirs(trace_path) + err_out_path = os.path.join(problem_path, "err_out") + create_dirs(err_out_path) inp_data_paths = get_codeforces_inp_data_paths(problem_path) # import pdb;pdb.set_trace() - if os.path.exists(solutions_path): - solutions_path = [ - (os.path.join(solutions_path, sol_name), inp_data_paths) - for sol_name in os.listdir(solutions_path) - ] - solution_paths.append(solutions_path) - - solution_paths = [sol_path for path in solution_paths for sol_path in path] - return solution_paths + if os.path.exists(prob_solutions_path): + solution_paths = [] + for sol_name in os.listdir(prob_solutions_path): + code_path = os.path.join(prob_solutions_path, sol_name) + sol_name_json = sol_name.replace(".txt", ".json") + perturbed_code_path = os.path.join( + perturbed_prob_solutions_path, sol_name + ) + trace_code_path = os.path.join(trace_path, sol_name_json) + sol_err_out_path = os.path.join(err_out_path, sol_name) + solution_paths.append( + ( + code_path, + inp_data_paths, + perturbed_code_path, + trace_code_path, + sol_err_out_path, + ) + ) + data_paths.append(solution_paths) + + data_paths = [sol_path for path in data_paths for sol_path in path] + return data_paths def set_seeds(seed=10): @@ -56,10 +89,12 @@ def load_data(fp): data = file.read().strip() return data + def load_json(fp): with open(fp, "r") as file: return json.load(file) + def write_csv(data, fp): with open(fp, "w") as file: file.write(data) @@ -74,4 +109,4 @@ def get_random_float(lower_limit, upper_limit, size=None): def get_random_list_sample(lst, num_samples): - return random.sample(lst, num_samples) \ No newline at end of file + return random.sample(lst, num_samples) diff --git a/error_generation/trace_code.py b/error_generation/trace_code.py index f5632285..bb870e6e 100644 --- a/error_generation/trace_code.py +++ b/error_generation/trace_code.py @@ -2,14 +2,15 @@ data = defaultdict(list) + def trace_lines(frame, event, arg): - if event != 'line': + if event != "line": return co = frame.f_code func_name = co.co_name line_no = frame.f_lineno filename = co.co_filename - if filename=="": + if filename == "": locals_dict = {} for key, value in frame.f_locals.items(): try: @@ -21,7 +22,7 @@ def trace_lines(frame, event, arg): def trace_calls(frame, event, arg): - if event!="call": + if event != "call": return co = frame.f_code func_name = co.co_name From 1e6b98be927135cbb3e17fe02ed3b57eaa082e68 Mon Sep 17 00:00:00 2001 From: Rishab Goel Date: Fri, 25 Jun 2021 10:56:34 -0400 Subject: [PATCH 5/6] added code for more error injection types --- __init__.py | 0 error_generation/__init__.py | 0 error_generation/add_code.py | 75 +++---- error_generation/config.yaml | 10 + error_generation/error_expression_factory.py | 225 +++++++++++++++++++ error_generation/get_trace.py | 3 +- error_generation/main.py | 42 ++-- error_generation/misc_utils.py | 31 ++- error_generation/trace_code.py | 18 +- tests/__init__.py | 0 tests/test_err_expr_factory.py | 147 ++++++++++++ 11 files changed, 485 insertions(+), 66 deletions(-) create mode 100644 __init__.py create mode 100644 error_generation/__init__.py create mode 100644 error_generation/config.yaml create mode 100644 error_generation/error_expression_factory.py create mode 100644 tests/__init__.py create mode 100644 tests/test_err_expr_factory.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/error_generation/__init__.py b/error_generation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/error_generation/add_code.py b/error_generation/add_code.py index 7a07e0d0..0ce74bb5 100644 --- a/error_generation/add_code.py +++ b/error_generation/add_code.py @@ -20,6 +20,7 @@ write_csv, get_random_int, get_random_float, + get_valid_code_trace, ) VARIABLE_NAMES = ["tmp{}".format(i) for i in range(10)] + [ @@ -28,65 +29,53 @@ NUM_RANGE = [1, 100] -def get_perturb_line_step(code_trace): +def get_perturb_line_step(code_trace_org, err_suffx): + # Not all the variable types are valid for all the errors. + # For instance for index out of range an int var is not valid. + code_trace = get_valid_code_trace(code_trace_org, err_suffx) + if not code_trace: + return None, None, None perturb_line = get_random_list_sample(code_trace.keys(), 1)[0] perturb_step = get_random_list_sample(code_trace[perturb_line], 1)[0] - # import pdb;pdb.set_trace() perturb_var = get_random_list_sample(perturb_step.keys(), 1)[0] perturb_val = perturb_step[perturb_var] - return int(perturb_line), perturb_var, int(perturb_val) + return int(perturb_line), perturb_var, perturb_val - -def get_zero_perturb_expression(perturb_var, perturb_val): - assign_var = get_random_list_sample(VARIABLE_NAMES, 1)[0] - is_zerro_err = get_random_int(0, 1) - if is_zerro_err: - numerator = get_random_float(*NUM_RANGE, size=1)[0] - return ( - assign_var - + "=" - + str(int(numerator)) - + "/" - + str(perturb_val) - + "-" - + perturb_var, - is_zerro_err, - ) - else: - perturb_val_offset, numerator = get_random_float(*NUM_RANGE, size=2) - perturb_val = perturb_val + int(perturb_val_offset) - return ( - assign_var - + "=" - + str(int(numerator)) - + "/" - + str(perturb_val) - + "-" - + perturb_var, - is_zerro_err, - ) - - -def perturb_program(red, code_trace): - perturb_line, perturb_var, perturb_val = get_perturb_line_step(code_trace) - perturb_expression, is_err_present = get_zero_perturb_expression( +def perturb_program(red, code_trace, err_suffx, error_expr_factory_obj): + perturb_line, perturb_var, perturb_val = get_perturb_line_step(code_trace, err_suffx) + if perturb_line is None: + return 0 + perturb_expression, is_err_present = error_expr_factory_obj.add_err( err_suffx, perturb_var, perturb_val ) - red.at(perturb_line).insert_after(perturb_expression) + # TODO(rishab): Need to be careful to ensure that that the insertion + # line is not an AssignmentNode in RedBaron. + if err_suffx == "math_domain_err": + # The sqrt function needs to be imported so that sqrt function + # can be called. I am not sure if we can just add the expression + # without proper imports. + import_statement, perturb_expression = perturb_expression.split(";") + red.at(perturb_line).insert_before(import_statement, offset=perturb_line-1) + red.at(perturb_line+1).insert_after(perturb_expression) + else: + red.at(perturb_line).insert_after(perturb_expression) return is_err_present -def add_error(org_code_fp, code_trace_fp, err_code_fp, suffx): +def add_error(org_code_fp, code_trace_fp, err_code_fp, err_suffx, error_expr_factory_obj): + # We can optimize the code by passing the read file. + # But for now to ensure isolation, I am doing it + # explicitly. code_trace = load_json(code_trace_fp) # To keep this function generic the name of the output # code file has the error type and indicator whether the # the error is present or not as suffix. - err_code_fp = err_code_fp.replace(".txt", "_" + suffx + ".txt") + err_code_fp = err_code_fp.replace(".txt", "-" + err_suffx + ".txt") program = load_data(org_code_fp).strip() red = rb.RedBaron(program) try: - is_zerro_err = perturb_program(red, code_trace) - err_code_fp = err_code_fp.replace(".txt", "_" + str(is_zerro_err) + ".txt") + is_err_present = perturb_program(red, code_trace, err_suffx, error_expr_factory_obj) + err_code_fp = err_code_fp.replace(".txt", "-" + str(is_err_present) + ".txt") except Exception as e: # We can handle the exception as we want. # But for the time being we can return False. @@ -94,4 +83,4 @@ def add_error(org_code_fp, code_trace_fp, err_code_fp, suffx): return False write_csv(red.dumps(), err_code_fp) - return True + return True \ No newline at end of file diff --git a/error_generation/config.yaml b/error_generation/config.yaml new file mode 100644 index 00000000..4591a2a0 --- /dev/null +++ b/error_generation/config.yaml @@ -0,0 +1,10 @@ +base_path: /Users/rishabgoel/Documents/compressive-ipagnn/data/codeforces +trace_code_path: error_generation/trace_code.py +process_suffix: processed +errors: + - zero_err + - assert_err + - not_subscriptable_err + - idx_out_range_err + - math_domain_err + - not_iterable_err \ No newline at end of file diff --git a/error_generation/error_expression_factory.py b/error_generation/error_expression_factory.py new file mode 100644 index 00000000..27628cb3 --- /dev/null +++ b/error_generation/error_expression_factory.py @@ -0,0 +1,225 @@ +# Copyright (C) 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from misc_utils import ( + get_random_list_sample, + load_json, + load_data, + write_csv, + get_random_int, + get_random_float, +) + +class ErrorFactory: + """TODO (rishab): + 1. implement methods for var name not defined and + operand mismatch. + 2. make the expressions more complex. + """ + VARIABLE_NAMES = ["tmp{}".format(i) for i in range(10)] + [ + chr(ord("a") + i) for i in range(26) + ] + NUM_RANGE = [1, 100] + + def __init__(self): + self._builders = { + "zero_err": self.get_zero_perturb_expression, + "assert_err": self.get_assert_perturb_expression, + "not_subscriptable_err": self.get_not_subscriptable_perturb_expression, + "idx_out_range_err": self.get_index_range_perturb_expression, + "undef_var_err": self.get_undef_name_perturb_expression, # Caution: not implemented properly. + "math_domain_err": self.get_math_domain_perturb_expression, + "not_iterable_err": self.get_int_not_iterable_perturb_expression + } + + def get_zero_perturb_expression(self, perturb_var, perturb_val): + assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0] + is_zerro_err = get_random_int(0, 1) + # is_zerro_err = 1 + if is_zerro_err: + numerator = get_random_float(*self.NUM_RANGE, size=1)[0] + return ( + assign_var + + "=" + + str(int(numerator)) + + "/(" + + str(perturb_val) + + "-" + + perturb_var + + ')', + is_zerro_err, + ) + else: + perturb_val_offset, numerator = get_random_float(*self.NUM_RANGE, size=2) + perturb_val = perturb_val + int(perturb_val_offset) + return ( + assign_var + + "=" + + str(int(numerator)) + + "/(" + + str(perturb_val) + + "-" + + perturb_var + + ')', + is_zerro_err, + ) + def get_assert_perturb_expression(self, perturb_var, perturb_val): + is_assert_err = get_random_int(0, 1) + # is_assert_err = 1 + if is_assert_err: + perturb_val_offset = get_random_float(*self.NUM_RANGE, size=1)[0] + perturb_val = perturb_val + int(perturb_val_offset) + return ( + "assert " + + perturb_var + + "==" + + str(perturb_val), + is_assert_err, + ) + else: + return ( + "assert " + + perturb_var + + "==" + + str(perturb_val), + is_assert_err, + ) + def get_not_subscriptable_perturb_expression(self, perturb_var, perturb_val): + is_not_subscriptable_err = get_random_int(0, 1) + # is_not_subscriptable_err = 1 + if is_not_subscriptable_err: + random_val, numerator = get_random_float(*self.NUM_RANGE, size=2) + return ( + perturb_var + + "[" + + str(int(numerator)) + + "] = " + + str(int(random_val)), + is_not_subscriptable_err, + ) + else: + return ( + "", + is_not_subscriptable_err, + ) + + def get_index_range_perturb_expression(self, perturb_var, perturb_val): + """This will occur very less frequently and hence we perhaps + need to rethink how to handle generate the error. + """ + is_index_range_err = get_random_int(0, 1) + # is_index_range_err = 1 + if is_index_range_err: + random_ass = get_random_float(*self.NUM_RANGE, size=1)[0] + return ( + perturb_var + + "[" + + str(len(perturb_val)) + + "] = " + + str(int(random_ass)), + is_index_range_err, + ) + else: + valid_idx = int(get_random_float(*[0, len(perturb_val)-1], size=1)[0]) + random_ass = get_random_float(*self.NUM_RANGE, size=1)[0] + return ( + perturb_var + + "[" + + str(valid_idx) + + "] = " + + str(random_ass), + is_index_range_err, + ) + + def get_undef_name_perturb_expression(self, perturb_var, perturb_val): + """Not implemented as per our requirements.""" + is_undef_name_err = get_random_int(0, 1) + # is_undef_name_err = 1 + if is_undef_name_err: + undef_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0] + return ( + perturb_var + + "=" + + undef_var + + "+" + + str(perturb_val), + is_undef_name_err, + ) + else: + return ( + "", + is_undef_name_err, + ) + + def get_math_domain_perturb_expression(self, perturb_var, perturb_val): + """The current implementation may cause unforeseen issues when the + is_math_domain_err is 0 as the assign_var can be a part of the program. Also, we may + perhaps need to refine how we import math module.""" + is_math_domain_err = get_random_int(0, 1) + # is_math_domain_err = 1 + if is_math_domain_err: + assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0] + if perturb_val>=0: + random_ass = str(-1*int(get_random_float(*self.NUM_RANGE, size=1)[0])) + "*" + perturb_var + else: + random_ass = str(int(get_random_float(*self.NUM_RANGE, size=1)[0])) + "*" + perturb_var + return ( + "import math;" + + assign_var + + "=" + + "math.sqrt("+str(random_ass)+")", + is_math_domain_err, + ) + else: + assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0] + if perturb_val>=0: + random_ass = str(int(get_random_float(*self.NUM_RANGE, size=1)[0])) + "*" + perturb_var + else: + random_ass = str(-1*int(get_random_float(*self.NUM_RANGE, size=1)[0])) + "*" + perturb_var + return ( + "import math;" + + assign_var + + "=" + + "math.sqrt("+str(random_ass)+")", + is_math_domain_err, + ) + def _relevant_operand_val_type(self, val, is_same): + pass + + def get_operand_type_mismatch_perturb_expression(self, perturb_var, perturb_val): + pass + + def get_int_not_iterable_perturb_expression(self, perturb_var, perturb_val): + """TODO: 1. Add more variants of the for loop. + 2. Add logic to include the while loop. + """ + is_int_not_iterable_err = get_random_int(0, 1) + # is_int_not_iterable_err = 1 + if is_int_not_iterable_err: + assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0] + random_ass = int(get_random_float(*self.NUM_RANGE, size=1)[0]) + return ( + "{}=[{}+val for val in {}]".format(assign_var, random_ass, perturb_var), + is_int_not_iterable_err + ) + else: + return "", is_int_not_iterable_err + + def add_err(self, err_type, perturb_var, perturb_val): + expr_builder = self._builders.get(err_type.lower(), None) + if not expr_builder: + raise ValueError(err_type + " is not a valid error generation function.") + return expr_builder(perturb_var, perturb_val) + \ No newline at end of file diff --git a/error_generation/get_trace.py b/error_generation/get_trace.py index c4e5c345..36c71d69 100644 --- a/error_generation/get_trace.py +++ b/error_generation/get_trace.py @@ -49,6 +49,7 @@ def run_for_errors( python_source = open(python_filepath, "r").read() python_source = python_source.replace('__name__ == "__main__"', "True") python_source = python_source.replace("__name__ == '__main__'", "True") + # TODO(rishab): Clean the python_source variable. python_source = ( "import json\n" + "import sys\n" @@ -70,7 +71,7 @@ def run_for_errors( stderr=open(stderr_file, "wb"), ) except Exception as e: - raise e + # raise e return False postprocess_and_save(data_path, offset, processed_suffix) return True diff --git a/error_generation/main.py b/error_generation/main.py index 8c8407ae..016e89eb 100644 --- a/error_generation/main.py +++ b/error_generation/main.py @@ -14,13 +14,24 @@ import os -from misc_utils import get_codeforeces_paths -from get_trace import run_for_errors -from add_code import add_error +from error_generation.misc_utils import get_codeforeces_paths, load_yaml, set_seeds +from error_generation.get_trace import run_for_errors +from error_generation.add_code import add_error +from error_generation.error_expression_factory import ErrorFactory +""" +TODO(rishab): Setup code to include codechef as well. -def main(base_path, trace_code_path, process_suffix="processed"): - code_inp_data_paths = get_codeforeces_paths(base_path) +Run instructions: +In the compressive-ipagnn folder run the following command: +python -m error_generation.main +""" + +def main(config_fp): + config = load_yaml(config_fp) + set_seeds() + code_inp_data_paths = get_codeforeces_paths(config["base_path"]) + error_expr_factory_obj = ErrorFactory() for ( code_path, inp_paths, @@ -40,21 +51,24 @@ def main(base_path, trace_code_path, process_suffix="processed"): data_trace_path = trace_data_path.replace( ".json", "_trace_" + str(idx) + ".json" ) - trace_successful = run_for_errors( + is_trace_successful = run_for_errors( code_path, data_trace_path, - trace_code_path, + config["trace_code_path"], inp_path, out_path, err_path, - process_suffix, + config["process_suffix"], ) - # print trace_successful - if trace_successful: + if is_trace_successful: data_trace_path = data_trace_path.replace( - ".json", "_" + process_suffix + ".json" + ".json", "_" + config["process_suffix"] + ".json" ) - _ = add_error(code_path, data_trace_path, out_code_path, "zero_err") - + for err_suffix in config["errors"]: + # import pdb;pdb.set_trace() + _ = add_error(code_path, data_trace_path, out_code_path, err_suffix, error_expr_factory_obj) + # break + # break -main("/Users/rishabgoel/Documents/compressive-ipagnn/data/codeforces", "trace_code.py") +if __name__ == '__main__': + main("error_generation/config.yaml") diff --git a/error_generation/misc_utils.py b/error_generation/misc_utils.py index 07fefd09..3a01ac64 100644 --- a/error_generation/misc_utils.py +++ b/error_generation/misc_utils.py @@ -16,7 +16,19 @@ import numpy as np import os import json - +import yaml +import copy +from collections import defaultdict + +ALLOWED_TYPE = { + "zero_err": [int, float], + "assert_err": [int, float], + "not_subscriptable_err": [int, float], + "idx_out_range_err": [list], + "undef_var_err": [int, float, str, list], + "math_domain_err": [int, float], + "not_iterable_err": [int, float] + } def get_codeforces_inp_data_paths(base_path): inp_data_directory = os.path.join(base_path, "samples") @@ -78,6 +90,20 @@ def get_codeforeces_paths(base_path): data_paths = [sol_path for path in data_paths for sol_path in path] return data_paths +def is_valid_type(val, type): + return isinstance(val, type) + +def get_valid_code_trace(code_trace, err_suffx): + code_trace_filtered = defaultdict(list) + for line in code_trace: + for step_idx in range(len(code_trace[line])): + new_step_dict = {} + for var in code_trace[line][step_idx]: + if any(is_valid_type(code_trace[line][step_idx][var], typ) for typ in ALLOWED_TYPE[err_suffx]): + new_step_dict[var] = code_trace[line][step_idx][var] + if new_step_dict: + code_trace_filtered[line].append(new_step_dict) + return code_trace_filtered def set_seeds(seed=10): random.seed(seed) @@ -94,6 +120,9 @@ def load_json(fp): with open(fp, "r") as file: return json.load(file) +def load_yaml(fp): + with open(fp, "r") as file: + return yaml.load(file, Loader=yaml.FullLoader) def write_csv(data, fp): with open(fp, "w") as file: diff --git a/error_generation/trace_code.py b/error_generation/trace_code.py index bb870e6e..129b749a 100644 --- a/error_generation/trace_code.py +++ b/error_generation/trace_code.py @@ -1,3 +1,7 @@ +# Get line by line trace for variables in a program. +# It is adapted from the examples in the following link: +# https://pymotw.com/2/sys/tracing.html + from collections import defaultdict data = defaultdict(list) @@ -7,23 +11,23 @@ def trace_lines(frame, event, arg): if event != "line": return co = frame.f_code - func_name = co.co_name + # func_name = co.co_name line_no = frame.f_lineno filename = co.co_filename if filename == "": - locals_dict = {} + local_data_dict = {} for key, value in frame.f_locals.items(): try: json.dumps(value) - locals_dict[key] = value + local_data_dict[key] = value except Exception as e: - _ = "" - data[line_no].append(locals_dict) + continue + data[line_no].append(local_data_dict) def trace_calls(frame, event, arg): if event != "call": return - co = frame.f_code - func_name = co.co_name + # co = frame.f_code + # func_name = co.co_name return trace_lines diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_err_expr_factory.py b/tests/test_err_expr_factory.py new file mode 100644 index 00000000..9e489aaa --- /dev/null +++ b/tests/test_err_expr_factory.py @@ -0,0 +1,147 @@ +import pytest +import subprocess +from error_generation.error_expression_factory import ErrorFactory + +class TestErrorFactory(): + """The test suite is not complete but will test some obvious + issues in code.""" + def setup(self): + self.error_factory = ErrorFactory() + self.test_var_name = "test_var" + self.test_var_val = 23 + self.test_lst_var_name = "test_lst_var" + self.test_lst_var_val = [1,2,3,4] + self.test_var_assign = self.test_var_name + "=" + str(self.test_var_val) + "\n" + self.test_lst_var_assign = self.test_lst_var_name + "=" + str(self.test_lst_var_val) + "\n" + + def test_get_zero_perturb_expression(self): + + expr, is_err_present = self.error_factory.add_err("zero_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + + while not is_err_present: + subprocess_call = subprocess.call( + ["python", "-c", expr], stderr=subprocess.PIPE + ) + assert subprocess_call == 0 + expr, is_err_present = self.error_factory.add_err("zero_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + + with pytest.raises(subprocess.CalledProcessError) as exc: + try: + subprocess_call = subprocess.check_output( + ["python", "-c", expr], stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as exception: + if "ZeroDivisionError" in exception.output: + raise exception + + def test_get_assert_perturb_expression(self): + expr, is_err_present = self.error_factory.add_err("assert_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + while not is_err_present: + subprocess_call = subprocess.call( + ["python", "-c", expr], stderr=subprocess.PIPE + ) + assert subprocess_call == 0 + expr, is_err_present = self.error_factory.add_err("assert_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + with pytest.raises(subprocess.CalledProcessError) as exc: + try: + subprocess_call = subprocess.check_output( + ["python", "-c", expr], stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as exception: + if "AssertionError" in exception.output: + raise exception + + def test_get_not_subscriptable_perturb_expression(self): + expr, is_err_present = self.error_factory.add_err("not_subscriptable_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + while not is_err_present: + subprocess_call = subprocess.call( + ["python", "-c", expr], stderr=subprocess.PIPE + ) + assert subprocess_call == 0 + expr, is_err_present = self.error_factory.add_err("not_subscriptable_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + with pytest.raises(subprocess.CalledProcessError) as exc: + try: + subprocess_call = subprocess.check_output( + ["python", "-c", expr], stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as exception: + if "TypeError" in exception.output and "not support item assignment" in exception.output: + raise exception + + def test_get_index_range_perturb_expression(self): + expr, is_err_present = self.error_factory.add_err("idx_out_range_err", self.test_lst_var_name, self.test_lst_var_val) + expr = self.test_lst_var_assign + expr + + while not is_err_present: + subprocess_call = subprocess.call( + ["python", "-c", expr], stderr=subprocess.PIPE + ) + if subprocess_call!=0: + import pdb;pdb.set_trace() + assert subprocess_call == 0 + + expr, is_err_present = self.error_factory.add_err("idx_out_range_err", self.test_lst_var_name, self.test_lst_var_val) + expr = self.test_lst_var_assign + expr + + with pytest.raises(subprocess.CalledProcessError) as exc: + try: + subprocess_call = subprocess.check_output( + ["python", "-c", expr], stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as exception: + if "IndexError" in exception.output and "index out of range" in exception.output: + raise exception + + def test_get_math_domain_perturb_expression(self): + expr, is_err_present = self.error_factory.add_err("math_domain_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + + while not is_err_present: + subprocess_call = subprocess.call( + ["python", "-c", expr], stderr=subprocess.PIPE + ) + assert subprocess_call == 0 + expr, is_err_present = self.error_factory.add_err("math_domain_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + + with pytest.raises(subprocess.CalledProcessError) as exc: + try: + subprocess_call = subprocess.check_output( + ["python", "-c", expr], stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as exception: + if "ValueError" in exception.output and "math domain error" in exception.output: + raise exception + + def test_get_int_not_iterable_perturb_expression(self): + expr, is_err_present = self.error_factory.add_err("not_iterable_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + + while not is_err_present: + subprocess_call = subprocess.call( + ["python", "-c", expr], stderr=subprocess.PIPE + ) + assert subprocess_call == 0 + expr, is_err_present = self.error_factory.add_err("not_iterable_err", self.test_var_name, self.test_var_val) + expr = self.test_var_assign + expr + + with pytest.raises(subprocess.CalledProcessError) as exc: + try: + subprocess_call = subprocess.check_output( + ["python", "-c", expr], stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as exception: + if "TypeError" in exception.output and "object is not iterable" in exception.output: + raise exception + + def test_get_operand_type_mismatch_perturb_expression(self): + pass + + def test_get_undef_name_perturb_expression(self): + pass \ No newline at end of file From 2ad297e7a7beb89fef9577d882fd946db44f55b6 Mon Sep 17 00:00:00 2001 From: Rishab Goel Date: Fri, 25 Jun 2021 11:01:55 -0400 Subject: [PATCH 6/6] re-formatted the code --- error_generation/add_code.py | 23 ++-- error_generation/error_expression_factory.py | 104 ++++++++++--------- error_generation/main.py | 12 ++- error_generation/misc_utils.py | 27 +++-- error_generation/trace_code.py | 2 +- tests/test_err_expr_factory.py | 104 +++++++++++++------ 6 files changed, 172 insertions(+), 100 deletions(-) diff --git a/error_generation/add_code.py b/error_generation/add_code.py index 0ce74bb5..81fe7c1e 100644 --- a/error_generation/add_code.py +++ b/error_generation/add_code.py @@ -41,12 +41,15 @@ def get_perturb_line_step(code_trace_org, err_suffx): perturb_val = perturb_step[perturb_var] return int(perturb_line), perturb_var, perturb_val + def perturb_program(red, code_trace, err_suffx, error_expr_factory_obj): - perturb_line, perturb_var, perturb_val = get_perturb_line_step(code_trace, err_suffx) + perturb_line, perturb_var, perturb_val = get_perturb_line_step( + code_trace, err_suffx + ) if perturb_line is None: return 0 - perturb_expression, is_err_present = error_expr_factory_obj.add_err( err_suffx, - perturb_var, perturb_val + perturb_expression, is_err_present = error_expr_factory_obj.add_err( + err_suffx, perturb_var, perturb_val ) # TODO(rishab): Need to be careful to ensure that that the insertion # line is not an AssignmentNode in RedBaron. @@ -55,14 +58,16 @@ def perturb_program(red, code_trace, err_suffx, error_expr_factory_obj): # can be called. I am not sure if we can just add the expression # without proper imports. import_statement, perturb_expression = perturb_expression.split(";") - red.at(perturb_line).insert_before(import_statement, offset=perturb_line-1) - red.at(perturb_line+1).insert_after(perturb_expression) + red.at(perturb_line).insert_before(import_statement, offset=perturb_line - 1) + red.at(perturb_line + 1).insert_after(perturb_expression) else: red.at(perturb_line).insert_after(perturb_expression) return is_err_present -def add_error(org_code_fp, code_trace_fp, err_code_fp, err_suffx, error_expr_factory_obj): +def add_error( + org_code_fp, code_trace_fp, err_code_fp, err_suffx, error_expr_factory_obj +): # We can optimize the code by passing the read file. # But for now to ensure isolation, I am doing it # explicitly. @@ -74,7 +79,9 @@ def add_error(org_code_fp, code_trace_fp, err_code_fp, err_suffx, error_expr_fac program = load_data(org_code_fp).strip() red = rb.RedBaron(program) try: - is_err_present = perturb_program(red, code_trace, err_suffx, error_expr_factory_obj) + is_err_present = perturb_program( + red, code_trace, err_suffx, error_expr_factory_obj + ) err_code_fp = err_code_fp.replace(".txt", "-" + str(is_err_present) + ".txt") except Exception as e: # We can handle the exception as we want. @@ -83,4 +90,4 @@ def add_error(org_code_fp, code_trace_fp, err_code_fp, err_suffx, error_expr_fac return False write_csv(red.dumps(), err_code_fp) - return True \ No newline at end of file + return True diff --git a/error_generation/error_expression_factory.py b/error_generation/error_expression_factory.py index 27628cb3..f4aed201 100644 --- a/error_generation/error_expression_factory.py +++ b/error_generation/error_expression_factory.py @@ -22,28 +22,30 @@ get_random_float, ) + class ErrorFactory: """TODO (rishab): - 1. implement methods for var name not defined and - operand mismatch. - 2. make the expressions more complex. + 1. implement methods for var name not defined and + operand mismatch. + 2. make the expressions more complex. """ + VARIABLE_NAMES = ["tmp{}".format(i) for i in range(10)] + [ chr(ord("a") + i) for i in range(26) ] NUM_RANGE = [1, 100] - + def __init__(self): self._builders = { - "zero_err": self.get_zero_perturb_expression, - "assert_err": self.get_assert_perturb_expression, - "not_subscriptable_err": self.get_not_subscriptable_perturb_expression, - "idx_out_range_err": self.get_index_range_perturb_expression, - "undef_var_err": self.get_undef_name_perturb_expression, # Caution: not implemented properly. - "math_domain_err": self.get_math_domain_perturb_expression, - "not_iterable_err": self.get_int_not_iterable_perturb_expression + "zero_err": self.get_zero_perturb_expression, + "assert_err": self.get_assert_perturb_expression, + "not_subscriptable_err": self.get_not_subscriptable_perturb_expression, + "idx_out_range_err": self.get_index_range_perturb_expression, + "undef_var_err": self.get_undef_name_perturb_expression, # Caution: not implemented properly. + "math_domain_err": self.get_math_domain_perturb_expression, + "not_iterable_err": self.get_int_not_iterable_perturb_expression, } - + def get_zero_perturb_expression(self, perturb_var, perturb_val): assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0] is_zerro_err = get_random_int(0, 1) @@ -58,7 +60,7 @@ def get_zero_perturb_expression(self, perturb_var, perturb_val): + str(perturb_val) + "-" + perturb_var - + ')', + + ")", is_zerro_err, ) else: @@ -72,9 +74,10 @@ def get_zero_perturb_expression(self, perturb_var, perturb_val): + str(perturb_val) + "-" + perturb_var - + ')', + + ")", is_zerro_err, ) + def get_assert_perturb_expression(self, perturb_var, perturb_val): is_assert_err = get_random_int(0, 1) # is_assert_err = 1 @@ -82,31 +85,22 @@ def get_assert_perturb_expression(self, perturb_var, perturb_val): perturb_val_offset = get_random_float(*self.NUM_RANGE, size=1)[0] perturb_val = perturb_val + int(perturb_val_offset) return ( - "assert " - + perturb_var - + "==" - + str(perturb_val), + "assert " + perturb_var + "==" + str(perturb_val), is_assert_err, ) else: return ( - "assert " - + perturb_var - + "==" - + str(perturb_val), + "assert " + perturb_var + "==" + str(perturb_val), is_assert_err, ) + def get_not_subscriptable_perturb_expression(self, perturb_var, perturb_val): is_not_subscriptable_err = get_random_int(0, 1) # is_not_subscriptable_err = 1 if is_not_subscriptable_err: random_val, numerator = get_random_float(*self.NUM_RANGE, size=2) return ( - perturb_var - + "[" - + str(int(numerator)) - + "] = " - + str(int(random_val)), + perturb_var + "[" + str(int(numerator)) + "] = " + str(int(random_val)), is_not_subscriptable_err, ) else: @@ -132,14 +126,10 @@ def get_index_range_perturb_expression(self, perturb_var, perturb_val): is_index_range_err, ) else: - valid_idx = int(get_random_float(*[0, len(perturb_val)-1], size=1)[0]) + valid_idx = int(get_random_float(*[0, len(perturb_val) - 1], size=1)[0]) random_ass = get_random_float(*self.NUM_RANGE, size=1)[0] return ( - perturb_var - + "[" - + str(valid_idx) - + "] = " - + str(random_ass), + perturb_var + "[" + str(valid_idx) + "] = " + str(random_ass), is_index_range_err, ) @@ -150,11 +140,7 @@ def get_undef_name_perturb_expression(self, perturb_var, perturb_val): if is_undef_name_err: undef_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0] return ( - perturb_var - + "=" - + undef_var - + "+" - + str(perturb_val), + perturb_var + "=" + undef_var + "+" + str(perturb_val), is_undef_name_err, ) else: @@ -171,30 +157,51 @@ def get_math_domain_perturb_expression(self, perturb_var, perturb_val): # is_math_domain_err = 1 if is_math_domain_err: assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0] - if perturb_val>=0: - random_ass = str(-1*int(get_random_float(*self.NUM_RANGE, size=1)[0])) + "*" + perturb_var + if perturb_val >= 0: + random_ass = ( + str(-1 * int(get_random_float(*self.NUM_RANGE, size=1)[0])) + + "*" + + perturb_var + ) else: - random_ass = str(int(get_random_float(*self.NUM_RANGE, size=1)[0])) + "*" + perturb_var + random_ass = ( + str(int(get_random_float(*self.NUM_RANGE, size=1)[0])) + + "*" + + perturb_var + ) return ( "import math;" + assign_var + "=" - + "math.sqrt("+str(random_ass)+")", + + "math.sqrt(" + + str(random_ass) + + ")", is_math_domain_err, ) else: assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0] - if perturb_val>=0: - random_ass = str(int(get_random_float(*self.NUM_RANGE, size=1)[0])) + "*" + perturb_var + if perturb_val >= 0: + random_ass = ( + str(int(get_random_float(*self.NUM_RANGE, size=1)[0])) + + "*" + + perturb_var + ) else: - random_ass = str(-1*int(get_random_float(*self.NUM_RANGE, size=1)[0])) + "*" + perturb_var + random_ass = ( + str(-1 * int(get_random_float(*self.NUM_RANGE, size=1)[0])) + + "*" + + perturb_var + ) return ( "import math;" + assign_var + "=" - + "math.sqrt("+str(random_ass)+")", + + "math.sqrt(" + + str(random_ass) + + ")", is_math_domain_err, ) + def _relevant_operand_val_type(self, val, is_same): pass @@ -203,7 +210,7 @@ def get_operand_type_mismatch_perturb_expression(self, perturb_var, perturb_val) def get_int_not_iterable_perturb_expression(self, perturb_var, perturb_val): """TODO: 1. Add more variants of the for loop. - 2. Add logic to include the while loop. + 2. Add logic to include the while loop. """ is_int_not_iterable_err = get_random_int(0, 1) # is_int_not_iterable_err = 1 @@ -212,7 +219,7 @@ def get_int_not_iterable_perturb_expression(self, perturb_var, perturb_val): random_ass = int(get_random_float(*self.NUM_RANGE, size=1)[0]) return ( "{}=[{}+val for val in {}]".format(assign_var, random_ass, perturb_var), - is_int_not_iterable_err + is_int_not_iterable_err, ) else: return "", is_int_not_iterable_err @@ -222,4 +229,3 @@ def add_err(self, err_type, perturb_var, perturb_val): if not expr_builder: raise ValueError(err_type + " is not a valid error generation function.") return expr_builder(perturb_var, perturb_val) - \ No newline at end of file diff --git a/error_generation/main.py b/error_generation/main.py index 016e89eb..90adecbf 100644 --- a/error_generation/main.py +++ b/error_generation/main.py @@ -27,6 +27,7 @@ python -m error_generation.main """ + def main(config_fp): config = load_yaml(config_fp) set_seeds() @@ -66,9 +67,16 @@ def main(config_fp): ) for err_suffix in config["errors"]: # import pdb;pdb.set_trace() - _ = add_error(code_path, data_trace_path, out_code_path, err_suffix, error_expr_factory_obj) + _ = add_error( + code_path, + data_trace_path, + out_code_path, + err_suffix, + error_expr_factory_obj, + ) # break # break -if __name__ == '__main__': + +if __name__ == "__main__": main("error_generation/config.yaml") diff --git a/error_generation/misc_utils.py b/error_generation/misc_utils.py index 3a01ac64..7b899d32 100644 --- a/error_generation/misc_utils.py +++ b/error_generation/misc_utils.py @@ -21,14 +21,15 @@ from collections import defaultdict ALLOWED_TYPE = { - "zero_err": [int, float], - "assert_err": [int, float], - "not_subscriptable_err": [int, float], - "idx_out_range_err": [list], - "undef_var_err": [int, float, str, list], - "math_domain_err": [int, float], - "not_iterable_err": [int, float] - } + "zero_err": [int, float], + "assert_err": [int, float], + "not_subscriptable_err": [int, float], + "idx_out_range_err": [list], + "undef_var_err": [int, float, str, list], + "math_domain_err": [int, float], + "not_iterable_err": [int, float], +} + def get_codeforces_inp_data_paths(base_path): inp_data_directory = os.path.join(base_path, "samples") @@ -90,21 +91,27 @@ def get_codeforeces_paths(base_path): data_paths = [sol_path for path in data_paths for sol_path in path] return data_paths + def is_valid_type(val, type): return isinstance(val, type) + def get_valid_code_trace(code_trace, err_suffx): code_trace_filtered = defaultdict(list) for line in code_trace: for step_idx in range(len(code_trace[line])): new_step_dict = {} for var in code_trace[line][step_idx]: - if any(is_valid_type(code_trace[line][step_idx][var], typ) for typ in ALLOWED_TYPE[err_suffx]): + if any( + is_valid_type(code_trace[line][step_idx][var], typ) + for typ in ALLOWED_TYPE[err_suffx] + ): new_step_dict[var] = code_trace[line][step_idx][var] if new_step_dict: code_trace_filtered[line].append(new_step_dict) return code_trace_filtered + def set_seeds(seed=10): random.seed(seed) np.random.seed(seed) @@ -120,10 +127,12 @@ def load_json(fp): with open(fp, "r") as file: return json.load(file) + def load_yaml(fp): with open(fp, "r") as file: return yaml.load(file, Loader=yaml.FullLoader) + def write_csv(data, fp): with open(fp, "w") as file: file.write(data) diff --git a/error_generation/trace_code.py b/error_generation/trace_code.py index 129b749a..11c31e79 100644 --- a/error_generation/trace_code.py +++ b/error_generation/trace_code.py @@ -1,4 +1,4 @@ -# Get line by line trace for variables in a program. +# Get line by line trace for variables in a program. # It is adapted from the examples in the following link: # https://pymotw.com/2/sys/tracing.html diff --git a/tests/test_err_expr_factory.py b/tests/test_err_expr_factory.py index 9e489aaa..4205e616 100644 --- a/tests/test_err_expr_factory.py +++ b/tests/test_err_expr_factory.py @@ -2,29 +2,37 @@ import subprocess from error_generation.error_expression_factory import ErrorFactory -class TestErrorFactory(): + +class TestErrorFactory: """The test suite is not complete but will test some obvious - issues in code.""" + issues in code.""" + def setup(self): self.error_factory = ErrorFactory() self.test_var_name = "test_var" self.test_var_val = 23 self.test_lst_var_name = "test_lst_var" - self.test_lst_var_val = [1,2,3,4] + self.test_lst_var_val = [1, 2, 3, 4] self.test_var_assign = self.test_var_name + "=" + str(self.test_var_val) + "\n" - self.test_lst_var_assign = self.test_lst_var_name + "=" + str(self.test_lst_var_val) + "\n" + self.test_lst_var_assign = ( + self.test_lst_var_name + "=" + str(self.test_lst_var_val) + "\n" + ) def test_get_zero_perturb_expression(self): - - expr, is_err_present = self.error_factory.add_err("zero_err", self.test_var_name, self.test_var_val) + + expr, is_err_present = self.error_factory.add_err( + "zero_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr - + while not is_err_present: subprocess_call = subprocess.call( ["python", "-c", expr], stderr=subprocess.PIPE ) assert subprocess_call == 0 - expr, is_err_present = self.error_factory.add_err("zero_err", self.test_var_name, self.test_var_val) + expr, is_err_present = self.error_factory.add_err( + "zero_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr with pytest.raises(subprocess.CalledProcessError) as exc: @@ -35,16 +43,20 @@ def test_get_zero_perturb_expression(self): except subprocess.CalledProcessError as exception: if "ZeroDivisionError" in exception.output: raise exception - + def test_get_assert_perturb_expression(self): - expr, is_err_present = self.error_factory.add_err("assert_err", self.test_var_name, self.test_var_val) + expr, is_err_present = self.error_factory.add_err( + "assert_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr while not is_err_present: subprocess_call = subprocess.call( ["python", "-c", expr], stderr=subprocess.PIPE ) assert subprocess_call == 0 - expr, is_err_present = self.error_factory.add_err("assert_err", self.test_var_name, self.test_var_val) + expr, is_err_present = self.error_factory.add_err( + "assert_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr with pytest.raises(subprocess.CalledProcessError) as exc: try: @@ -54,16 +66,20 @@ def test_get_assert_perturb_expression(self): except subprocess.CalledProcessError as exception: if "AssertionError" in exception.output: raise exception - + def test_get_not_subscriptable_perturb_expression(self): - expr, is_err_present = self.error_factory.add_err("not_subscriptable_err", self.test_var_name, self.test_var_val) + expr, is_err_present = self.error_factory.add_err( + "not_subscriptable_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr while not is_err_present: subprocess_call = subprocess.call( ["python", "-c", expr], stderr=subprocess.PIPE ) assert subprocess_call == 0 - expr, is_err_present = self.error_factory.add_err("not_subscriptable_err", self.test_var_name, self.test_var_val) + expr, is_err_present = self.error_factory.add_err( + "not_subscriptable_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr with pytest.raises(subprocess.CalledProcessError) as exc: try: @@ -71,22 +87,31 @@ def test_get_not_subscriptable_perturb_expression(self): ["python", "-c", expr], stderr=subprocess.STDOUT ) except subprocess.CalledProcessError as exception: - if "TypeError" in exception.output and "not support item assignment" in exception.output: + if ( + "TypeError" in exception.output + and "not support item assignment" in exception.output + ): raise exception - + def test_get_index_range_perturb_expression(self): - expr, is_err_present = self.error_factory.add_err("idx_out_range_err", self.test_lst_var_name, self.test_lst_var_val) + expr, is_err_present = self.error_factory.add_err( + "idx_out_range_err", self.test_lst_var_name, self.test_lst_var_val + ) expr = self.test_lst_var_assign + expr while not is_err_present: subprocess_call = subprocess.call( ["python", "-c", expr], stderr=subprocess.PIPE ) - if subprocess_call!=0: - import pdb;pdb.set_trace() + if subprocess_call != 0: + import pdb + + pdb.set_trace() assert subprocess_call == 0 - expr, is_err_present = self.error_factory.add_err("idx_out_range_err", self.test_lst_var_name, self.test_lst_var_val) + expr, is_err_present = self.error_factory.add_err( + "idx_out_range_err", self.test_lst_var_name, self.test_lst_var_val + ) expr = self.test_lst_var_assign + expr with pytest.raises(subprocess.CalledProcessError) as exc: @@ -95,11 +120,16 @@ def test_get_index_range_perturb_expression(self): ["python", "-c", expr], stderr=subprocess.STDOUT ) except subprocess.CalledProcessError as exception: - if "IndexError" in exception.output and "index out of range" in exception.output: + if ( + "IndexError" in exception.output + and "index out of range" in exception.output + ): raise exception - + def test_get_math_domain_perturb_expression(self): - expr, is_err_present = self.error_factory.add_err("math_domain_err", self.test_var_name, self.test_var_val) + expr, is_err_present = self.error_factory.add_err( + "math_domain_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr while not is_err_present: @@ -107,20 +137,27 @@ def test_get_math_domain_perturb_expression(self): ["python", "-c", expr], stderr=subprocess.PIPE ) assert subprocess_call == 0 - expr, is_err_present = self.error_factory.add_err("math_domain_err", self.test_var_name, self.test_var_val) + expr, is_err_present = self.error_factory.add_err( + "math_domain_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr - + with pytest.raises(subprocess.CalledProcessError) as exc: try: subprocess_call = subprocess.check_output( ["python", "-c", expr], stderr=subprocess.STDOUT ) except subprocess.CalledProcessError as exception: - if "ValueError" in exception.output and "math domain error" in exception.output: + if ( + "ValueError" in exception.output + and "math domain error" in exception.output + ): raise exception - + def test_get_int_not_iterable_perturb_expression(self): - expr, is_err_present = self.error_factory.add_err("not_iterable_err", self.test_var_name, self.test_var_val) + expr, is_err_present = self.error_factory.add_err( + "not_iterable_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr while not is_err_present: @@ -128,7 +165,9 @@ def test_get_int_not_iterable_perturb_expression(self): ["python", "-c", expr], stderr=subprocess.PIPE ) assert subprocess_call == 0 - expr, is_err_present = self.error_factory.add_err("not_iterable_err", self.test_var_name, self.test_var_val) + expr, is_err_present = self.error_factory.add_err( + "not_iterable_err", self.test_var_name, self.test_var_val + ) expr = self.test_var_assign + expr with pytest.raises(subprocess.CalledProcessError) as exc: @@ -137,11 +176,14 @@ def test_get_int_not_iterable_perturb_expression(self): ["python", "-c", expr], stderr=subprocess.STDOUT ) except subprocess.CalledProcessError as exception: - if "TypeError" in exception.output and "object is not iterable" in exception.output: + if ( + "TypeError" in exception.output + and "object is not iterable" in exception.output + ): raise exception def test_get_operand_type_mismatch_perturb_expression(self): pass def test_get_undef_name_perturb_expression(self): - pass \ No newline at end of file + pass