@@ -73,9 +73,12 @@ def init_argparse() -> argparse.ArgumentParser:
73
73
parser .add_argument ("-p" , "--plot" , action = argparse .BooleanOptionalAction )
74
74
parser .add_argument ('-w' , '--weights' , type = str , help = 'a trained model weights' )
75
75
parser .add_argument ("-d" , "--days" , type = int , default = 365 )
76
+ parser .add_argument ("-ev" , "--evaluate" , action = argparse .BooleanOptionalAction )
77
+ parser .add_argument ("-ep" , "--epochs" , type = int , default = 20 )
76
78
return parser
77
79
78
80
81
+
79
82
def main ():
80
83
parser = init_argparse ()
81
84
args = parser .parse_args ()
@@ -90,7 +93,7 @@ def main():
90
93
agent = obs .DQNAgent (action_size )
91
94
92
95
logdir = "tensorboard_logs/scalars/" + datetime .now ().strftime ("%Y%m%d-%H%M%S" )
93
- tensorboard_callback = TensorBoard (log_dir = logdir )
96
+ tensorboard_callback = TensorBoard (log_dir = logdir , histogram_freq = 1 , write_images = False , batch_size = args . batch_size )
94
97
95
98
if args .weights :
96
99
print (f"Loading model { args .weights } ..." )
@@ -107,7 +110,9 @@ def main():
107
110
108
111
if args .train and len (agent .memory ) > args .batch_size :
109
112
print ("Starting replay..." )
110
- agent .replay (args .batch_size , tensorboard_callback )
113
+ score = agent .replay (args .batch_size , args .epochs , args .evaluate , tensorboard_callback )
114
+ if args .evaluate :
115
+ print (f"Score = { score } " )
111
116
112
117
if args .train and (episode + 1 ) % 10 == 0 : # checkpoint weights
113
118
print ("Saving..." )
0 commit comments