-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
55 lines (45 loc) · 1.76 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import gym
import numpy as np
import random
import torch
import rollout
def evaluate_agent(agent, env, num_episodes):
"""Evaluate an agent
Args:
agent: the agent to evaluate
env: the environment to evaluate on
num_episodes: the number of episodes to evaluate for
Returns:
reward_avg: average cumulative reward seen in the episodes
"""
agent.eval()
if hasattr(env, "rollout"):
returns = env.rollout(agent, None, num_episodes, env.env._max_episode_steps)
else:
returns = np.zeros(num_episodes)
for idx_episode in range(num_episodes):
_, _, return_episode = rollout.rollout_episode(env, agent)
returns[idx_episode] = return_episode
return np.mean(returns)
def evaluate_model(model, dataset, fns_eval, device):
"""Evaluate a model
Args:
model: the model to evaluate
dataset: the dataset to evaluate on
fns_eval: list of evaluation functions
device: the device to use
Returns:
scores: list of the evaluation scores (one for every evaluation function)
"""
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=len(dataset))
model.eval()
with torch.no_grad():
state, action, reward, state_next, _ = next(iter(dataloader))
reward = reward.unsqueeze(dim=1)
x = torch.cat((state, action[0]), dim=-1).to(device)
y_pred_means, y_pred_stds = model(x)
y = torch.cat((reward, state_next - state), dim=-1).to(device)
y = model.scaler_y.transform(y)
scores = [fn_eval(y_pred_means, y_pred_stds, y, state=state, action=action, temperature=model.temperature) for fn_eval in fns_eval]
return scores