-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
334 additions
and
8 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import random | ||
from collections import namedtuple | ||
|
||
Transition = namedtuple('Transition', | ||
('state', 'action', 'next_state', 'reward', 'time')) | ||
|
||
class ReplayMemory(object): | ||
|
||
def __init__(self, capacity, policy_capacity): | ||
self.capacity = capacity | ||
self.memory = [] | ||
self.position = 0 | ||
|
||
self.policy_capacity = policy_capacity | ||
self.policy_memory = [] | ||
self.policy_position = 0 | ||
|
||
def push(self, *args): | ||
"""Saves a transition.""" | ||
if len(self.memory) < self.capacity: | ||
self.memory.append(None) | ||
self.memory[self.position] = Transition(*args) | ||
self.position = (self.position + 1) % self.capacity | ||
|
||
if len(self.policy_memory) < self.policy_capacity: | ||
self.policy_memory.append(None) | ||
self.policy_memory[self.policy_position] = Transition(*args) | ||
self.policy_position = (self.policy_position + 1) % self.policy_capacity | ||
|
||
def sample(self, batch_size): | ||
return random.sample(self.memory, batch_size) | ||
|
||
def policy_sample(self, batch_size): | ||
return random.sample(self.policy_memory, batch_size) | ||
|
||
def __len__(self): | ||
return len(self.memory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import math | ||
import numpy as np | ||
import random | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable | ||
from memory_replay import Transition | ||
from itertools import count | ||
from torch.distributions import Categorical | ||
|
||
|
||
use_cuda = torch.cuda.is_available() | ||
|
||
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor | ||
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor | ||
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor | ||
Tensor = FloatTensor | ||
|
||
class DQN(nn.Module): | ||
""" | ||
Deep neural network with represents an agent. | ||
""" | ||
def __init__(self, input_size, num_actions): | ||
super(DQN, self).__init__() | ||
self.l1 = nn.Linear(input_size,100) ## To play with different NN architectures | ||
self.l2 = nn.Linear(100, num_actions) | ||
|
||
def forward(self, x): | ||
return self.l2(F.relu(self.l1(x))) | ||
|
||
class PolicyNetwork(nn.Module): | ||
""" | ||
Deep neural network which represents policy network. | ||
""" | ||
def __init__(self, input_size, num_actions): | ||
super(PolicyNetwork, self).__init__() | ||
self.l1 = nn.Linear(input_size,100) ## To play with different NN architectures | ||
self.l2 = nn.Linear(100, num_actions) | ||
|
||
def forward(self, x): | ||
return F.softmax(self.l2(F.relu(self.l1(x)))) | ||
|
||
def select_action(state, policy, model, num_actions, | ||
EPS_START, EPS_END, EPS_DECAY, steps_done, alpha, beta): | ||
""" | ||
Selects whether the next action is choosen by our model or randomly | ||
""" | ||
# sample = random.random() | ||
# eps_threshold = EPS_END + (EPS_START - EPS_END) * \ | ||
# math.exp(-1. * steps_done / EPS_DECAY) | ||
# .data.max(1)[1].view(1, 1) | ||
# if sample <= eps_threshold: | ||
# return LongTensor([[random.randrange(num_actions)]]) | ||
|
||
|
||
|
||
Q = model(Variable(state, volatile=True).type(FloatTensor)) | ||
pi0 = policy(Variable(state, volatile=True).type(FloatTensor)) | ||
# print(pi0.data.numpy()) | ||
V = torch.log((torch.pow(pi0, alpha) * torch.exp(beta * Q)).sum(1) ) / beta | ||
|
||
#### FOUND ERROR: ( Q ) returns a tensor of nan at some point | ||
if np.isnan( Q.sum(1).data[0]) : | ||
print(Q) | ||
print(state) | ||
|
||
pi_i = torch.pow(pi0, alpha) * torch.exp(beta * (Q - V)) | ||
|
||
#if sum(pi_i.data.numpy()[0] < 0) > 0: | ||
# print("Warning!!!: pi_i has negative values: pi_i", pi_i.data.numpy()[0]) | ||
# pi_i = torch.max(torch.zeros_like(pi_i), pi_i) | ||
# probabilities = pi_i.data.numpy()[0] | ||
m = Categorical(pi_i) | ||
action = m.sample().data.view(1, 1) | ||
return action | ||
# numpy.random.choice(numpy.arange(0, num_actions), p=probabilities) | ||
|
||
|
||
def optimize_policy(policy, optimizer, memories, batch_size, | ||
num_envs, gamma): | ||
loss = 0 | ||
for i_env in range(num_envs): | ||
size_to_sample = np.minimum(batch_size, len(memories[i_env])) | ||
transitions = memories[i_env].policy_sample(size_to_sample) | ||
batch = Transition(*zip(*transitions)) | ||
|
||
state_batch = Variable(torch.cat(batch.state)) | ||
# print(batch.action) | ||
time_batch = Variable(torch.cat(batch.time)) | ||
actions = np.array([action.numpy()[0][0] for action in batch.action]) | ||
|
||
cur_loss = (torch.pow(Variable(Tensor([gamma])), time_batch) * | ||
torch.log(policy(state_batch)[:, actions])).sum() | ||
loss -= cur_loss | ||
# loss = cur_loss if i_env == 0 else loss + cur_loss | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
|
||
# for param in policy.parameters(): | ||
# param.grad.data.clamp_(-1, 1) | ||
optimizer.step() | ||
|
||
def optimize_model(policy, model, optimizer, memory, batch_size, | ||
alpha, beta, gamma): | ||
if len(memory) < batch_size: | ||
return | ||
transitions = memory.sample(batch_size) | ||
# Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for | ||
# detailed explanation). | ||
batch = Transition(*zip(*transitions)) | ||
|
||
# Compute a mask of non-final states and concatenate the batch elements | ||
non_final_mask = ByteTensor(tuple(map(lambda s: s is not None, | ||
batch.next_state))) | ||
# We don't want to backprop through the expected action values and volatile | ||
# will save us on temporarily changing the model parameters' | ||
# requires_grad to False! | ||
non_final_next_states = Variable(torch.cat([s for s in batch.next_state | ||
if s is not None]), | ||
volatile=True) | ||
|
||
state_batch = Variable(torch.cat(batch.state)) | ||
action_batch = Variable(torch.cat(batch.action)) | ||
reward_batch = Variable(torch.cat(batch.reward)) | ||
|
||
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the | ||
# columns of actions taken | ||
state_action_values = model(state_batch).gather(1, action_batch) | ||
|
||
# Compute V(s_{t+1}) for all next states. | ||
next_state_values = Variable(torch.zeros(batch_size).type(Tensor)) | ||
next_state_values[non_final_mask] = torch.log( | ||
(torch.pow(policy(non_final_next_states), alpha) | ||
* torch.exp(beta * model(non_final_next_states))).sum(1)) / beta | ||
# Now, we don't want to mess up the loss with a volatile flag, so let's | ||
# clear it. After this, we'll just end up with a Variable that has | ||
# requires_grad=False | ||
next_state_values.volatile = False | ||
# Compute the expected Q values | ||
expected_state_action_values = (next_state_values * gamma) + reward_batch | ||
|
||
# Compute Huber loss | ||
loss = F.mse_loss(state_action_values, expected_state_action_values) | ||
|
||
# Optimize the model | ||
optimizer.zero_grad() | ||
loss.backward() | ||
# for param in model.parameters(): | ||
# param.grad.data.clamp_(-1, 1) | ||
optimizer.step() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
from itertools import count | ||
import torch.optim as optim | ||
import torch | ||
import math | ||
import numpy as np | ||
from memory_replay import ReplayMemory, Transition | ||
from network import DQN, select_action, optimize_model, Tensor, optimize_policy, PolicyNetwork | ||
import sys | ||
sys.path.append('../') | ||
from envs.gridworld_env import GridworldEnv | ||
from utils import plot_rewards, plot_durations, plot_state, get_screen | ||
|
||
def trainD(file_name="Distral_2col", list_of_envs=[GridworldEnv(5), | ||
GridworldEnv(4), GridworldEnv(6)], batch_size=128, gamma=0.999, alpha=0.8, | ||
beta=5, eps_start=0.9, eps_end=0.05, eps_decay=5, | ||
is_plot=False, num_episodes=200, | ||
max_num_steps_per_episode=1000, learning_rate=0.001, | ||
memory_replay_size=10000, memory_policy_size=1000): | ||
""" | ||
Soft Q-learning training routine. Retuns rewards and durations logs. | ||
Plot environment screen | ||
""" | ||
num_actions = list_of_envs[0].action_space.n | ||
input_size = list_of_envs[0].observation_space.shape[0] | ||
num_envs = len(list_of_envs) | ||
policy = PolicyNetwork(input_size, num_actions) | ||
models = [DQN(input_size,num_actions) for _ in range(0, num_envs)] ### Add torch.nn.ModuleList (?) | ||
memories = [ReplayMemory(memory_replay_size, memory_policy_size) for _ in range(0, num_envs)] | ||
|
||
use_cuda = torch.cuda.is_available() | ||
if use_cuda: | ||
policy.cuda() | ||
for model in models: | ||
model.cuda() | ||
|
||
optimizers = [optim.Adam(model.parameters(), lr=learning_rate) | ||
for model in models] | ||
policy_optimizer = optim.Adam(policy.parameters(), lr=learning_rate) | ||
# optimizer = optim.RMSprop(model.parameters(), ) | ||
|
||
episode_durations = [[] for _ in range(num_envs)] | ||
episode_rewards = [[] for _ in range(num_envs)] | ||
|
||
steps_done = np.zeros(num_envs) | ||
episodes_done = np.zeros(num_envs) | ||
current_time = np.zeros(num_envs) | ||
|
||
# Initialize environments | ||
states = [] | ||
for env in list_of_envs: | ||
states.append(torch.from_numpy( env.reset() ).type(torch.FloatTensor).view(-1,input_size)) | ||
|
||
while np.min(episodes_done) < num_episodes: | ||
# TODO: add max_num_steps_per_episode | ||
|
||
# Optimization is given by alterating minimization scheme: | ||
# 1. do the step for each env | ||
# 2. do one optimization step for each env using "soft-q-learning". | ||
# 3. do one optimization step for the policy | ||
|
||
for i_env, env in enumerate(list_of_envs): | ||
# print("Cur episode:", i_episode, "steps done:", steps_done, | ||
# "exploration factor:", eps_end + (eps_start - eps_end) * \ | ||
# math.exp(-1. * steps_done / eps_decay)) | ||
|
||
# Select and perform an action | ||
#print(states[i_env]) | ||
|
||
action = select_action(states[i_env], policy, models[i_env], num_actions, | ||
eps_start, eps_end, eps_decay, | ||
episodes_done[i_env], alpha, beta) | ||
|
||
steps_done[i_env] += 1 | ||
current_time[i_env] += 1 | ||
next_state_tmp, reward, done, _ = env.step(action[0,0]) | ||
reward = Tensor([reward]) | ||
|
||
# Observe new state | ||
next_state = torch.from_numpy( next_state_tmp ).type(torch.FloatTensor).view(-1,input_size) | ||
|
||
if done: | ||
next_state = None | ||
|
||
# Store the transition in memory | ||
time = Tensor([current_time[i_env]]) | ||
memories[i_env].push(states[i_env], action, next_state, reward, time) | ||
|
||
# Perform one step of the optimization (on the target network) | ||
optimize_model(policy, models[i_env], optimizers[i_env], | ||
memories[i_env], batch_size, alpha, beta, gamma) | ||
|
||
# Update state | ||
states[i_env] = next_state | ||
|
||
# Check if agent reached target | ||
if done: | ||
print("ENV:", i_env, "iter:", episodes_done[i_env], | ||
"\treward:", env.episode_total_reward, | ||
"\tit:", current_time[i_env], "\texp_factor:", eps_end + | ||
(eps_start - eps_end) * math.exp(-1. * episodes_done[i_env] / eps_decay)) | ||
states[i_env] = torch.from_numpy( env.reset() ).type(torch.FloatTensor).view(-1,input_size) | ||
episodes_done[i_env] += 1 | ||
episode_durations[i_env].append(current_time[i_env]) | ||
current_time[i_env] = 0 | ||
episode_rewards[i_env].append(env.episode_total_reward) | ||
if is_plot: | ||
plot_rewards(episode_rewards, i_env) | ||
|
||
|
||
optimize_policy(policy, policy_optimizer, memories, batch_size, | ||
num_envs, gamma) | ||
|
||
print('Complete') | ||
env.render(close=True) | ||
env.close() | ||
if is_plot: | ||
plt.ioff() | ||
plt.show() | ||
|
||
## Store Results | ||
|
||
np.save(file_name + '-distral-1col-rewards', episode_rewards) | ||
np.save(file_name + '-distral-1col-durations', episode_durations) | ||
|
||
return models, policy, episode_rewards, episode_durations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import sys | ||
sys.path.append('../') | ||
from envs.gridworld_env import GridworldEnv | ||
from utils import play_game | ||
sys.path.append('../distral_2col0') | ||
import trainingDistral2col0 | ||
|
||
models, policy, episode_rewards, episode_durations = trainingDistral2col0.trainD() |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters