diff --git a/tensor2tensor/data_generators/gym.py b/tensor2tensor/data_generators/gym.py index d1a1cdc37..06b5ad0f3 100644 --- a/tensor2tensor/data_generators/gym.py +++ b/tensor2tensor/data_generators/gym.py @@ -19,25 +19,28 @@ from __future__ import division from __future__ import print_function -import os +import functools # Dependency imports +import gym +import numpy as np + from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem +from tensor2tensor.models.research import rl +from tensor2tensor.rl.envs import atari_wrappers from tensor2tensor.utils import registry import tensorflow as tf -def gym_lib(): - """Access to gym to allow for import of this file without a gym install.""" - try: - import gym # pylint: disable=g-import-not-at-top - except ImportError: - raise ImportError("pip install gym to use gym-based Problems") - return gym + +flags = tf.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("model_path", "", "File with model for pong") class GymDiscreteProblem(problem.Problem): @@ -55,7 +58,7 @@ def env_name(self): @property def env(self): if self._env is None: - self._env = gym_lib().make(self.env_name) + self._env = gym.make(self.env_name) return self._env @property @@ -143,3 +146,68 @@ def num_rewards(self): @property def num_steps(self): return 5000 + + +@registry.register_problem +class GymPongTrajectoriesFromPolicy(GymDiscreteProblem): + """Pong game, loaded actions.""" + + def __init__(self, event_dir, *args, **kwargs): + super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs) + self._env = None + self._event_dir = event_dir + env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda + gym.make("PongNoFrameskip-v4"), + warp=False, + frame_skip=4, + frame_stack=False) + hparams = rl.atari_base() + with tf.variable_scope("train"): + policy_lambda = hparams.network + policy_factory = tf.make_template( + "network", + functools.partial(policy_lambda, env_spec().action_space, hparams)) + self._max_frame_pl = tf.placeholder( + tf.float32, self.env.observation_space.shape) + actor_critic = policy_factory(tf.expand_dims(tf.expand_dims( + self._max_frame_pl, 0), 0)) + policy = actor_critic.policy + self._last_policy_op = policy.mode() + self._last_action = self.env.action_space.sample() + self._skip = 4 + self._skip_step = 0 + self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, + dtype=np.uint8) + self._sess = tf.Session() + model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*")) + model_saver.restore(self._sess, FLAGS.model_path) + + # TODO(blazej0): For training of atari agents wrappers are usually used. + # Below we have a hacky solution which is a workaround to be used together + # with atari_wrappers.MaxAndSkipEnv. + def get_action(self, observation=None): + if self._skip_step == self._skip - 2: self._obs_buffer[0] = observation + if self._skip_step == self._skip - 1: self._obs_buffer[1] = observation + self._skip_step = (self._skip_step + 1) % self._skip + if self._skip_step == 0: + max_frame = self._obs_buffer.max(axis=0) + self._last_action = int(self._sess.run( + self._last_policy_op, + feed_dict={self._max_frame_pl: max_frame})[0, 0]) + return self._last_action + + @property + def env_name(self): + return "PongNoFrameskip-v4" + + @property + def num_actions(self): + return 4 + + @property + def num_rewards(self): + return 2 + + @property + def num_steps(self): + return 5000 diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index c78d1f52a..32ef49901 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -44,6 +44,7 @@ from tensor2tensor.models.research import cycle_gan from tensor2tensor.models.research import gene_expression from tensor2tensor.models.research import multimodel +from tensor2tensor.models.research import rl from tensor2tensor.models.research import super_lm from tensor2tensor.models.research import transformer_moe from tensor2tensor.models.research import transformer_revnet diff --git a/tensor2tensor/models/research/rl.py b/tensor2tensor/models/research/rl.py index 7433026b0..858d6964e 100644 --- a/tensor2tensor/models/research/rl.py +++ b/tensor2tensor/models/research/rl.py @@ -49,6 +49,7 @@ def ppo_base_v1(): hparams.add_hparam("eval_every_epochs", 10) hparams.add_hparam("num_eval_agents", 3) hparams.add_hparam("video_during_eval", True) + hparams.add_hparam("save_models_every_epochs", 30) return hparams @@ -66,7 +67,23 @@ def discrete_action_base(): return hparams -# Neural networks for actor-critic algorithms +@registry.register_hparams +def atari_base(): + """Atari base parameters.""" + hparams = discrete_action_base() + hparams.learning_rate = 16e-5 + hparams.num_agents = 5 + hparams.epoch_length = 200 + hparams.gae_gamma = 0.985 + hparams.gae_lambda = 0.985 + hparams.entropy_loss_coef = 0.002 + hparams.value_loss_coef = 0.025 + hparams.optimization_epochs = 10 + hparams.epochs_num = 10000 + hparams.num_eval_agents = 1 + hparams.network = feed_forward_cnn_small_categorical_fun + return hparams + NetworkOutput = collections.namedtuple( "NetworkOutput", "policy, value, action_postprocessing") @@ -85,23 +102,24 @@ def feed_forward_gaussian_fun(action_space, config, observations): tf.shape(observations)[0], tf.shape(observations)[1], functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)]) - with tf.variable_scope("policy"): - x = flat_observations - for size in config.policy_layers: - x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) - mean = tf.contrib.layers.fully_connected( - x, action_space.shape[0], tf.tanh, - weights_initializer=mean_weights_initializer) - logstd = tf.get_variable( - "logstd", mean.shape[2:], tf.float32, logstd_initializer) - logstd = tf.tile( - logstd[None, None], - [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2)) - with tf.variable_scope("value"): - x = flat_observations - for size in config.value_layers: - x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) - value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0] + with tf.variable_scope("network_parameters"): + with tf.variable_scope("policy"): + x = flat_observations + for size in config.policy_layers: + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) + mean = tf.contrib.layers.fully_connected( + x, action_space.shape[0], tf.tanh, + weights_initializer=mean_weights_initializer) + logstd = tf.get_variable( + "logstd", mean.shape[2:], tf.float32, logstd_initializer) + logstd = tf.tile( + logstd[None, None], + [tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2)) + with tf.variable_scope("value"): + x = flat_observations + for size in config.value_layers: + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) + value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0] mean = tf.check_numerics(mean, "mean") logstd = tf.check_numerics(logstd, "logstd") value = tf.check_numerics(value, "value") @@ -119,17 +137,18 @@ def feed_forward_categorical_fun(action_space, config, observations): flat_observations = tf.reshape(observations, [ tf.shape(observations)[0], tf.shape(observations)[1], functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)]) - with tf.variable_scope("policy"): - x = flat_observations - for size in config.policy_layers: - x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) - logits = tf.contrib.layers.fully_connected(x, action_space.n, - activation_fn=None) - with tf.variable_scope("value"): - x = flat_observations - for size in config.value_layers: - x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) - value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0] + with tf.variable_scope("network_parameters"): + with tf.variable_scope("policy"): + x = flat_observations + for size in config.policy_layers: + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) + logits = tf.contrib.layers.fully_connected(x, action_space.n, + activation_fn=None) + with tf.variable_scope("value"): + x = flat_observations + for size in config.value_layers: + x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu) + value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0] policy = tf.contrib.distributions.Categorical(logits=logits) return NetworkOutput(policy, value, lambda a: a) @@ -141,7 +160,7 @@ def feed_forward_cnn_small_categorical_fun(action_space, config, observations): obs_shape = observations.shape.as_list() x = tf.reshape(observations, [-1] + obs_shape[2:]) - with tf.variable_scope("policy"): + with tf.variable_scope("network_parameters"): x = tf.to_float(x) / 255.0 x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2], activation_fn=tf.nn.relu, padding="SAME") diff --git a/tensor2tensor/rl/README.md b/tensor2tensor/rl/README.md index 8581ae21a..46e40403f 100644 --- a/tensor2tensor/rl/README.md +++ b/tensor2tensor/rl/README.md @@ -7,6 +7,16 @@ for now and under heavy development. Currently the only supported algorithm is Proximy Policy Optimization - PPO. -## Sample usage - training in Pendulum-v0 environment. +## Sample usage - training in the Pendulum-v0 environment. ```python rl/t2t_rl_trainer.py --problems=Pendulum-v0 --hparams_set continuous_action_base [--output_dir dir_location]``` + +## Sample usage - training in the PongNoFrameskip-v0 environment. + +```python tensor2tensor/rl/t2t_rl_trainer.py --problem stacked_pong --hparams_set atari_base --hparams num_agents=5 --output_dir /tmp/pong`date +%Y%m%d_%H%M%S```` + +## Sample usage - generation of a model + +```python tensor2tensor/bin/t2t-trainer --generate_data --data_dir=~/t2t_data --problems=gym_pong_trajectories_from_policy --hparams_set=base_atari --model_path [model]``` + +```python tensor2tensor/bin/t2t-datagen --data_dir=~/t2t_data --tmp_dir=~/t2t_data/tmp --problem=gym_pong_trajectories_from_policy --model_path [model]``` diff --git a/tensor2tensor/rl/envs/atari_wrappers.py b/tensor2tensor/rl/envs/atari_wrappers.py new file mode 100644 index 000000000..b8dd425ec --- /dev/null +++ b/tensor2tensor/rl/envs/atari_wrappers.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2018 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Various wrappers copied for Gym Baselines.""" + +from collections import deque +import gym +import numpy as np + + +# Adapted from the link below. +# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py + + +class WarpFrame(gym.ObservationWrapper): + """Wrap a frame.""" + + def __init__(self, env): + """Warp frames to 84x84 as done in the Nature paper and later work.""" + gym.ObservationWrapper.__init__(self, env) + self.width = 84 + self.height = 84 + self.observation_space = gym.spaces.Box( + low=0, high=255, + shape=(self.height, self.width, 1), dtype=np.uint8) + + def observation(self, frame): + import cv2 # pylint: disable=g-import-not-at-top + cv2.ocl.setUseOpenCL(False) + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = cv2.resize(frame, (self.width, self.height), + interpolation=cv2.INTER_AREA) + return frame[:, :, None] + + +class LazyFrames(object): + """Lazy frame storage.""" + + def __init__(self, frames): + """Lazy frame storage. + + This object ensures that common frames between the observations + are only stored once. It exists purely to optimize memory usage + which can be huge for DQN's 1M frames replay buffers. + This object should only be converted to numpy array before being passed + to the model. + + Args: + frames: the frames. + """ + self._frames = frames + + def __array__(self, dtype=None): + out = np.concatenate(self._frames, axis=2) + if dtype is not None: + out = out.astype(dtype) + return out + + +class FrameStack(gym.Wrapper): + """Stack frames.""" + + def __init__(self, env, k): + """Stack k last frames. Returns lazy array, memory efficient.""" + gym.Wrapper.__init__(self, env) + self.k = k + self.frames = deque([], maxlen=k) + shp = env.observation_space.shape + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8) + + def reset(self): + ob = self.env.reset() + for _ in range(self.k): + self.frames.append(ob) + return self._get_ob() + + def step(self, action): + ob, reward, done, info = self.env.step(action) + self.frames.append(ob) + return self._get_ob(), reward, done, info + + def _get_ob(self): + assert len(self.frames) == self.k + return LazyFrames(list(self.frames)) + + +class MaxAndSkipEnv(gym.Wrapper): + """Max and skip env.""" + + def __init__(self, env, skip=4): + """Return only every `skip`-th frame.""" + gym.Wrapper.__init__(self, env) + # Most recent raw observations (for max pooling across time steps). + self._obs_buffer = np.zeros((2,) + env.observation_space.shape, + dtype=np.uint8) + self._skip = skip + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + def step(self, action): + """Repeat action, sum reward, and max over last observations.""" + total_reward = 0.0 + done = None + for i in range(self._skip): + obs, reward, done, info = self.env.step(action) + if i == self._skip - 2: self._obs_buffer[0] = obs + if i == self._skip - 1: self._obs_buffer[1] = obs + total_reward += reward + if done: + break + # Note that the observation on the done=True frame + # doesn't matter + max_frame = self._obs_buffer.max(axis=0) + + return max_frame, total_reward, done, info + + +def wrap_atari(env, warp=False, frame_skip=False, frame_stack=False): + if warp: + env = WarpFrame(env) + if frame_skip: + env = MaxAndSkipEnv(env, frame_skip) + if frame_stack: + env = FrameStack(env, frame_stack) + return env diff --git a/tensor2tensor/rl/rl_trainer_lib.py b/tensor2tensor/rl/rl_trainer_lib.py index 28ec7e22c..3193b7044 100644 --- a/tensor2tensor/rl/rl_trainer_lib.py +++ b/tensor2tensor/rl/rl_trainer_lib.py @@ -18,6 +18,7 @@ from __future__ import absolute_import import functools +import os # Dependency imports @@ -28,6 +29,7 @@ from tensor2tensor.models.research import rl # pylint: disable=unused-import from tensor2tensor.rl import collect from tensor2tensor.rl import ppo +from tensor2tensor.rl.envs import atari_wrappers from tensor2tensor.rl.envs import utils import tensorflow as tf @@ -69,19 +71,24 @@ def define_train(hparams, environment_spec, event_dir): utils.define_batch_env(wrapped_eval_env_lambda, hparams.num_eval_agents, xvfb=hparams.video_during_eval), hparams, eval_phase=True) - return summary, eval_summary + return summary, eval_summary, policy_factory def train(hparams, environment_spec, event_dir=None): """Train.""" - train_summary_op, eval_summary_op = define_train(hparams, environment_spec, - event_dir) - + if environment_spec == "stacked_pong": + environment_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda + gym.make("PongNoFrameskip-v4"), + warp=False, frame_skip=4, frame_stack=False) + train_summary_op, eval_summary_op, _ = define_train(hparams, environment_spec, + event_dir) if event_dir: summary_writer = tf.summary.FileWriter( event_dir, graph=tf.get_default_graph(), flush_secs=60) + model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*")) else: summary_writer = None + model_saver = None with tf.Session() as sess: sess.run(tf.global_variables_initializer()) @@ -94,3 +101,7 @@ def train(hparams, environment_spec, event_dir=None): summary = sess.run(eval_summary_op) if summary_writer: summary_writer.add_summary(summary, epoch_index) + if (model_saver and hparams.save_models_every_epochs and + epoch_index % hparams.save_models_every_epochs == 0): + model_saver.save(sess, os.path.join(event_dir, + "model{}.ckpt".format(epoch_index)))