-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_yaml.py
311 lines (264 loc) · 11.9 KB
/
demo_yaml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 25 16:10:00 2018
@author: Ash & Markus
"""
import scipy as sp
import numpy as np
import numpy.random as rnd
import sys
import os
import distree as dst
import distree.schedulers as SL
import anytree as atr
import yaml
import logging
import datetime
import argparse
class Distree_Demo(dst.Distree):
def get_taskdata_path(self, task_id):
return self.data_path + '%s.yaml' % task_id
def save_task_data(self, taskdata_path, data, task_id, parent_id):
data.update(task_id=task_id, parent_id=parent_id)
with open(taskdata_path, 'w') as f:
yaml.dump(data, f, default_flow_style=False)
def load_task_data(self, taskdata_path):
with open(taskdata_path, 'r') as f:
taskdata = yaml.load(f)
task_id = taskdata.get("task_id", None)
parent_id = taskdata.get("parent_id", None)
return taskdata, task_id, parent_id
@staticmethod
def branch_treepath(parent_treepath, branch_num):
return parent_treepath + '/%u' % branch_num
@staticmethod
def measure_data(state, taskdata):
# TODO Actually measure meaningful things about an MPS.
data = state[0]
return data
@staticmethod
def generate_state_filename(task_id, t):
filename = "{}_t{}.npy".format(task_id, t)
return filename
def store_state(self, state, filename=None, **kwargs):
if not filename:
filename = self.generate_state_filename(**kwargs)
path = self.data_path + filename
np.save(path, state)
return path
def load_state(self, path):
state = np.load(path)
return state
def should_branch(self, state, t, t_0, taskdata):
# TODO Actually determine based on the state and elapsed time whether
# we should try branching or not.
return state > 4
def evolve_state(self, state, taskdata):
# TODO Actually do time-evolution of the state.
state += 1.0
time_increment = 1.0
return state, time_increment
def run_task(self, taskdata_path):
# The taskdata file contains everything needed to run the task.
# Initial values, parameters, and so on.
# It also contains the task_id, generated by the scheduler when the
# task was scheduled, and the parent_id, which may be `None`.
taskdata, task_id, parent_id = self.load_task_data(taskdata_path)
# Put some data into local variables for convenience
state_paths = taskdata["state_paths"]
# TODO Change the value of measurements to not be a dictionary, but
# perhaps a path to a single data file, maybe in HDF5 or in plaintext.
measurements = taskdata.get("measurements", {})
has_children = "children" in taskdata and taskdata["children"]
if has_children:
# TODO Decide what to do here. Maybe ask the user to set some
# command line parameter saying "Yes, please recreate children."
msg = "This task already has children."
raise NotImplementedError(msg)
t_0 = min(state_paths)
prev_checkpoint = max(state_paths)
t = prev_checkpoint # The current time in the evolution.
state_path = state_paths[t]
state = self.load_state(state_path)
try:
prev_measurement_time = max(measurements)
except ValueError:
prev_measurement_time = -np.inf
while t < taskdata["t_max"]:
state, t_increment = self.evolve_state(state, taskdata)
t += t_increment
if t - prev_measurement_time >= taskdata["measurement_frequency"]:
data_entry = self.measure_data(state, taskdata)
measurements[t] = data_entry
prev_measurement_time = t
if t - prev_checkpoint >= taskdata["checkpoint_frequency"]:
state_paths[t] = self.store_state(state, t=t, task_id=task_id)
prev_checkpoint = t
if self.should_branch(state, t, t_0, taskdata):
taskdata = self.branch(state, t, taskdata, task_id)
break
# Always store the state at the end of the simulation.
# TODO Fix the fact that this gets run even if t > t_max from the
# start.
state_paths[t] = self.store_state(state, t=t, task_id=task_id)
# Save the final taskdata, overwriting the initial data file(s)
# Note that the values in taskdata that have been modified, have been
# modified in place.
self.save_task_data(taskdata_path, taskdata, task_id, parent_id)
def branch(self, state, t, taskdata, task_id):
parent_treepath = taskdata['parent_treepath']
branch_num = taskdata["branch_num"]
treepath = self.branch_treepath(parent_treepath, branch_num)
# TODO Do actual branching of MPSes.
if state % 2 == 0:
children = [state/2-1, state/2+1]
coeffs = [0.7, 0.3]
else:
children = [state]
coeffs = [1.0]
num_children = len(children)
taskdata['num_children'] = 2
# Create taskdata files for, and schedule, children
for i, (child, child_coeff) in enumerate(zip(children, coeffs)):
# child_id = self.sched.get_id() # Instead of this, see below.
child_id = "{}_c{}".format(task_id, i)
child_state_path = dtree.store_state(child, t=t, task_id=child_id)
child_taskdata = {
'parent_id': task_id,
'parent_treepath': treepath,
'branch_num': i,
't_max': taskdata["t_max"],
'state_paths': {t: child_state_path},
'coeff': child_coeff*taskdata["coeff"],
'measurement_frequency': taskdata["measurement_frequency"],
'checkpoint_frequency': taskdata["checkpoint_frequency"]
}
# This will add each child task to the log, and schedule them to be
# run. How they are run is up to the scheduler.
child_id, child_path = self.schedule_task(
task_id, child_taskdata, task_id=child_id
)
# NOTE: We could add more child info to the parent taskdata here
return taskdata
# End of custom Distree subclass #
# Build an anytree from saved data by parsing the log file.
def build_tree(dtree):
top = None
r = atr.Resolver('name')
with open(dtree.log_path, "r") as f:
for line in f:
task_id1, parent_id1, taskdata_path = line.strip().split("\t")
taskdata, task_id2, parent_id2 = dtree.load_task_data(taskdata_path)
assert task_id1 == str(task_id2)
assert parent_id1 == str(parent_id2)
parent_treepath = taskdata.get('parent_treepath', '')
branch_num = taskdata.get('branch_num', 0)
num_children = taskdata.get('num_children', 0)
if top is None:
# Check that this really is a top-level node
assert parent_id2 == "" or parent_id2 is None
top = atr.Node('%u' % branch_num, task_id=task_id2,
parent_id=parent_id2,
num_children=num_children,
data=taskdata)
else:
# pnode = atr.search.find_by_attr(top, parent_id,
# name='task_id') # not optimal, but should never fail
# should be efficient
# (alternatively, keep a node dictionary with id's as keys)
pnode = r.get(top, parent_treepath)
atr.Node('%u' % branch_num, parent=pnode, task_id=task_id2,
parent_id=parent_id2, num_children=num_children,
data=taskdata)
return top
def setup_logging():
# Set up logging, both to stdout and to a file.
# First get the logger and set the level to INFO
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Figure out a name for the log file.
filename = os.path.basename(__file__).replace(".py", "")
datetime_str = datetime.datetime.strftime(datetime.datetime.now(),
'%Y-%m-%d_%H-%M-%S')
title_str = ('{}_{}'.format(filename, datetime_str))
logfilename = "log/{}.log".format(title_str)
i = 0
while os.path.exists(logfilename):
i += 1
logfilename = "log/{}_{}.log".format(title_str, i)
# Create a handler for the file.
os.makedirs(os.path.dirname(logfilename), exist_ok=True)
filehandler = logging.FileHandler(logfilename, mode='w')
filehandler.setLevel(logging.INFO)
# Create a handler for stdout.
consolehandler = logging.StreamHandler()
consolehandler.setLevel(logging.INFO)
# Create a formatter object, to determine output format.
fmt = "%(asctime)s %(levelname).1s: %(message)s"
datefmt = "%Y-%m-%d %H:%M:%S"
formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
# Make both stdout and file logging use this formatter.
consolehandler.setFormatter(formatter)
filehandler.setFormatter(formatter)
# Set the logger to use both handlers.
logger.addHandler(filehandler)
logger.addHandler(consolehandler)
def parse_args():
# Parse command line arguments.
parser = argparse.ArgumentParser()
parser.add_argument('taskfile', type=str, nargs="?", default="",
help='an integer for the accumulator')
parser.add_argument('--child', dest='child', default=False,
action='store_true')
parser.add_argument('--show', dest='show', default=False,
action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
# Parse command line arguments.
args = parse_args()
# setup_logging is for the logging module, that is concerned with text
# output (think print statements). It has nothing to do with the log file
# of the Distree.
setup_logging()
# This log file keeps track of the tree.
logfile = "./log/distreelog.txt"
datafolder = "./data/"
# Create a scheduler and tell it what script to run to schedule tasks.
sched = SL.Sched_Local(sys.argv[0], scriptargs=['--child'])
# Create the tree object, telling it where the logfile lives, where the
# taskdata files are to be stored, and giving it the scheduler to use.
dtree = Distree_Demo(logfile, datafolder, sched)
# NOTE: This script is designed so that it can schedule the root job and
# also child jobs, depending on the supplied command-line arguments.
if args.show:
# Print the tree from saved data
top = build_tree(dtree)
logging.info(atr.RenderTree(top))
elif args.child:
# Assume the first argument is a taskdata file for a child job.
# This means the task should be run in the current process,
# rather than be scheduled for later.
dtree.run_task(args.taskfile)
elif args.taskfile:
# Assume the argument is a taskdata file to be used for a root job
dtree.schedule_task_from_file(args.taskfile)
else:
# Create the initial MPS.
root_state = np.array([0.0])
root_task_id = dtree.sched.get_id()
root_state_path = dtree.store_state(
root_state, t=0., task_id=root_task_id
)
# Save a simple initial taskdata file and schedule a root job.
init_task_data = {'parent_id': None,
'parent_treepath': '',
'branch_num': 0,
't_max': 7,
'state_paths': {0.: root_state_path},
'coeff': 1.0,
'measurement_frequency': 2,
'checkpoint_frequency': 4
}
# The following schedules a job (it will be run in a different process)
dtree.schedule_task(None, init_task_data, task_id=root_task_id)