@@ -30,6 +30,9 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
30
30
parser = argparse .ArgumentParser ("Gather results, plot them and create table" )
31
31
parser .add_argument ("-i" , "--input" , help = "Input filename (numpy archive)" , type = str )
32
32
parser .add_argument ("-skip" , "--skip-envs" , help = "Environments to skip" , nargs = "+" , default = [], type = str )
33
+ parser .add_argument ("--keep-envs" , help = "Envs to keep" , nargs = "+" , default = [], type = str )
34
+ parser .add_argument ("--skip-keys" , help = "Keys to skip" , nargs = "+" , default = [], type = str )
35
+ parser .add_argument ("--keep-keys" , help = "Keys to keep" , nargs = "+" , default = [], type = str )
33
36
parser .add_argument ("--no-million" , action = "store_true" , default = False , help = "Do not convert x-axis to million" )
34
37
parser .add_argument ("--skip-timesteps" , action = "store_true" , default = False , help = "Do not display learning curves" )
35
38
parser .add_argument ("-o" , "--output" , help = "Output filename (image)" , type = str )
@@ -40,6 +43,7 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
40
43
parser .add_argument ("-l" , "--labels" , help = "Custom labels" , type = str , nargs = "+" )
41
44
parser .add_argument ("-b" , "--boxplot" , help = "Enable boxplot" , action = "store_true" , default = False )
42
45
parser .add_argument ("-latex" , "--latex" , help = "Enable latex support" , action = "store_true" , default = False )
46
+ parser .add_argument ("--merge" , help = "Merge with other results files" , nargs = "+" , default = [], type = str )
43
47
44
48
args = parser .parse_args ()
45
49
@@ -69,8 +73,26 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
69
73
70
74
del results ["results_table" ]
71
75
72
- keys = [key for key in results [list (results .keys ())[0 ]].keys ()]
76
+ for filename in args .merge :
77
+ # Merge other files
78
+ with open (filename , "rb" ) as file_handler :
79
+ results_2 = pickle .load (file_handler )
80
+ del results_2 ["results_table" ]
81
+ for key in results .keys ():
82
+ if key in results_2 :
83
+ for new_key in results_2 [key ].keys ():
84
+ results [key ][new_key ] = results_2 [key ][new_key ]
85
+
86
+
87
+ keys = [key for key in results [list (results .keys ())[0 ]].keys () if key not in args .skip_keys ]
88
+ print (f"keys: { keys } " )
89
+ if len (args .keep_keys ) > 0 :
90
+ keys = [key for key in keys if key in args .keep_keys ]
73
91
envs = [env for env in results .keys () if env not in args .skip_envs ]
92
+
93
+ if len (args .keep_envs ) > 0 :
94
+ envs = [env for env in envs if env in args .keep_envs ]
95
+
74
96
labels = {key : key for key in keys }
75
97
if args .labels is not None :
76
98
for key , label in zip (keys , args .labels ):
@@ -129,9 +151,10 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
129
151
# plt.title('Influence of the time feature', fontsize=args.fontsize)
130
152
# plt.title('Influence of the network architecture', fontsize=args.fontsize)
131
153
# plt.title('Influence of the exploration variance $log \sigma$', fontsize=args.fontsize)
132
- plt .title ("Influence of the sampling frequency" , fontsize = args .fontsize )
154
+ # plt.title("Influence of the sampling frequency", fontsize=args.fontsize)
133
155
# plt.title('Parallel vs No Parallel Sampling', fontsize=args.fontsize)
134
156
# plt.title('Influence of the exploration function input', fontsize=args.fontsize)
157
+ plt .title ("PyBullet envs" , fontsize = args .fontsize )
135
158
plt .xticks (fontsize = 13 )
136
159
plt .xlabel ("Environment" , fontsize = args .fontsize )
137
160
plt .ylabel ("Score" , fontsize = args .fontsize )
0 commit comments