-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.py
315 lines (249 loc) · 11.9 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import json
import os
import random
import time
from typing import Callable, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from constants import STATE_N_VARS, STATE_SHAPE_CHANNELS_FIRST, Direction
from model import GridNet
from replay import MultiInputReplayBuffer, ReplayBuffer
from colorama import Fore, Style
StateType = tuple[np.ndarray, np.ndarray]
class Agent:
model_path = "./models"
def __init__(
self,
batch_size: int,
gamma: float,
learning_rate: float,
eps_decay_steps: int,
update_target_rate: int,
epsilon_min: float,
device: torch.device,
replay: ReplayBuffer,
num_actions: int,
policy_net: nn.Module,
target_net: nn.Module,
eval,
model_path : Optional[str] = None,
load_model : Optional[str] = None
):
self.num_actions = num_actions
# hyperparameters
self.gamma = gamma
self.learning_rate = learning_rate
self.batch_size = batch_size
self.eps_decay_steps = eps_decay_steps
self.update_target_rate = update_target_rate
self.train_cntr = 0
self.epsilon = 1.0
self.epsilon_min = epsilon_min
self.replay = replay
self.device = device
if model_path is None:
if not os.path.exists(self.model_path):
print("Created `./models` directory")
os.mkdir("./models")
else:
self.model_path = model_path
if not os.path.exists(self.model_path):
raise Exception(f"Model directory `{self.model_path}` does not exist")
self.policy_net = policy_net.to(device)
self.target_net = target_net.to(device)
self.update_target_network()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
self.loss_fn = nn.MSELoss().to(device)
self.clip_grad = 80.0 # clip gradient value
self.eval = eval # inference only mode
if load_model:
self.load(load_model, verbose=True)
if self.eval:
if load_model is None:
raise Exception("Must load a model in evaluation mode.")
else:
self.load(load_model)
self.eval_msg(load_model)
else:
self.summary_msg()
print()
def eval_msg(self, model_name: str):
print()
print(Style.BRIGHT + Fore.GREEN + f"Inference only mode! All actions use the loaded model: {model_name}" + Style.RESET_ALL)
def summary_msg(self):
print()
print(Style.BRIGHT + Fore.GREEN + "Hyperparameters" + Style.NORMAL)
print(f"\tbatch_size = {self.batch_size}")
print(f"\tlearning_rate = {self.learning_rate}")
print(f"\tgamma = {self.gamma}")
print(f"\tmemory_size = {self.replay.size:,}")
print(f"\ttrain_steps = {self.eps_decay_steps:,}")
print(f"\ttarget_update = {self.update_target_rate:,}")
total_params = 0
for parameter in self.policy_net.parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
total_params+=params
print(Style.BRIGHT + Fore.WHITE + "Other Parameters" + Style.NORMAL)
print(f"\ttotal_model_parameters = {total_params:,}")
print(f"\tepsilon_min = {self.epsilon_min}")
print(f"\tdevice = {self.device}")
print(f"\tmodel_path = {self.model_path}")
print(f"\toptimizer = {self.optimizer.__class__.__name__}")
print(f"\tloss_fn = {self.loss_fn.__class__.__name__}")
print()
print(f"[Ctrl+C to exit]" + Style.RESET_ALL)
@classmethod
def from_config(cls, config_path: str, load_model: Optional[str] = None, eval=False):
with open(config_path, "r") as f:
config = json.load(f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_net = GridNet(num_channels=STATE_SHAPE_CHANNELS_FIRST[0], num_scalars=STATE_N_VARS, num_actions=len(Direction)).to(device)
target_net = GridNet(num_channels=STATE_SHAPE_CHANNELS_FIRST[0], num_scalars=STATE_N_VARS, num_actions=len(Direction)).to(device)
replay = MultiInputReplayBuffer(
size=config.pop("memory_size"),
grid_shape=STATE_SHAPE_CHANNELS_FIRST,
num_scalars=STATE_N_VARS
)
return Agent(
replay=replay,
policy_net=policy_net,
target_net=target_net,
device=device,
num_actions=policy_net.num_actions,
eval=eval,
load_model=load_model,
**config
)
@classmethod
def eval_pipeline(cls, model_name: str) -> Callable[[StateType], Direction]:
"""Load a model, returns the max Q value function.
This function takes the state as a tuple `(grid, scalar)` and returns the max Q action, an integer."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_net = GridNet(num_channels=STATE_SHAPE_CHANNELS_FIRST[0], num_scalars=STATE_N_VARS, num_actions=len(Direction)).to(device)
checkpoint = torch.load(os.path.join(cls.model_path, model_name))
policy_net.load_state_dict(checkpoint["policy_state_dict"])
def act(state):
state_tensor = cls.game_state_to_tensor_batch(state, device)
q_values = policy_net.forward(state_tensor)[0]
action = torch.argmax(q_values)
return action.item()
return act
def update_target_network(self):
# copy the weights of the target model to the current model
# SOFT update
# target_net_state_dict = self.target_net.state_dict()
# policy_net_state_dict = self.policy_net.state_dict()
# for key in policy_net_state_dict:
# target_net_state_dict[key] = policy_net_state_dict[key]*self.tau + target_net_state_dict[key]*(1-self.tau)
# self.target_net.load_state_dict(target_net_state_dict)
# HARD update
self.target_net.load_state_dict(self.policy_net.state_dict(), strict=True)
@staticmethod
def game_state_to_tensor_batch(state, device):
grid, scalar = state
with torch.no_grad():
grid_t = torch.tensor(grid[np.newaxis, :], device=device)
scalar_t = torch.tensor(scalar[np.newaxis, :], device=device, dtype=torch.float32)
return grid_t, scalar_t
def get_eps(self):
return max(1 - (self.train_cntr/self.eps_decay_steps), self.epsilon_min)
def set_eps(self, value: float):
"""Set train counter such that `epsilon decay steps * epsilon = train counter`
Returns new train counter value for progress bars."""
if value > 1.0 or value < 0.0:
raise Exception("Expected epsilon value in interval [0, 1]")
self.train_cntr = int(value * self.eps_decay_steps)
return self.train_cntr
def act(self, state, is_grid=True):
"""Predicts the max Q action of a game state.
Uses epsilon greedy strategy for exploration.
In evaluation mode, all actions use the policy net. See `self.eval()`.
Returns tuple of `[action, q_value]`."""
if np.random.rand() <= self.get_eps() and not self.eval:
# random action has no q-value
return random.randrange(0, self.num_actions), None
else:
# convert to batch tensor
if is_grid:
state_t = self.game_state_to_tensor_batch(state, self.device)
else:
state_t = torch.tensor(state[np.newaxis, :], device=self.device)
# return action index and its q-value
pred = self.policy_net.forward(state_t)[0]
idx = torch.argmax(pred)
return idx.item(), float(pred[idx])
def train(self):
if self.replay.cntr < self.batch_size:
return float('inf')
state, action, reward, new_state, done = self.replay.sample(self.batch_size, self.device)
with torch.no_grad():
q_next = self.policy_net.forward(new_state)
max_q_actions = q_next.argmax(1)
q_next_target = self.target_net.forward(new_state)
max_q_values = q_next_target.gather(1, max_q_actions.unsqueeze(1)).squeeze(1)
q_eval = self.policy_net.forward(state).gather(1, action.unsqueeze(1)).squeeze(1)
# zero the future reward if terminal state
q_target = reward + self.gamma * max_q_values * (1 - done)
loss = self.loss_fn(q_eval, q_target)
# start ORIGINAL
# with T.no_grad():
# q_next = self.policy_net.forward(new_state)
# max_q_actions = q_next.argmax(1)
# q_next_target = self.target_net.forward(new_state)
# max_q_values = q_next_target[:, max_q_actions]
# q_eval = self.policy_net.forward(state)[:, action]
# # zero the future reward if terminal state
# q_target = reward + self.gamma * max_q_values * (1 - done)
# loss = self.loss_fn(q_eval, q_target)
# end ORIGINAL
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad.clip_grad_value_(self.policy_net.parameters(), self.clip_grad)
self.optimizer.step()
self.train_cntr += 1
if self.train_cntr % self.update_target_rate == 0:
self.update_target_network()
return loss.item()
def train_steps_completed(self):
"""No. times `self.train()` has been called."""
return self.train_cntr
def save(self, name, verbose=False):
torch.save({
"policy_state_dict": self.policy_net.state_dict(),
"target_state_dict": self.target_net.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}, os.path.join(self.model_path, name))
if verbose:
print(Style.BRIGHT + Fore.GREEN + f"Saved model {name}" + Style.RESET_ALL)
def load(self, name, verbose=False):
checkpoint = torch.load(os.path.join(self.model_path, name))
self.policy_net.load_state_dict(checkpoint["policy_state_dict"])
self.target_net.load_state_dict(checkpoint["target_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if verbose:
print(Style.BRIGHT + Fore.GREEN + f"Loaded model {name}" + Style.RESET_ALL)
@classmethod
def convert_train_model(cls, model_name: str):
"""Opens a pickled trained model dict with keys for policy network, target network and the optimizer.
Creates a new piclked model as the policy network's state dict.
Must be called before a model can be used with `online` command.
Must be called with cuda enabled since tensors from trained model
expect to be loaded on a cuda device"""
print(Fore.YELLOW + f"Warning: requires CUDA enabled device to open trained model!" + Fore.RESET)
# load trained model and move to cpu
trained_model = torch.load(os.path.join(cls.model_path, model_name))
print(Fore.GREEN + f"Opened trained model {model_name}" + Fore.RESET)
temp_net = GridNet(num_channels=STATE_SHAPE_CHANNELS_FIRST[0], num_scalars=STATE_N_VARS, num_actions=len(Direction))
temp_net.load_state_dict(trained_model["policy_state_dict"])
temp_net.to("cpu")
print(Fore.GREEN + f"Moved tensors to CPU" + Fore.RESET)
# only save policy state dict for inference
new_model_name = model_name.split(".pt")[0] + "-eval.pt"
torch.save(temp_net.state_dict(), os.path.join(cls.model_path, new_model_name))
print(Fore.GREEN + f"Saved new model {new_model_name}" + Fore.RESET)
def eval(self):
"""Enables evaluation only mode."""
self.eval = True