Skip to content

Commit

Permalink
Evaluate 1 rollout per experiment call (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsavarela committed Mar 1, 2020
1 parent 83ea365 commit 7940650
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 66 deletions.
25 changes: 7 additions & 18 deletions ilurl/core/experiment.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,16 @@ class can generate csv files from emission files produced by sumo. These
the environment object the simulator will run
"""

def __init__(self, env, dir_path=EMISSION_PATH, train=True, policies=None):
def __init__(self, env, dir_path=EMISSION_PATH, train=True):
"""Instantiate Experiment."""
if not train and policies is None:
raise ValueError(
f"In validation mode an array of policies must be provided"
)

sim_step = env.sim_params.sim_step
# garantees that the enviroment has stoped
if not train:
env.stop = True

self.env = env
self.train = train
self.dir_path = dir_path
self.Qs = policies
# fails gracifully if an environment with no cycle time
# is provided
self.cycle = getattr(env, 'cycle_time', None)
Expand Down Expand Up @@ -194,7 +192,7 @@ def rl_actions(*_):
break

# for every 100 decisions -- save Q
if self._is_save_q_table():
if self._is_save_q_table_step():
n = int(j / self.save_step) + 1
filename = \
f'{self.env.network.name}.Q.{i + 1}-{n}.pickle'
Expand All @@ -203,10 +201,6 @@ def rl_actions(*_):
filename,
attr_name='Q')

elif self._is_swap_q_table():
if i < len(self.Qs):
self.env.Q = self.Qs[i]

vels.append(vel_list)
vehs.append(veh_list)
observation_spaces.append(obs_list)
Expand Down Expand Up @@ -257,12 +251,7 @@ def _is_save_step(self):
return self.env.duration == 0.0
return self.step_counter % self.save_step == 0

def _is_save_q_table(self):
def _is_save_q_table_step(self):
if self.env.step_counter % (100 * self.save_step) == 0:
return self.train and hasattr(self.env, 'dump') and self.dir_path
return False

def _is_swap_q_table(self):
if self.env.step_counter % (100 * self.save_step) == 0:
return not self.train
return False
235 changes: 187 additions & 48 deletions models/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
__author__ = 'Guilherme Varela'
__date__ = '2019-09-24'

import multiprocessing as mp
from copy import deepcopy
from collections import defaultdict, OrderedDict
import re
import os
from glob import glob
import json
import argparse
import pdb

import dill

Expand All @@ -32,6 +35,7 @@

TIMESTAMP_PROG = re.compile(r'[0-9]{8}\-[0-9]{8,}\.[0-9]+')


def get_arguments():
parser = argparse.ArgumentParser(
description="""
Expand All @@ -52,13 +56,6 @@ def get_arguments():
Q-table. This argument if provided takes the
precedence over time parameter''')

parser.add_argument('--time', '-t', dest='time', type=int,
default=None, nargs='?',
help=
'''Simulation\'s real world time in seconds
for a single rollout of a Q-table. Is ignored
unless `cycles` parameter is not provided.''')

parser.add_argument('--switch', '-w', dest='switch', type=str2bool,
default=False, nargs='?',
help=
Expand All @@ -67,8 +64,7 @@ def get_arguments():

parser.add_argument('--render', '-r', dest='render', type=str2bool,
default=False, nargs='?',
help='Renders the simulation')

help='Renders the simulation')
return parser.parse_args()


Expand All @@ -83,31 +79,76 @@ def str2bool(v):
raise argparse.ArgumentTypeError('Boolean value expected.')


def evaluate(env, code, q_tables, horizon, path, render=False):
# load all networks with config -- else build
results = []
for i, qtb in enumerate(q_tables):
#TODO: save network objects rather then routes
env1 = TrafficLightQLEnv(env.env_params,
env.sim_params,
env.ql_params,
env.network)
env1.sim_params.render = render
env1.stop = True
num_iterations = 1
eval = Experiment(
env1,
dir_path=path,
train=False,
q_tables=q_tables
)

print(f"Running evaluation {i + 1}")
result = eval.run(num_iterations, horizon)
results.append(result)

# TODO: concatenate results along axis=1
return results
def evaluate(env_params, sim_params, ql_params, network, horizon, qtb):
"""Evaluate
Params:
-------
* env_params: ilurl.core.params.EnvParams
objects parameters
* sim_params: ilurl.core.params.SumoParams
objects parameters
* ql_params: ilurl.core.params.QLParams
objects parameters
* network: ilurl.networks.Network
Returns:
--------
* info: dict
"""
env1 = TrafficLightQLEnv(
env_params,
sim_params,
ql_params,
network
)
env1.Q = qtb
env1.stop = True
eval = Experiment(
env1,
dir_path=None,
train=False,
)
result = eval.run(1, horizon)
return result


def concat(evaluations, horizon):
"""Receives an experiments' json and merges it's contents
Params:
-------
* evaluations: list
list of rollout evaluations
Returns:
--------
* result: dict
where `id` key certifies that experiments are the same
`list` params are united
`numeric` params are appended
"""
result = defaultdict(list)
for evl in evaluations:
id = evl.pop('id')
if 'id' not in result:
result['id'] = id
elif result['id'] != id:
raise ValueError(
'Tried to concatenate evaluations from different experiments')
for k, v in evl.items():
if k != 'id':
if isinstance(v, list):
result[k] = result[k] + v
else:
result[k].append(v)
result['horizon'] = horizon
return result


def parse_all(paths):
Expand All @@ -116,7 +157,6 @@ def parse_all(paths):
data = parse(path)
# store argument condition
if data is not None:

if len(data) == 4:
key = data[1:]
key1 = None # nested key
Expand All @@ -127,7 +167,7 @@ def parse_all(paths):
raise ValueError(f'{data} not recognized')

if key1 is None:
parsed_dict[key]['source'] = path
parsed_dict[key]['env'] = path
else:
parsed_dict[key][key1] = path
return parsed_dict
Expand Down Expand Up @@ -168,9 +208,9 @@ def parse(x):
Usage:
------
> x = \
> x = \
'data/experiments/0x04/6030/' +
'intersection_20200227-1131341582803094.6042109' +
'intersection_20200227-1131341582803094.6042109' +
'.Q.1-8.pickle'
> y = parse(x)
> y[0]
Expand Down Expand Up @@ -205,38 +245,137 @@ def parse(x):
return source_dir, network_id, timestamp, iter, cycles, ext


def sort_all(data):
"""Performs sort accross multiple experiments, within
each experiment
Params:
------
* qtbs: dict
keys are tuples (<iter_num>,<cycles_num>)
values are Q-tables
Returns:
-------
* OrderedDict
"""
result = defaultdict(OrderedDict)
for exid, qtbs in data.items():
result[exid]['env'] = qtbs['env']
tmp = {k: v for k, v in qtbs.items() if k != 'env'}
tmp = sort_tables(tmp.items())
for trial, path in tmp.items():
result[exid][trial] = path
return result


def sort_tables(qtbs):
"""Sort Q-tables dictionary
Params:
------
* qtbs: dict
keys are tuples (<iter_num>,<cycles_num>)
values are Q-tables
Returns:
-------
* OrderedDict
"""
qtbs = sorted(qtbs, key=lambda x: x[0][1])
qtbs = sorted(qtbs, key=lambda x: x[0][0])
return OrderedDict(qtbs)


def load_all(data):
"""Converts path variable into objects
Params:
------
* qtbs: dict
keys are tuples (<iter_num>,<cycles_num>)
values are Q-tables
Returns:
-------
* OrderedDict
"""
result = defaultdict(OrderedDict)

for exid, keys_paths in data.items():
for key, path in keys_paths.items():
# traffic light object
if key == 'env':
result[exid]['env'] = TrafficLightQLEnv.load(path)
else:
with open(path, 'rb') as f:
result[exid][key] = dill.load(f)
return result


if __name__ == '__main__':
args = get_arguments()

dir_pickle = args.dir_pickle
x = 'w' if args.switch else 'l'
render = args.render
time = args.time
cycles = args.cycles

paths = glob(f"{dir_pickle}/*.pickle")
if not any(paths):
raise Exception("Environment pickles must be saved on root")

# process data
# converts paths into dictionary
data = parse_all(paths)
# sort experiments by instances and cycles
data = sort_all(data)
# converts paths into objects
data = load_all(data)
num_experiments = len(data)
# one element will be the environment
num_rollouts = sum([len(d) - 1 for d in data.values()])
i = 1
results = []
for exid, qtbs in data.items():
# for each experiment perform rowouts
# lexografical sort
print(f"""
experiment:\t{i}/{num_experiments}
network_id:\t{exid[0]}
timestamp:\t{exid[1]}
rollouts:\t{len(qtbs)}
rollouts:\t{len(qtbs) - 1}/{num_rollouts}
""")
env = TrafficLightQLEnv.load(qtbs.pop('source'))
qtbs = sort_tables(qtbs.items())
results = \
evaluate(env, x, qtbs, cycles, dir_pickle, render=render)

# TODO: concatenate axis=1
# TODO: save
env = qtbs.pop('env')
cycle_time = getattr(env, 'cycle_time', 1)
env_params = env.env_params
sim_params = env.sim_params
sim_params.render = render

horizon = int((cycle_time * cycles) / sim_params.sim_step)
ql_params = env.ql_params
network = env.network

def fn(qtb):
ret = evaluate(
env_params,
sim_params,
ql_params,
network,
horizon,
qtb)
return ret

# TODO: to do enable multiprocessing
# pool = mp.Pool(3)
# results = pool.map(gn, qtbs.items())
# pool.close()

results = [fn(qtb) for qtb in qtbs.values()]
info = concat(results, horizon)
filename = '_'.join(exid[:2])
file_path = f'{dir_pickle}/{filename}.eval.json'
with open(file_path, 'w') as f:
json.dump(info, f)

0 comments on commit 7940650

Please sign in to comment.