diff --git a/examples/sb3_imitation.py b/examples/sb3_imitation.py index 0da9c921..0411321d 100644 --- a/examples/sb3_imitation.py +++ b/examples/sb3_imitation.py @@ -221,27 +221,32 @@ def close_env(): print("Starting RL Training:") learner.learn(args.rl_timesteps, progress_bar=True) -except KeyboardInterrupt: +except (KeyboardInterrupt, ConnectionError, ConnectionResetError): print( - """Training interrupted by user. Will save if --save_model_path was + """Training interrupted by user or a ConnectionError. Will save if --save_model_path was used and/or export if --onnx_export_path was used.""" ) - -close_env() - -if args.eval_episode_count: - print("Evaluating:") - env = SBGSingleObsEnv( - env_path=args.env_path, - show_window=True, - seed=args.seed, - n_parallel=1, - speedup=args.speedup, +except (KeyboardInterrupt, ConnectionError, ConnectionResetError): + print( + """Training interrupted by user or a ConnectionError. Will save if --save_model_path was + used and/or export if --onnx_export_path was used.""" ) - env = VecMonitor(env) - mean_reward, _ = evaluate_policy(learner, env, n_eval_episodes=args.eval_episode_count) +finally: close_env() - print(f"Mean reward after evaluation: {mean_reward}") + + if args.eval_episode_count: + print("Evaluating:") + env = SBGSingleObsEnv( + env_path=args.env_path, + show_window=True, + seed=args.seed, + n_parallel=1, + speedup=args.speedup, + ) + env = VecMonitor(env) + mean_reward, _ = evaluate_policy(learner, env, n_eval_episodes=args.eval_episode_count) + close_env() + print(f"Mean reward after evaluation: {mean_reward}") -handle_onnx_export() -handle_model_save() + handle_onnx_export() + handle_model_save()