From 281ba22bea8503c87a418b76e667ab817c1718fb Mon Sep 17 00:00:00 2001 From: Karishnu Poddar Date: Fri, 20 Dec 2019 20:25:41 +0530 Subject: [PATCH] Change imports in tf-agents example to support updated version --- tf-agents-example/gridworld.py | 4 ++-- tf-agents-example/simulate.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tf-agents-example/gridworld.py b/tf-agents-example/gridworld.py index e3c954a..decf4c5 100644 --- a/tf-agents-example/gridworld.py +++ b/tf-agents-example/gridworld.py @@ -6,9 +6,9 @@ from tf_agents.environments import tf_environment from tf_agents.environments import tf_py_environment from tf_agents.environments import utils -from tf_agents.specs import array_spec -from tf_agents.environments import time_step as ts from tf_agents.environments import wrappers +from tf_agents.specs import array_spec +from tf_agents.trajectories import time_step as ts tf.compat.v1.enable_v2_behavior() diff --git a/tf-agents-example/simulate.py b/tf-agents-example/simulate.py index a9dd713..76573d2 100644 --- a/tf-agents-example/simulate.py +++ b/tf-agents-example/simulate.py @@ -1,20 +1,21 @@ import tensorflow as tf from tf_agents.agents.dqn import dqn_agent -from tf_agents.agents.dqn import q_network +from tf_agents.networks import q_network from tf_agents.drivers import dynamic_step_driver from tf_agents.environments import tf_py_environment -from tf_agents.environments import trajectory from tf_agents.environments import wrappers -from tf_agents.metrics import metric_utils +from tf_agents.eval import metric_utils from tf_agents.metrics import tf_metrics from tf_agents.policies import random_tf_policy from tf_agents.replay_buffers import tf_uniform_replay_buffer +from tf_agents.trajectories import trajectory from tf_agents.utils import common from tf_agents.metrics import py_metrics from tf_agents.metrics import tf_metrics from tf_agents.drivers import py_driver from tf_agents.drivers import dynamic_episode_driver +from tf_agents.utils import common from gridworld import GridWorldEnv import matplotlib.pyplot as plt @@ -71,7 +72,7 @@ def compute_avg_return(environment, policy, num_episodes=10): train_env.action_spec(), q_network=q_net, optimizer=optimizer, - td_errors_loss_fn = dqn_agent.element_wise_squared_loss, + td_errors_loss_fn = common.element_wise_squared_loss, train_step_counter=train_step_counter) tf_agent.initialize()