Skip to content

Commit

Permalink
Merge pull request #57 from muupan/pcl-prioritized-replay
Browse files Browse the repository at this point in the history
Use Prioritized Replay for PCL
  • Loading branch information
muupan authored Mar 23, 2017
2 parents f4cff28 + 5c83182 commit 67901fd
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 65 deletions.
8 changes: 7 additions & 1 deletion chainerrl/agents/pcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ def update_from_replay(self):

episodes = self.replay_buffer.sample_episodes(
self.batchsize, max_len=self.t_max)
if isinstance(episodes, tuple):
# Prioritized replay
episodes, weights = episodes
else:
weights = [1] * len(episodes)
sorted_episodes = list(reversed(sorted(episodes, key=len)))
max_epi_len = len(sorted_episodes[0])

Expand Down Expand Up @@ -326,7 +331,8 @@ def update_from_replay(self):
values=e_values,
next_values=e_next_values,
log_probs=e_log_probs))
loss = chainerrl.functions.sum_arrays(losses) / self.batchsize
loss = chainerrl.functions.weighted_sum_arrays(
losses, weights) / self.batchsize
self.update(loss)

def update_on_policy(self, statevar):
Expand Down
115 changes: 91 additions & 24 deletions chainerrl/misc/prioritized.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,46 @@
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases()
import random

import numpy as np


class PrioritizedBuffer (object):
def __init__(self, capacity=None):

def __init__(self, capacity=None, wait_priority_after_sampling=True):
self.capacity = capacity
self.data = []
self.priority_tree = SumTree()
self.data_inf = []
self.wait_priority_after_sampling = wait_priority_after_sampling
self.flag_wait_priority = False

def __len__(self):
return len(self.data) + len(self.data_inf)

def append(self, value):
# new values are the most prioritized
self.data_inf.append(value)
if self.capacity is not None and len(self) > self.capacity:
def append(self, value, priority=None):
if self.capacity is not None and len(self) == self.capacity:
self.pop()
if priority is not None:
# Append with a given priority
i = len(self.data)
self.data.append(value)
self.priority_tree[i] = priority
else:
# Append with the highest priority
self.data_inf.append(value)

def _pop_random_data_inf(self):
assert self.data_inf
n = len(self.data_inf)
i = random.randrange(n)
ret = self.data_inf[i]
self.data_inf[i] = self.data_inf[n-1]
self.data_inf[i] = self.data_inf[n - 1]
self.data_inf.pop()
return ret

Expand All @@ -33,47 +50,96 @@ def pop(self):
Not prioritized.
"""
assert len(self) > 0
assert not self.flag_wait_priority
assert (not self.wait_priority_after_sampling or
not self.flag_wait_priority)
n = len(self.data)
if n == 0:
return self._pop_random_data_inf()
i = random.randrange(0, n)
# remove i-th
self.priority_tree[i] = self.priority_tree[n-1]
del self.priority_tree[n-1]
self.priority_tree[i] = self.priority_tree[n - 1]
del self.priority_tree[n - 1]
ret = self.data[i]
self.data[i] = self.data[n-1]
del self.data[n-1]
self.data[i] = self.data[n - 1]
del self.data[n - 1]
return ret

def sample(self, n):
"""Sample n distinct elements"""
def _prioritized_sample_indices_and_probabilities(self, n):
assert 0 <= n <= len(self)
assert not self.flag_wait_priority
indices, probabilities = self.priority_tree.prioritized_sample(
max(0, n - len(self.data_inf)), remove=True)
sampled = []
for i in indices:
sampled.append(self.data[i])
while len(sampled) < n and len(self.data_inf) > 0:
max(0, n - len(self.data_inf)),
remove=self.wait_priority_after_sampling)
while len(indices) < n:
i = len(self.data)
e = self._pop_random_data_inf()
self.data.append(e)
del self.priority_tree[i]
indices.append(i)
probabilities.append(None)
sampled.append(self.data[i])
return indices, probabilities

def _sample_indices_and_probabilities(self, n, uniform_ratio):
if uniform_ratio > 0:
# Mix uniform samples and prioritized samples
n_uniform = np.random.binomial(n, uniform_ratio)
n_prioritized = n - n_uniform
pr_indices, pr_probs = \
self._prioritized_sample_indices_and_probabilities(
n_prioritized)
un_indices, un_probs = \
self._uniform_sample_indices_and_probabilities(
n_uniform)
indices = pr_indices + un_indices
# Note: when uniform samples and prioritized samples are mixed,
# resulting probabilities are not the true probabilities for each
# entry to be sampled.
probabilities = pr_probs + un_probs
return indices, probabilities
else:
# Only prioritized samples
return self._prioritized_sample_indices_and_probabilities(n)

def sample(self, n, uniform_ratio=0):
"""Sample data along with their corresponding probabilities.
Args:
n (int): Number of data to sample.
uniform_ratio (float): Ratio of uniformly sampled data.
Returns:
sampled data (list)
probabitilies (list)
"""
assert (not self.wait_priority_after_sampling or
not self.flag_wait_priority)
indices, probabilities = self._sample_indices_and_probabilities(
n, uniform_ratio=uniform_ratio)
sampled = [self.data[i] for i in indices]
self.sampled_indices = indices
self.flag_wait_priority = True
return sampled, probabilities

def set_last_priority(self, priority):
assert self.flag_wait_priority
assert (not self.wait_priority_after_sampling or
self.flag_wait_priority)
assert all([p > 0.0 for p in priority])
assert len(self.sampled_indices) == len(priority)
for i, p in zip(self.sampled_indices, priority):
self.priority_tree[i] = p
self.flag_wait_priority = False
self.sampled_indices = []

def _uniform_sample_indices_and_probabilities(self, n):
indices = random.sample(range(len(self.data)),
max(0, n - len(self.data_inf)))
probabilities = [1 / len(self)] * len(indices)
while len(indices) < n:
i = len(self.data)
e = self._pop_random_data_inf()
self.data.append(e)
del self.priority_tree[i]
indices.append(i)
probabilities.append(None)
return indices, probabilities


class SumTree (object):
Expand Down Expand Up @@ -119,17 +185,18 @@ def _center(self):

def _allocindex(self, ix):
if self.bd is None:
self.bd = (ix, ix+1)
self.bd = (ix, ix + 1)
while ix >= self.bd[1]:
r_bd = (self.bd[1], self.bd[1]*2 - self.bd[0])
r_bd = (self.bd[1], self.bd[1] * 2 - self.bd[0])
l = SumTree(self.bd, self.l, self.r, self.s)

r = SumTree(bd=r_bd)._initdescendant()
self.bd = (l.bd[0], r.bd[1])
self.l = l
self.r = r
# no need to update self.s because self.r.s == 0
while ix < self.bd[0]:
l_bd = (self.bd[0]*2 - self.bd[1], self.bd[0])
l_bd = (self.bd[0] * 2 - self.bd[1], self.bd[0])
l = SumTree(bd=l_bd)._initdescendant()
r = SumTree(self.bd, self.l, self.r, self.s)
self.bd = (l.bd[0], r.bd[1])
Expand Down
35 changes: 28 additions & 7 deletions chainerrl/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max):
assert 0.0 <= beta0 <= 1.0
self.alpha = alpha
self.beta = beta0
self.beta_add = (1.0 - beta0) / betasteps
if betasteps is None:
self.beta_add = 0
else:
self.beta_add = (1.0 - beta0) / betasteps
self.eps = eps
self.normalize_by_max = normalize_by_max

Expand Down Expand Up @@ -206,31 +209,49 @@ class PrioritizedEpisodicReplayBuffer (

def __init__(self, capacity=None,
alpha=0.6, beta0=0.4, betasteps=2e5, eps=1e-8,
normalize_by_max=True):
normalize_by_max=True,
default_priority_func=None,
uniform_ratio=0,
wait_priority_after_sampling=True,
return_sample_weights=True):
self.current_episode = []
self.episodic_memory = PrioritizedBuffer(capacity=None)
self.episodic_memory = PrioritizedBuffer(
capacity=None,
wait_priority_after_sampling=wait_priority_after_sampling)
self.memory = deque(maxlen=capacity)
self.capacity_left = capacity
self.default_priority_func = default_priority_func
self.uniform_ratio = uniform_ratio
self.return_sample_weights = return_sample_weights
PriorityWeightError.__init__(
self, alpha, beta0, betasteps, eps, normalize_by_max)

def sample_episodes(self, n_episodes, max_len=None):
"""Sample n unique samples from this replay buffer"""
assert len(self.episodic_memory) >= n_episodes
episodes, probabilities = self.episodic_memory.sample(n_episodes)
weights = self.weights_from_probabilities(probabilities)
episodes, probabilities = self.episodic_memory.sample(
n_episodes, uniform_ratio=self.uniform_ratio)
if max_len is not None:
episodes = [random_subseq(ep, max_len) for ep in episodes]
return episodes, weights
if self.return_sample_weights:
weights = self.weights_from_probabilities(probabilities)
return episodes, weights
else:
return episodes

def update_errors(self, errors):
self.episodic_memory.set_last_priority(
self.priority_from_errors(errors))

def stop_current_episode(self):
if self.current_episode:
if self.default_priority_func is not None:
priority = self.default_priority_func(self.current_episode)
else:
priority = None
self.memory.extend(self.current_episode)
self.episodic_memory.append(self.current_episode)
self.episodic_memory.append(self.current_episode,
priority=priority)
if self.capacity_left is not None:
self.capacity_left -= len(self.current_episode)
self.current_episode = []
Expand Down
20 changes: 19 additions & 1 deletion examples/gym/train_pcl_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from chainerrl.optimizers import rmsprop_async


def exp_return_of_episode(episode):
return np.exp(sum(x['reward'] for x in episode))


def main():
import logging

Expand Down Expand Up @@ -58,6 +62,8 @@ def main():
parser.add_argument('--logger-level', type=int, default=logging.DEBUG)
parser.add_argument('--monitor', action='store_true')
parser.add_argument('--train-async', action='store_true', default=False)
parser.add_argument('--prioritized-replay', action='store_true',
default=False)
parser.add_argument('--disable-online-update', action='store_true',
default=False)
parser.add_argument('--backprop-future-values', action='store_true',
Expand Down Expand Up @@ -126,6 +132,7 @@ def make_env(process_idx, test):
)

if not args.train_async and args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use()
model.to_gpu(args.gpu)

if args.train_async:
Expand All @@ -134,7 +141,18 @@ def make_env(process_idx, test):
opt = chainer.optimizers.Adam(alpha=args.lr)
opt.setup(model)

replay_buffer = chainerrl.replay_buffer.EpisodicReplayBuffer(10 ** 5)
if args.prioritized_replay:
replay_buffer = \
chainerrl.replay_buffer.PrioritizedEpisodicReplayBuffer(
capacity=5 * 10 ** 3,
uniform_ratio=0.1,
default_priority_func=exp_return_of_episode,
wait_priority_after_sampling=False,
return_sample_weights=False)
else:
replay_buffer = chainerrl.replay_buffer.EpisodicReplayBuffer(
capacity=5 * 10 ** 3)

agent = chainerrl.agents.PCL(
model, opt, replay_buffer=replay_buffer,
t_max=args.t_max, gamma=0.99,
Expand Down
Loading

0 comments on commit 67901fd

Please sign in to comment.