-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexploration.py
108 lines (86 loc) · 2.95 KB
/
exploration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
def exploration(
environment,
agent,
T,
evaluation,
T_random=0,
reset=True,
plot=False,
animate=None,
save_models=None):
"""Run exploration
:type environment: environment object
:type agent: agent object
:param T: time horizon
:type T: int
:type evaluation: evaluation object
:param T_random: number of time steps for which random inputs are played in the beginning,
defaults to 0
:type T_random: int, optional
:param reset: reset the environment before running exploration, defaults to True
:type reset: bool, optional
:param plot: plot the system, defaults to False
:type plot: bool, optional
:param animate: animation function, defaults to None
:type animate: function, optional
:param save_models: path to save model parameters, defaults to None
:type save_models: str, optional
:return: state-action values, evaluation values
:rtype: array of shape T x (d+m), array of size T
"""
d, m = environment.d, environment.m
z_values = np.zeros((T, d+m))
error_values = np.zeros(T)
if reset:
environment.reset()
for t in range(T):
x = environment.x.copy()
u = agent.draw_random_control()
if t > T_random:
u = agent.policy(x, t)
dx = environment.step(u, t)
dx_dt = dx/environment.dt
agent.learning_step(x, u, dx_dt)
z_values[t:, :d] = x.copy()
z_values[t:, d:] = u.copy()
if evaluation is not None:
error_values[t] = evaluation.evaluate(agent.model, t)
illustrate(x, u, t, agent.model, z_values, error_values, environment.plot_system, plot, animate, save_models)
return z_values, error_values
def illustrate(x, u, t, model, z_values, error_values, plot_system, plot, animate, save_models):
if save_models is not None:
path = f'{save_models}_{t}.dat'
with open(path, 'wb') as file:
torch.save(model, file)
if animate is not None:
animate(model, u, t, z_values, error_values, plot=plot)
return
if plot:
plot_system(x, u, t)
plt.pause(0.1)
plt.close()
if __name__=='__main__':
from environments.pendulum import DampedPendulum
from models.pendulum import Linear
from evaluation.pendulum import ParameterNorm
from policies import Random, Flex
from exploration import exploration
T = 100
dt = 5e-2
environment = DampedPendulum(dt)
model = Linear(environment)
evaluation = ParameterNorm(environment)
# agent = Random(
agent = Flex(
model,
environment.d,
environment.m,
environment.gamma,
dt=environment.dt
)
z_values, error_values = exploration(environment, agent, T, evaluation, plot=True)
plt.plot(error_values) ; plt.yscale('log') ; plt.show()