-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenvironment.py
175 lines (148 loc) · 11.1 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from svgpath2mpl import parse_path
'''
The environment includes the wildfire with the agents in it.
The state of the agent is defined as an array [x, y, mu, sigma] of size N_agents + N_sats + 1 + 2 by 2,
where x and y are the (N_agents + N_sats) locations, mu and sigma are the mean and covariance of the estimated distribution.
'''
class WildFireEnv:
def __init__(self, width: int, height: int, init_state, action_range: int, p_move: float, max_temp, N_agents: int, N_sats=0, max_steps=100, tol=1e-8):
self.width = width
self.height = height
self.step_count = 0
self.mean = np.array([width/2, height/2])
self.cov = np.array([[width/2, 0], [0, height/2]])
self.temperature_dist = multivariate_normal(self.mean, self.cov)
self.init_state = init_state
self.state = init_state
self.action_range = action_range
self.p_move = p_move # penalty factor for moving
self.max_temp = max_temp
self.N_agents = N_agents
self.N_sats = N_sats
self.max_steps = max_steps
self.tol = tol
self.done = False
def reset(self):
# Set all the agent locations to some random (discrete) locations
self.state[:self.N_agents] = np.random.randint(0, self.width, size=(self.N_agents, 2))
# self.state = self.init_state
self.step_count = 0
self.done = False
def reposition(self, locations):
'''Reposition the agents in the environment.'''
self.state[:self.N_agents] = locations
updated_mu, updated_sigma = self.fit_distribution(locations)
self.state[-3] = updated_mu
self.state[-2:] = updated_sigma
def print_state(self):
'''Print the current state of the environment.'''
print('-'*40)
print('Agent locations: ', self.state[:self.N_agents])
print('Satellite locations: ', self.state[self.N_agents:(self.N_agents + self.N_sats)])
print('Mean: ', self.state[-3])
print('Covariance: ', self.state[-2:].flatten())
print('-'*40)
def check_single_inbound(self, location):
'''Check if a single location is within the bounds of the environment.'''
return np.all(location >= 0) and np.all(location < np.array([self.width, self.height]))
def check_all_inbound(self, state):
'''Check if the state is within the bounds of the environment.'''
return np.all(state[:self.N_agents + self.N_sats] >= 0) and np.all(state[:self.N_agents + self.N_sats] < np.array([self.width, self.height]))
def get_temperatures(self, state):
'''Get the temperatures at all locations.'''
temperatures = np.array([self.temperature_dist.pdf(state[loc_idx]) for loc_idx in range(self.N_agents + self.N_sats)])
temperatures = temperatures * self.max_temp / self.temperature_dist.pdf(self.mean)
return temperatures
def get_divergence(self, mu, sigma):
'''Evaluate the distance from the true distribution given a certain estimated distribution using the KL divergence analytical formula.'''
k = 2 # dimension of the distribution
D_kl = 1/2 * (np.log(np.linalg.det(self.cov)/np.linalg.det(sigma)) - k + (mu - self.mean).T @ np.linalg.inv(self.cov) @ (mu - self.mean) + np.trace(np.linalg.inv(self.cov) @ sigma))
return D_kl
def fit_distribution(self, new_locations):
'''Update distribution given new locations of the agents and satellites.'''
temperatures = self.get_temperatures(new_locations)
temperatures_sum = np.sum(temperatures)
fitted_mu = np.sum(temperatures.reshape(-1, 1) * new_locations, axis=0) / temperatures_sum
# new_mu = (self.state[-3] + fitted_mu) / 2 # update mean using a weighted average
new_mu = 0.75 * self.state[-3] + 0.25 * fitted_mu # update mean using a weighted average
new_sigma = self.state[-2:, :] # do not update covariance for now
return new_mu, new_sigma
def move_cost(self, new_state):
'''Calculate the (approximate) cost of moving from current state to new state, simply being the euclidian distance between the two states'''
move_cost = 0
for loc_idx in range(self.N_agents):
move_cost += np.linalg.norm(new_state[loc_idx] - self.state[loc_idx])
move_cost /= self.N_agents
return move_cost
def get_reward(self, new_state):
'''Get the reward given a certain estimated distribution and the new state.'''
new_mu = new_state[-3] # estimated mean
new_sigma = new_state[-2:, :] # estimated covariance
# print("KL Divergence cost: ", self.get_divergence(new_mu, new_sigma))
# print("Move cost: ", self.move_cost(new_state))
return -self.get_divergence(new_mu, new_sigma) - self.p_move * self.move_cost(new_state)
def act(self, action: int):
'''Take an action in the environment. The action is an array of integers which are linear indices of the new state relative to the current state.'''
action_agents = np.unravel_index(action, (self.action_range**2,) * self.N_agents)
action = np.zeros((self.N_agents, 2))
for agent_idx in range(self.N_agents):
action[agent_idx, :] = np.unravel_index(action_agents[agent_idx], (self.action_range,) * 2) - np.array([self.action_range//2, self.action_range//2]) # subtract half the action range to get relative states
new_locations = self.state[:self.N_agents] + action # add relative states to current locations
# Check if new locations are inbound, otherwise keep the old location
for loc_idx in range(self.N_agents):
if not self.check_single_inbound(new_locations[loc_idx]):
new_locations[loc_idx] = self.state[loc_idx]
new_mu, new_sigma = self.fit_distribution(new_locations)
# Set done to true if divergence between new and old distribution is small enough
if self.get_divergence(new_mu, new_sigma) < self.tol:
self.done = True
return np.vstack((new_locations, new_mu, new_sigma))
def flatten_state(self, state):
return state.flatten()
def step(self, action: int):
'''Take a step in the environment given an action.'''
self.step_count += 1
if self.step_count >= self.max_steps:
self.done = True
else:
self.done = False
new_state = self.act(action)
reward = self.get_reward(new_state)
self.state = new_state
return self.flatten_state(new_state), reward, self.done
def render(self):
window = 5 # window buffer size for plotting
mu_pred = self.state[-3]
sigma_pred = self.state[-2:]
pred_dist = multivariate_normal(mu_pred, sigma_pred).pdf
true_dist = self.temperature_dist.pdf
x_pred, y_pred, z_pred = self.plotVal(pred_dist)
x_true, y_true, z_true = self.plotVal(true_dist)
fig,ax = plt.subplots(layout='constrained')
pred_contour = ax.contourf(x_pred, y_pred, z_pred, 30, cmap='Blues')
true_contour = ax.contourf(x_true, y_true, z_true, 30, cmap='Reds', alpha = 0.3)
ax.set_title('Predicted Distribution Relative to True Distribution')
ax.set_xlabel(r'$x_1$ [km]')
ax.set_ylabel(r'$x_2$ [km]')
ax.set_xlim([-window,self.width + window])
ax.set_ylim([-window,self.height + window])
cbar1 = fig.colorbar(pred_contour)
cbar2 = fig.colorbar(true_contour)
cbar1.set_label('Predicted Temps')
cbar2.set_label('True Temps')
drone = parse_path("M915 5009 c-218 -18 -426 -125 -567 -293 -65 -78 -147 -235 -173 -331 -26 -95 -31 -273 -10 -380 32 -171 116 -327 241 -446 123 -116 243 -182 401 -220 73 -17 111 -20 220 -17 73 2 137 8 143 14 6 6 -19 38 -78 96 l-87 86 -80 4 c-250 12 -506 240 -570 506 -35 144 -6 320 77 460 42 72 158 188 230 230 194 114 435 123 636 22 189 -95 322 -282 351 -494 6 -44 8 -101 5 -129 l-7 -49 79 -79 c43 -43 81 -79 85 -79 16 0 34 136 34 260 -1 190 -39 318 -141 471 -171 258 -463 394 -789 368z M4030 5014 c-14 -2 -52 -9 -85 -15 -205 -37 -414 -179 -534 -364 -106 -161 -147 -313 -138 -509 5 -117 21 -216 35 -216 4 0 43 33 87 74 l80 73 1 99 c0 122 15 185 65 292 28 57 60 104 105 152 81 86 152 133 257 172 105 40 262 49 372 24 212 -50 388 -211 462 -423 24 -69 28 -93 28 -203 0 -105 -4 -136 -23 -195 -86 -258 -331 -455 -569 -459 l-73 -1 -85 -85 c-55 -55 -81 -88 -75 -94 15 -15 239 -20 315 -7 328 57 599 309 686 636 30 112 30 298 0 410 -98 370 -411 623 -790 639 -53 3 -107 3 -121 0z M984 4241 c-89 -40 -137 -144 -114 -243 12 -51 22 -62 284 -323 347 -346 439 -461 509 -636 57 -141 62 -182 62 -469 -1 -243 -3 -272 -23 -344 -42 -150 -141 -334 -275 -511 -81 -106 -307 -336 -430 -438 -106 -87 -124 -112 -133 -179 -17 -123 83 -238 206 -238 81 0 90 7 475 389 403 400 475 459 635 520 152 59 191 66 365 66 159 0 161 -1 270 -38 173 -60 339 -157 533 -313 161 -129 381 -353 536 -548 85 -106 219 -115 312 -22 59 58 76 127 53 211 -9 33 -56 85 -258 290 -136 138 -275 282 -308 322 -132 154 -212 298 -255 454 -22 79 -23 102 -23 369 0 331 8 378 95 555 99 200 282 414 586 681 158 139 184 174 184 247 0 118 -96 217 -210 217 -81 0 -90 -8 -460 -374 -371 -368 -474 -456 -615 -526 -138 -68 -217 -85 -390 -85 -126 0 -163 4 -230 23 -100 28 -294 123 -405 197 -200 134 -451 365 -656 601 -122 142 -155 164 -237 164 -23 0 -60 -9 -83 -19z m1681 -1276 c167 -44 278 -167 308 -341 30 -173 -63 -359 -218 -435 -262 -129 -559 19 -608 305 -50 296 225 547 518 471z M2509 2811 c-118 -24 -202 -128 -202 -251 0 -191 188 -309 364 -228 90 41 142 124 142 228 0 162 -145 282 -304 251z M885 1799 c-343 -50 -624 -306 -707 -647 -29 -119 -29 -294 1 -407 81 -306 310 -534 616 -616 96 -25 287 -31 380 -10 311 68 554 296 646 607 30 102 37 290 14 400 -10 46 -20 84 -24 84 -3 0 -43 -34 -88 -75 l-82 -76 5 -77 c20 -274 -169 -543 -453 -643 -82 -29 -255 -37 -346 -15 -330 78 -546 406 -488 741 49 280 330 535 590 535 l66 0 89 89 c48 49 84 91 78 95 -15 9 -142 26 -192 25 -25 -1 -72 -5 -105 -10z M3985 1794 c-11 -3 -26 -7 -34 -10 -10 -3 14 -34 74 -94 l90 -90 65 0 c197 0 430 -160 530 -364 59 -119 74 -190 67 -316 -9 -178 -76 -322 -208 -443 -119 -111 -263 -167 -429 -167 -181 0 -318 58 -451 190 -141 139 -200 272 -202 457 l-2 103 -75 75 c-41 41 -78 75 -81 75 -4 0 -15 -33 -25 -72 -28 -109 -26 -307 4 -411 84 -294 313 -519 607 -598 113 -30 307 -30 420 0 292 78 517 292 604 574 119 385 -65 814 -429 1001 -139 72 -204 88 -365 92 -77 1 -149 1 -160 -2z")
drone.vertices -= drone.vertices.mean(axis=0)
ax.scatter(self.state[:self.N_agents,0], self.state[:self.N_agents,1], c='black', marker=drone, s = 250, label='Agent Locations')
ax.legend()
fig.savefig('./renderings/step_{}.png'.format(self.step_count))
plt.close(fig)
def plotVal(self, distrib):
k = 0.05 # adjusts coarseness of the plot
window = 5 # window buffer size for plotting
x, y = np.meshgrid(np.arange(-window,self.width + window, k), np.arange(-window,self.height + window, k))
xy = np.vstack((x.flatten(), y.flatten())).T
z = distrib(xy).reshape(x.shape)
return x,y,z