diff --git a/ilurl/core/experiment.py b/ilurl/core/experiment.py old mode 100755 new mode 100644 index f85ba68..3b9dc48 --- a/ilurl/core/experiment.py +++ b/ilurl/core/experiment.py @@ -80,18 +80,16 @@ class can generate csv files from emission files produced by sumo. These the environment object the simulator will run """ - def __init__(self, env, dir_path=EMISSION_PATH, train=True, policies=None): + def __init__(self, env, dir_path=EMISSION_PATH, train=True): """Instantiate Experiment.""" - if not train and policies is None: - raise ValueError( - f"In validation mode an array of policies must be provided" - ) - sim_step = env.sim_params.sim_step + # garantees that the enviroment has stoped + if not train: + env.stop = True + self.env = env self.train = train self.dir_path = dir_path - self.Qs = policies # fails gracifully if an environment with no cycle time # is provided self.cycle = getattr(env, 'cycle_time', None) @@ -194,7 +192,7 @@ def rl_actions(*_): break # for every 100 decisions -- save Q - if self._is_save_q_table(): + if self._is_save_q_table_step(): n = int(j / self.save_step) + 1 filename = \ f'{self.env.network.name}.Q.{i + 1}-{n}.pickle' @@ -203,10 +201,6 @@ def rl_actions(*_): filename, attr_name='Q') - elif self._is_swap_q_table(): - if i < len(self.Qs): - self.env.Q = self.Qs[i] - vels.append(vel_list) vehs.append(veh_list) observation_spaces.append(obs_list) @@ -257,12 +251,7 @@ def _is_save_step(self): return self.env.duration == 0.0 return self.step_counter % self.save_step == 0 - def _is_save_q_table(self): + def _is_save_q_table_step(self): if self.env.step_counter % (100 * self.save_step) == 0: return self.train and hasattr(self.env, 'dump') and self.dir_path return False - - def _is_swap_q_table(self): - if self.env.step_counter % (100 * self.save_step) == 0: - return not self.train - return False diff --git a/models/evaluate.py b/models/evaluate.py index 619d477..64f549d 100644 --- a/models/evaluate.py +++ b/models/evaluate.py @@ -8,12 +8,15 @@ __author__ = 'Guilherme Varela' __date__ = '2019-09-24' +import multiprocessing as mp +from copy import deepcopy from collections import defaultdict, OrderedDict import re import os from glob import glob import json import argparse +import pdb import dill @@ -32,6 +35,7 @@ TIMESTAMP_PROG = re.compile(r'[0-9]{8}\-[0-9]{8,}\.[0-9]+') + def get_arguments(): parser = argparse.ArgumentParser( description=""" @@ -52,13 +56,6 @@ def get_arguments(): Q-table. This argument if provided takes the precedence over time parameter''') - parser.add_argument('--time', '-t', dest='time', type=int, - default=None, nargs='?', - help= - '''Simulation\'s real world time in seconds - for a single rollout of a Q-table. Is ignored - unless `cycles` parameter is not provided.''') - parser.add_argument('--switch', '-w', dest='switch', type=str2bool, default=False, nargs='?', help= @@ -67,8 +64,7 @@ def get_arguments(): parser.add_argument('--render', '-r', dest='render', type=str2bool, default=False, nargs='?', - help='Renders the simulation') - + help='Renders the simulation') return parser.parse_args() @@ -83,31 +79,76 @@ def str2bool(v): raise argparse.ArgumentTypeError('Boolean value expected.') -def evaluate(env, code, q_tables, horizon, path, render=False): - # load all networks with config -- else build - results = [] - for i, qtb in enumerate(q_tables): - #TODO: save network objects rather then routes - env1 = TrafficLightQLEnv(env.env_params, - env.sim_params, - env.ql_params, - env.network) - env1.sim_params.render = render - env1.stop = True - num_iterations = 1 - eval = Experiment( - env1, - dir_path=path, - train=False, - q_tables=q_tables - ) - - print(f"Running evaluation {i + 1}") - result = eval.run(num_iterations, horizon) - results.append(result) - - # TODO: concatenate results along axis=1 - return results +def evaluate(env_params, sim_params, ql_params, network, horizon, qtb): + """Evaluate + + Params: + ------- + * env_params: ilurl.core.params.EnvParams + objects parameters + + * sim_params: ilurl.core.params.SumoParams + objects parameters + + * ql_params: ilurl.core.params.QLParams + objects parameters + + * network: ilurl.networks.Network + + Returns: + -------- + * info: dict + + """ + env1 = TrafficLightQLEnv( + env_params, + sim_params, + ql_params, + network + ) + env1.Q = qtb + env1.stop = True + eval = Experiment( + env1, + dir_path=None, + train=False, + ) + result = eval.run(1, horizon) + return result + + +def concat(evaluations, horizon): + """Receives an experiments' json and merges it's contents + + Params: + ------- + * evaluations: list + list of rollout evaluations + + Returns: + -------- + * result: dict + where `id` key certifies that experiments are the same + `list` params are united + `numeric` params are appended + + """ + result = defaultdict(list) + for evl in evaluations: + id = evl.pop('id') + if 'id' not in result: + result['id'] = id + elif result['id'] != id: + raise ValueError( + 'Tried to concatenate evaluations from different experiments') + for k, v in evl.items(): + if k != 'id': + if isinstance(v, list): + result[k] = result[k] + v + else: + result[k].append(v) + result['horizon'] = horizon + return result def parse_all(paths): @@ -116,7 +157,6 @@ def parse_all(paths): data = parse(path) # store argument condition if data is not None: - if len(data) == 4: key = data[1:] key1 = None # nested key @@ -127,7 +167,7 @@ def parse_all(paths): raise ValueError(f'{data} not recognized') if key1 is None: - parsed_dict[key]['source'] = path + parsed_dict[key]['env'] = path else: parsed_dict[key][key1] = path return parsed_dict @@ -168,9 +208,9 @@ def parse(x): Usage: ------ - > x = \ + > x = \ 'data/experiments/0x04/6030/' + - 'intersection_20200227-1131341582803094.6042109' + + 'intersection_20200227-1131341582803094.6042109' + '.Q.1-8.pickle' > y = parse(x) > y[0] @@ -205,38 +245,137 @@ def parse(x): return source_dir, network_id, timestamp, iter, cycles, ext +def sort_all(data): + """Performs sort accross multiple experiments, within + each experiment + + Params: + ------ + * qtbs: dict + keys are tuples (,) + values are Q-tables + + Returns: + ------- + * OrderedDict + + """ + result = defaultdict(OrderedDict) + for exid, qtbs in data.items(): + result[exid]['env'] = qtbs['env'] + tmp = {k: v for k, v in qtbs.items() if k != 'env'} + tmp = sort_tables(tmp.items()) + for trial, path in tmp.items(): + result[exid][trial] = path + return result + + def sort_tables(qtbs): + """Sort Q-tables dictionary + + Params: + ------ + * qtbs: dict + keys are tuples (,) + values are Q-tables + + Returns: + ------- + * OrderedDict + + """ qtbs = sorted(qtbs, key=lambda x: x[0][1]) qtbs = sorted(qtbs, key=lambda x: x[0][0]) return OrderedDict(qtbs) +def load_all(data): + """Converts path variable into objects + + Params: + ------ + * qtbs: dict + keys are tuples (,) + values are Q-tables + + Returns: + ------- + * OrderedDict + + """ + result = defaultdict(OrderedDict) + + for exid, keys_paths in data.items(): + for key, path in keys_paths.items(): + # traffic light object + if key == 'env': + result[exid]['env'] = TrafficLightQLEnv.load(path) + else: + with open(path, 'rb') as f: + result[exid][key] = dill.load(f) + return result + + if __name__ == '__main__': args = get_arguments() dir_pickle = args.dir_pickle x = 'w' if args.switch else 'l' render = args.render - time = args.time cycles = args.cycles paths = glob(f"{dir_pickle}/*.pickle") if not any(paths): raise Exception("Environment pickles must be saved on root") + # process data + # converts paths into dictionary data = parse_all(paths) + # sort experiments by instances and cycles + data = sort_all(data) + # converts paths into objects + data = load_all(data) + num_experiments = len(data) + # one element will be the environment + num_rollouts = sum([len(d) - 1 for d in data.values()]) + i = 1 + results = [] for exid, qtbs in data.items(): - # for each experiment perform rowouts - # lexografical sort print(f""" + experiment:\t{i}/{num_experiments} network_id:\t{exid[0]} timestamp:\t{exid[1]} - rollouts:\t{len(qtbs)} + rollouts:\t{len(qtbs) - 1}/{num_rollouts} """) - env = TrafficLightQLEnv.load(qtbs.pop('source')) - qtbs = sort_tables(qtbs.items()) - results = \ - evaluate(env, x, qtbs, cycles, dir_pickle, render=render) - # TODO: concatenate axis=1 - # TODO: save + env = qtbs.pop('env') + cycle_time = getattr(env, 'cycle_time', 1) + env_params = env.env_params + sim_params = env.sim_params + sim_params.render = render + + horizon = int((cycle_time * cycles) / sim_params.sim_step) + ql_params = env.ql_params + network = env.network + + def fn(qtb): + ret = evaluate( + env_params, + sim_params, + ql_params, + network, + horizon, + qtb) + return ret + + # TODO: to do enable multiprocessing + # pool = mp.Pool(3) + # results = pool.map(gn, qtbs.items()) + # pool.close() + + results = [fn(qtb) for qtb in qtbs.values()] + info = concat(results, horizon) + filename = '_'.join(exid[:2]) + file_path = f'{dir_pickle}/{filename}.eval.json' + with open(file_path, 'w') as f: + json.dump(info, f)