-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluation.py
29 lines (21 loc) · 860 Bytes
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import Dict
import flax.linen as nn
import gym
import numpy as np
def evaluate(agent: nn.Module, env: gym.Env,
num_episodes: int, offline: bool, verbose: bool = False) -> Dict[str, float]:
stats = {'return': [], 'length': []}
temperature = 1. if env.spec._env_name.split("-")[0] == 'antmaze' else 0.
for _ in range(num_episodes):
observation, done = env.reset(), False
while not done:
action = agent.sample_actions(observation, temperature=temperature, offline=offline)
observation, _, done, info = env.step(action)
for k in stats.keys():
stats[k].append(info['episode'][k])
if verbose:
v = info['episode'][k]
print(f'{k}:{v}')
for k, v in stats.items():
stats[k] = np.mean(v)
return stats