diff --git a/.coveragerc b/.coveragerc index f13922b..7b289b0 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,5 +1,5 @@ [report] -omit = *tests*, *examples*, setup.py +omit = *tests*, *examples*, setup.py, *safe-pilco-extension* exclude_lines = pragma: no cover raise AssertionError diff --git a/README.md b/README.md index aba8837..3556639 100644 --- a/README.md +++ b/README.md @@ -2,39 +2,37 @@ [![Build Status](https://travis-ci.org/nrontsis/PILCO.svg?branch=master)](https://travis-ci.org/nrontsis/PILCO) [![codecov](https://codecov.io/gh/nrontsis/PILCO/branch/master/graph/badge.svg)](https://codecov.io/gh/nrontsis/PILCO) -A modern \& clean implementation of the [PILCO](https://ieeexplore.ieee.org/abstract/document/6654139/) Algorithm in `TensorFlow`. +A modern \& clean implementation of the [PILCO](https://ieeexplore.ieee.org/abstract/document/6654139/) Algorithm in `TensorFlow v2`. Unlike PILCO's [original implementation](http://mlg.eng.cam.ac.uk/pilco/) which was written as a self-contained package of `MATLAB`, this repository aims to provide a clean implementation by heavy use of modern machine learning libraries. -In particular, we use `TensorFlow` to avoid the need for hardcoded gradients and scale to GPU architectures. Moreover, we use [`GPflow`](https://github.com/GPflow/GPflow) for Gaussian Process Regression. +In particular, we use `TensorFlow v2` to avoid the need for hardcoded gradients and scale to GPU architectures. Moreover, we use [`GPflow v2`](https://github.com/GPflow/GPflow) for Gaussian Process Regression. The core functionality is tested against the original `MATLAB` implementation. ## Example of usage -Before using, or installing, PILCO, you need to have `Tensorflow 1.13.1` installed (either the gpu or the cpu version). It is recommended to install everything in a fresh `conda` environment with `python>=3.7`. Given `Tensorflow`, PILCO can be installed as follows +Before using `PILCO` you have to install it by running: ``` git clone https://github.com/nrontsis/PILCO && cd PILCO python setup.py develop ``` +It is recommended to install everything in a fresh conda environment with `python>=3.7` -The examples included in this repo use [`OpenAI gym 0.15.3`](https://github.com/openai/gym#installation) and [`mujoco-py 2.0.2.7`](https://github.com/openai/mujoco-py#install-mujoco). Once these dependencies are installed, you can run one of the examples as follows +The examples included in this repo use [`OpenAI gym 0.15.3`](https://github.com/openai/gym#installation) and [`mujoco-py 2.0.2.7`](https://github.com/openai/mujoco-py#install-mujoco). Theses dependecies should be installed manually. Then, you can run one of the examples as follows ``` python examples/inverted_pendulum.py ``` -While running an example, `Tensorflow` might print a lot of warnings, [some of which are deprecated](https://github.com/tensorflow/tensorflow/issues/25996). If necessary, you can suppress them by running -```python -tf.logging.set_verbosity(tf.logging.ERROR) -``` -right after including `TensorFlow` in Python. + +## Example Extension: Safe PILCO +As an example of the extensibility of the framework, we include in the folder `safe_pilco_extension` an extension of the standard PILCO algorithm that takes safety constraints (defined on the environment's state space) into account as in [https://arxiv.org/abs/1712.05556](https://arxiv.org/pdf/1712.05556.pdf). The `safe_swimmer_run.py` and `safe_cars_run.py` in the `examples` folder demonstrate the use of this extension. ## Credits: The following people have been involved in the development of this package: * [Nikitas Rontsis](https://github.com/nrontsis) -* [Kyriakos Polymenakos](https://github.com/kyr-pol) +* [Kyriakos Polymenakos](https://github.com/kyr-pol/) ## References -See the following publications for a description of the algorithm: [1](https://ieeexplore.ieee.org/abstract/document/6654139/), [2](http://mlg.eng.cam.ac.uk/pub/pdf/DeiRas11.pdf), -[3](https://pdfs.semanticscholar.org/c9f2/1b84149991f4d547b3f0f625f710750ad8d9.pdf) - +See the following publications for a description of the algorithm: [1](https://ieeexplore.ieee.org/abstract/document/6654139/), [2](http://mlg.eng.cam.ac.uk/pub/pdf/DeiRas11.pdf), +[3](https://pdfs.semanticscholar.org/c9f2/1b84149991f4d547b3f0f625f710750ad8d9.pdf) \ No newline at end of file diff --git a/examples/inv_double_pendulum.py b/examples/inv_double_pendulum.py index f1bb3a1..2800b9f 100644 --- a/examples/inv_double_pendulum.py +++ b/examples/inv_double_pendulum.py @@ -5,8 +5,8 @@ from pilco.controllers import RbfController, LinearController from pilco.rewards import ExponentialReward import tensorflow as tf -from tensorflow import logging from utils import rollout, policy +from gpflow import set_trainable np.random.seed(0) # Introduces a simple wrapper for the gym environment @@ -29,7 +29,7 @@ def state_trans(self, s): def step(self, action): ob, r, done, _ = self.env.step(action) - if np.abs(ob[0])> 0.98 or np.abs(ob[-3]) > 0.1 or np.abs(ob[-2]) > 0.1 or np.abs(ob[-1]) > 0.1: + if np.abs(ob[0])> 0.90 or np.abs(ob[-3]) > 0.15 or np.abs(ob[-2]) > 0.15 or np.abs(ob[-1]) > 0.15: done = True return self.state_trans(ob), r, done, {} @@ -41,32 +41,32 @@ def render(self): self.env.render() -SUBS = 1 -bf = 40 -maxiter=80 -state_dim = 6 -control_dim = 1 -max_action=1.0 # actions for these environments are discrete -target = np.zeros(state_dim) -weights = 3.0 * np.eye(state_dim) -weights[0,0] = 0.5 -weights[3,3] = 0.5 -m_init = np.zeros(state_dim)[None, :] -S_init = 0.01 * np.eye(state_dim) -T = 40 -J = 1 -N = 12 -T_sim = 130 -restarts=True -lens = [] - -with tf.Session() as sess: +if __name__=='__main__': + SUBS = 1 + bf = 40 + maxiter=10 + state_dim = 6 + control_dim = 1 + max_action=1.0 # actions for these environments are discrete + target = np.zeros(state_dim) + weights = 5.0 * np.eye(state_dim) + weights[0,0] = 1.0 + weights[3,3] = 1.0 + m_init = np.zeros(state_dim)[None, :] + S_init = 0.005 * np.eye(state_dim) + T = 40 + J = 5 + N = 12 + T_sim = 130 + restarts=True + lens = [] + env = DoublePendWrapper() # Initial random rollouts to generate a dataset - X,Y = rollout(env, None, timesteps=T, random=True, SUBS=SUBS) + X, Y, _, _ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, render=True) for i in range(1,J): - X_, Y_ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True) + X_, Y_, _, _ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True, render=True) X = np.vstack((X, X_)) Y = np.vstack((Y, Y_)) @@ -77,21 +77,19 @@ def render(self): R = ExponentialReward(state_dim=state_dim, t=target, W=weights) - pilco = PILCO(X, Y, controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init) + pilco = PILCO((X, Y), controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init) # for numerical stability for model in pilco.mgpr.models: - # model.kern.lengthscales.prior = gpflow.priors.Gamma(1,10) priors have to be included before - # model.kern.variance.prior = gpflow.priors.Gamma(1.5,2) before the model gets compiled - model.likelihood.variance = 0.001 - model.likelihood.variance.trainable = False + model.likelihood.variance.assign(0.001) + set_trainable(model.likelihood.variance, False) for rollouts in range(N): print("**** ITERATION no", rollouts, " ****") pilco.optimize_models(maxiter=maxiter, restarts=2) pilco.optimize_policy(maxiter=maxiter, restarts=2) - X_new, Y_new = rollout(env, pilco, timesteps=T_sim, verbose=True, SUBS=SUBS) + X_new, Y_new, _, _ = rollout(env, pilco, timesteps=T_sim, verbose=True, SUBS=SUBS, render=True) # Since we had decide on the various parameters of the reward function # we might want to verify that it behaves as expected by inspection @@ -102,7 +100,7 @@ def render(self): # Update dataset X = np.vstack((X, X_new[:T, :])); Y = np.vstack((Y, Y_new[:T, :])) - pilco.mgpr.set_XY(X, Y) + pilco.mgpr.set_data((X, Y)) lens.append(len(X_new)) print(len(X_new)) diff --git a/examples/inverted_pendulum.py b/examples/inverted_pendulum.py index 487671c..18aaa38 100644 --- a/examples/inverted_pendulum.py +++ b/examples/inverted_pendulum.py @@ -4,41 +4,36 @@ from pilco.controllers import RbfController, LinearController from pilco.rewards import ExponentialReward import tensorflow as tf -from tensorflow import logging +from gpflow import set_trainable +# from tensorflow import logging np.random.seed(0) from utils import rollout, policy -with tf.Session(graph=tf.Graph()) as sess: - env = gym.make('InvertedPendulum-v2') - # Initial random rollouts to generate a dataset - X,Y = rollout(env=env, pilco=None, random=True, timesteps=40) - for i in range(1,3): - X_, Y_ = rollout(env=env, pilco=None, random=True, timesteps=40) - X = np.vstack((X, X_)) - Y = np.vstack((Y, Y_)) +env = gym.make('InvertedPendulum-v2') +# Initial random rollouts to generate a dataset +X,Y, _, _ = rollout(env=env, pilco=None, random=True, timesteps=40, render=True) +for i in range(1,5): + X_, Y_, _, _ = rollout(env=env, pilco=None, random=True, timesteps=40, render=True) + X = np.vstack((X, X_)) + Y = np.vstack((Y, Y_)) - state_dim = Y.shape[1] - control_dim = X.shape[1] - state_dim - controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=5) - #controller = LinearController(state_dim=state_dim, control_dim=control_dim) +state_dim = Y.shape[1] +control_dim = X.shape[1] - state_dim +controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=10) +# controller = LinearController(state_dim=state_dim, control_dim=control_dim) - pilco = PILCO(X, Y, controller=controller, horizon=40) - # Example of user provided reward function, setting a custom target state - # R = ExponentialReward(state_dim=state_dim, t=np.array([0.1,0,0,0])) - # pilco = PILCO(X, Y, controller=controller, horizon=40, reward=R) +pilco = PILCO((X, Y), controller=controller, horizon=40) +# Example of user provided reward function, setting a custom target state +# R = ExponentialReward(state_dim=state_dim, t=np.array([0.1,0,0,0])) +# pilco = PILCO(X, Y, controller=controller, horizon=40, reward=R) - # Example of fixing a parameter, optional, for a linear controller only - #pilco.controller.b = np.array([[0.0]]) - #pilco.controller.b.trainable = False - - for rollouts in range(3): - pilco.optimize_models() - pilco.optimize_policy() - import pdb; pdb.set_trace() - X_new, Y_new = rollout(env=env, pilco=pilco, timesteps=100) - print("No of ops:", len(tf.get_default_graph().get_operations())) - # Update dataset - X = np.vstack((X, X_new)); Y = np.vstack((Y, Y_new)) - pilco.mgpr.set_XY(X, Y) +for rollouts in range(3): + pilco.optimize_models() + pilco.optimize_policy() + import pdb; pdb.set_trace() + X_new, Y_new, _, _ = rollout(env=env, pilco=pilco, timesteps=100, render=True) + # Update dataset + X = np.vstack((X, X_new)); Y = np.vstack((Y, Y_new)) + pilco.mgpr.set_data((X, Y)) diff --git a/examples/linear_cars_env.py b/examples/linear_cars_env.py new file mode 100644 index 0000000..1f987f8 --- /dev/null +++ b/examples/linear_cars_env.py @@ -0,0 +1,34 @@ +import numpy as np +from gym import spaces +from gym.core import Env + +class LinearCars(Env): + def __init__(self): + self.action_space = spaces.Box(low=-0.4, high=0.4, shape=(1,)) + self.observation_space = spaces.Box(low=-100, high=100, shape=(4,)) + self.M = 1 # car mass [kg] + self.b = 0.001 # friction coef [N/m/s] + self.Dt = 0.50 # timestep [s] + + self.A = np.array([[0, self.Dt, 0, 0], + [0, -self.b*self.Dt/self.M, 0, 0], + [0, 0, 0, self.Dt], + [0, 0, 0, 0]]) + + self.B = np.array([0,self.Dt/self.M, 0, 0]).reshape((4,1)) + + self.initial_state = np.array([-6.0, 1.0, -5.0, 1.0]).reshape((4,1)) + + def step(self, action): + self.state += self.A @ self.state + self.B * action + #0.1 * np.random.normal(scale=[[1e-3], [1e-3], [1e-3], [0.001]], size=(4,1)) + + if self.state[0] < 0: + reward = -1 + else: + reward = 1 + return np.reshape(self.state[:], (4,)), reward, False, None + + def reset(self): + self.state = self.initial_state + 0.03 * np.random.normal(size=(4,1)) + return np.reshape(self.state[:], (4,)) diff --git a/examples/mountain_car.py b/examples/mountain_car.py new file mode 100644 index 0000000..1efc97f --- /dev/null +++ b/examples/mountain_car.py @@ -0,0 +1,69 @@ +import numpy as np +import gym +from pilco.models import PILCO +from pilco.controllers import RbfController, LinearController +from pilco.rewards import ExponentialReward +import tensorflow as tf +np.random.seed(0) +from utils import policy, rollout, Normalised_Env + + +SUBS = 5 +T = 25 +env = gym.make('MountainCarContinuous-v0') +# Initial random rollouts to generate a dataset +X1,Y1, _, _ = rollout(env=env, pilco=None, random=True, timesteps=T, SUBS=SUBS, render=True) +for i in range(1,5): + X1_, Y1_,_,_ = rollout(env=env, pilco=None, random=True, timesteps=T, SUBS=SUBS, render=True) + X1 = np.vstack((X1, X1_)) + Y1 = np.vstack((Y1, Y1_)) +env.close() + +env = Normalised_Env('MountainCarContinuous-v0', np.mean(X1[:,:2],0), np.std(X1[:,:2], 0)) +X = np.zeros(X1.shape) +X[:, :2] = np.divide(X1[:, :2] - np.mean(X1[:,:2],0), np.std(X1[:,:2], 0)) +X[:, 2] = X1[:,-1] # control inputs are not normalised +Y = np.divide(Y1 , np.std(X1[:,:2], 0)) + +state_dim = Y.shape[1] +control_dim = X.shape[1] - state_dim +m_init = np.transpose(X[0,:-1,None]) +S_init = 0.5 * np.eye(state_dim) +controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=25) + +R = ExponentialReward(state_dim=state_dim, + t=np.divide([0.5,0.0] - env.m, env.std), + W=np.diag([0.5,0.1]) + ) +pilco = PILCO((X, Y), controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init) + +best_r = 0 +all_Rs = np.zeros((X.shape[0], 1)) +for i in range(len(all_Rs)): + all_Rs[i,0] = R.compute_reward(X[i,None,:-1], 0.001 * np.eye(state_dim))[0] + +ep_rewards = np.zeros((len(X)//T,1)) + +for i in range(len(ep_rewards)): + ep_rewards[i] = sum(all_Rs[i * T: i*T + T]) + +for model in pilco.mgpr.models: + model.likelihood.variance.assign(0.05) + set_trainable(model.likelihood.variance, False) + +r_new = np.zeros((T, 1)) +for rollouts in range(5): + pilco.optimize_models() + pilco.optimize_policy(maxiter=100, restarts=3) + import pdb; pdb.set_trace() + X_new, Y_new,_,_ = rollout(env=env, pilco=pilco, timesteps=T, SUBS=SUBS, render=True) + + for i in range(len(X_new)): + r_new[:, 0] = R.compute_reward(X_new[i,None,:-1], 0.001 * np.eye(state_dim))[0] + total_r = sum(r_new) + _, _, r = pilco.predict(m_init, S_init, T) + + print("Total ", total_r, " Predicted: ", r) + X = np.vstack((X, X_new)); Y = np.vstack((Y, Y_new)); + all_Rs = np.vstack((all_Rs, r_new)); ep_rewards = np.vstack((ep_rewards, np.reshape(total_r,(1,1)))) + pilco.mgpr.set_data((X, Y)) diff --git a/examples/pendulum_swing_up.py b/examples/pendulum_swing_up.py index b6a884f..fd514cd 100644 --- a/examples/pendulum_swing_up.py +++ b/examples/pendulum_swing_up.py @@ -4,8 +4,8 @@ from pilco.controllers import RbfController, LinearController from pilco.rewards import ExponentialReward import tensorflow as tf -from tensorflow import logging from utils import rollout, policy +from gpflow import set_trainable np.random.seed(0) # NEEDS a different initialisation than the one in gym (change the reset() method), @@ -35,28 +35,27 @@ def reset(self): def render(self): self.env.render() +if __name__=='__main__': + SUBS=3 + bf = 30 + maxiter=50 + max_action=2.0 + target = np.array([1.0, 0.0, 0.0]) + weights = np.diag([2.0, 2.0, 0.3]) + m_init = np.reshape([-1.0, 0, 0.0], (1,3)) + S_init = np.diag([0.01, 0.05, 0.01]) + T = 40 + T_sim = T + J = 4 + N = 8 + restarts = 2 -SUBS=3 -bf = 30 -maxiter=50 -max_action=2.0 -target = np.array([1.0, 0.0, 0.0]) -weights = np.diag([2.0, 2.0, 0.3]) -m_init = np.reshape([-1.0, 0, 0.0], (1,3)) -S_init = np.diag([0.01, 0.05, 0.01]) -T = 40 -T_sim = T -J = 4 -N = 8 -restarts = 2 - -with tf.Session() as sess: env = myPendulum() # Initial random rollouts to generate a dataset - X,Y = rollout(env, None, timesteps=T, random=True, SUBS=SUBS) + X, Y, _, _ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, render=True) for i in range(1,J): - X_, Y_ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True) + X_, Y_, _, _ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True, render=True) X = np.vstack((X, X_)) Y = np.vstack((Y, Y_)) @@ -64,30 +63,31 @@ def render(self): control_dim = X.shape[1] - state_dim controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=bf, max_action=max_action) - R = ExponentialReward(state_dim=state_dim, t=target, W=weights) - pilco = PILCO(X, Y, controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init) + pilco = PILCO((X, Y), controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init) - # for numerical stability + # for numerical stability, we can set the likelihood variance parameters of the GP models for model in pilco.mgpr.models: - model.likelihood.variance = 0.001 - model.likelihood.variance.trainable = False + model.likelihood.variance.assign(0.001) + set_trainable(model.likelihood.variance, False) + r_new = np.zeros((T, 1)) for rollouts in range(N): print("**** ITERATION no", rollouts, " ****") pilco.optimize_models(maxiter=maxiter, restarts=2) pilco.optimize_policy(maxiter=maxiter, restarts=2) - X_new, Y_new = rollout(env, pilco, timesteps=T_sim, verbose=True, SUBS=SUBS) + X_new, Y_new, _, _ = rollout(env, pilco, timesteps=T_sim, verbose=True, SUBS=SUBS, render=True) # Since we had decide on the various parameters of the reward function # we might want to verify that it behaves as expected by inspection - # cur_rew = 0 - # for t in range(0,len(X_new)): - # cur_rew += reward_wrapper(R, X_new[t, 0:state_dim, None].transpose(), 0.0001 * np.eye(state_dim))[0] - # print('On this episode reward was ', cur_rew) + for i in range(len(X_new)): + r_new[:, 0] = R.compute_reward(X_new[i,None,:-1], 0.001 * np.eye(state_dim))[0] + total_r = sum(r_new) + _, _, r = pilco.predict(X_new[0,None,:-1], 0.001 * S_init, T) + print("Total ", total_r, " Predicted: ", r) # Update dataset X = np.vstack((X, X_new)); Y = np.vstack((Y, Y_new)) - pilco.mgpr.set_XY(X, Y) + pilco.mgpr.set_data((X, Y)) diff --git a/examples/safe_cars_run.py b/examples/safe_cars_run.py new file mode 100644 index 0000000..c6cff95 --- /dev/null +++ b/examples/safe_cars_run.py @@ -0,0 +1,142 @@ +import numpy as np +import tensorflow as tf +import gym +from pilco.models import PILCO +from pilco.controllers import RbfController, LinearController +from pilco.rewards import ExponentialReward, LinearReward + +from linear_cars_env import LinearCars + +from safe_pilco_extension.rewards_safe import RiskOfCollision, ObjectiveFunction +from safe_pilco_extension.safe_pilco import SafePILCO +from utils import rollout, policy + +from gpflow import config +from gpflow import set_trainable +int_type = config.default_int() +float_type = config.default_float() + + +class Normalised_Env(): + def __init__(self, m, std): + self.env = LinearCars() + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + self.m = m + self.std = std + + def state_trans(self, x): + return np.divide(x-self.m, self.std) + + def step(self, action): + ob, r, done, _ = self.env.step(action) + return self.state_trans(ob), r, done, {} + + def reset(self): + ob = self.env.reset() + return self.state_trans(ob) + + def render(self): + self.env.render() + + +def safe_cars(seed=0): + T = 25 + th = 0.10 + np.random.seed(seed) + J = 5 + N = 5 + eval_runs = 5 + env = LinearCars() + # Initial random rollouts to generate a dataset + X1, Y1, _, _ = rollout(env, pilco=None, timesteps=T, verbose=True, random=True, render=False) + for i in range(1,5): + X1_, Y1_, _, _ = rollout(env, pilco=None, timesteps=T, verbose=True, random=True, render=False) + X1 = np.vstack((X1, X1_)) + Y1 = np.vstack((Y1, Y1_)) + + env = Normalised_Env(np.mean(X1[:,:4],0), np.std(X1[:,:4], 0)) + X, Y, _, _ = rollout(env, pilco=None, timesteps=T, verbose=True, random=True, render=False) + for i in range(1,J): + X_, Y_, _, _ = rollout(env, pilco=None, timesteps=T, verbose=True, random=True, render=False) + X = np.vstack((X, X_)) + Y = np.vstack((Y, Y_)) + + state_dim = Y.shape[1] + control_dim = X.shape[1] - state_dim + + m_init = np.transpose(X[0,:-1,None]) + S_init = 0.1 * np.eye(state_dim) + + controller = RbfController(state_dim=state_dim, control_dim=control_dim, + num_basis_functions=40, max_action=0.2) + + #w1 = np.diag([1.5, 0.001, 0.001, 0.001]) + #t1 = np.divide(np.array([3.0, 1.0, 3.0, 1.0]) - env.m, env.std) + #R1 = ExponentialReward(state_dim=state_dim, t=t1, W=w1) + # R1 = LinearReward(state_dim=state_dim, W=np.array([0.1, 0.0, 0.0, 0.0])) + R1 = LinearReward(state_dim=state_dim, W=np.array([1.0 * env.std[0], 0., 0., 0,])) + + bound_x1 = 1 / env.std[0] + bound_x2 = 1 / env.std[2] + B = RiskOfCollision(2, [-bound_x1-env.m[0]/env.std[0], -bound_x2 - env.m[2]/env.std[2]], + [bound_x1 - env.m[0]/env.std[0], bound_x2 - env.m[2]/env.std[2]]) + + pilco = SafePILCO((X, Y), controller=controller, mu=-300.0, reward_add=R1, reward_mult=B, horizon=T, m_init=m_init, S_init=S_init) + + for model in pilco.mgpr.models: + model.likelihood.variance.assign(0.001) + set_trainable(model.likelihood.variance, False) + + # define tolerance + new_data = True + # init = tf.global_variables_initializer() + evaluation_returns_full = np.zeros((N, eval_runs)) + evaluation_returns_sampled = np.zeros((N, eval_runs)) + X_eval = [] + for rollouts in range(N): + print("***ITERATION**** ", rollouts) + if new_data: + pilco.optimize_models(maxiter=100) + new_data = False + pilco.optimize_policy(maxiter=20, restarts=2) + # check safety + m_p = np.zeros((T, state_dim)) + S_p = np.zeros((T, state_dim, state_dim)) + predicted_risks = np.zeros(T) + predicted_rewards = np.zeros(T) + + for h in range(T): + m_h, S_h, _ = pilco.predict(m_init, S_init, h) + m_p[h,:], S_p[h,:,:] = m_h[:], S_h[:,:] + predicted_risks[h], _ = B.compute_reward(m_h, S_h) + predicted_rewards[h], _ = R1.compute_reward(m_h, S_h) + overall_risk = 1 - np.prod(1.0-predicted_risks) + + print("Predicted episode's return: ", sum(predicted_rewards)) + print("Overall risk ", overall_risk) + print("Mu is ", pilco.mu.numpy()) + print("bound1 ", bound_x1, " bound1 ", bound_x2) + + if overall_risk < th: + X_new, Y_new, _, _ = rollout(env, pilco=pilco, timesteps=T, verbose=True, render=False) + new_data = True + X = np.vstack((X, X_new)); Y = np.vstack((Y, Y_new)) + pilco.mgpr.set_data((X, Y)) + if overall_risk < (th/4): + pilco.mu.assign(0.75 * pilco.mu.numpy()) + + else: + X_new, Y_new,_,_ = rollout(env, pilco=pilco, timesteps=T, verbose=True, render=False) + print(m_p[:,0] - X_new[:,0]) + print(m_p[:,2] - X_new[:,2]) + print("*********CHANGING***********") + _, _, r = pilco.predict(m_init, S_init, T) + print(r) + # to verify this actually changes, run the reward wrapper before and after on the same trajectory + pilco.mu.assign(1.5 * pilco.mu.numpy()) + _, _, r = pilco.predict(m_init, S_init, T) + print(r) + +if __name__=='__main__': + safe_cars() diff --git a/examples/safe_swimmer_run.py b/examples/safe_swimmer_run.py new file mode 100644 index 0000000..a52b3c8 --- /dev/null +++ b/examples/safe_swimmer_run.py @@ -0,0 +1,133 @@ +import numpy as np +import gym +from pilco.models import PILCO +from pilco.controllers import RbfController, LinearController +from pilco.rewards import ExponentialReward, LinearReward, CombinedRewards +from safe_pilco_extension.rewards_safe import SingleConstraint +import tensorflow as tf +from utils import rollout, policy +from gpflow import set_trainable + +# Uses a wrapper for the Swimmer +# First one to use a combined reward function, that includes penalties +# Uses the video capture function + +# np.random.seed(int(sys.argv[2])) +# name = "safe_swimmer_final" + sys.argv[2] + +class SwimmerWrapper(): + def __init__(self): + self.env = gym.make('Swimmer-v2').env + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + + def state_trans(self, s): + return np.hstack([[self.x],s]) + + def step(self, action): + ob, r, done, _ = self.env.step(action) + self.x += r / 10.0 + return self.state_trans(ob), r, done, {} + + def reset(self): + ob = self.env.reset() + self.x = 0.0 + return self.state_trans(ob) + + def render(self): + self.env.render() + +def safe_swimmer_run(seed=0, logging=False): + env = SwimmerWrapper() + state_dim = 9 + control_dim = 2 + SUBS = 2 + maxiter = 60 + max_action = 1.0 + m_init = np.reshape(np.zeros(state_dim), (1,state_dim)) # initial state mean + S_init = 0.05 * np.eye(state_dim) + J = 1 + N = 12 + T = 25 + bf = 30 + T_sim=100 + + # Reward function that dicourages the joints from hitting their max angles + weights_l = np.zeros(state_dim) + weights_l[0] = 0.5 + max_ang = (100 / 180 * np.pi) * 0.95 + R1 = LinearReward(state_dim, weights_l) + + C1 = SingleConstraint(1, low=-max_ang, high=max_ang, inside=False) + C2 = SingleConstraint(2, low=-max_ang, high=max_ang, inside=False) + C3 = SingleConstraint(3, low=-max_ang, high=max_ang, inside=False) + R = CombinedRewards(state_dim, [R1, C1, C2, C3], coefs=[1.0, -10.0, -10.0, -10.0]) + + th=0.2 + # Initial random rollouts to generate a dataset + X,Y, _, _ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True) + for i in range(1,J): + X_, Y_ , _, _= rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True) + X = np.vstack((X, X_)) + Y = np.vstack((Y, Y_)) + + state_dim = Y.shape[1] + control_dim = X.shape[1] - state_dim + controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=bf, max_action=max_action) + + pilco = PILCO((X, Y), controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init) + for model in pilco.mgpr.models: + model.likelihood.variance.assign(0.001) + set_trainable(model.likelihood.variance, False) + + new_data = True + eval_runs = T_sim + evaluation_returns_full = np.zeros((N, eval_runs)) + evaluation_returns_sampled = np.zeros((N, eval_runs)) + X_eval = [] + for rollouts in range(N): + print("**** ITERATION no", rollouts, " ****") + if new_data: pilco.optimize_models(maxiter=100); new_data = False + pilco.optimize_policy(maxiter=1, restarts=2) + + m_p = np.zeros((T, state_dim)) + S_p = np.zeros((T, state_dim, state_dim)) + predicted_risk1 = np.zeros(T) + predicted_risk2 = np.zeros(T) + predicted_risk3 = np.zeros(T) + for h in range(T): + m_h, S_h, _ = pilco.predict(m_init, S_init, h) + m_p[h,:], S_p[h,:,:] = m_h[:], S_h[:,:] + predicted_risk1[h], _ = C1.compute_reward(m_h, S_h) + predicted_risk2[h], _ = C2.compute_reward(m_h, S_h) + predicted_risk3[h], _ = C3.compute_reward(m_h, S_h) + estimate_risk1 = 1 - np.prod(1.0-predicted_risk1) + estimate_risk2 = 1 - np.prod(1.0-predicted_risk2) + estimate_risk3 = 1 - np.prod(1.0-predicted_risk3) + overall_risk = 1 - (1 - estimate_risk1) * (1 - estimate_risk2) * (1 - estimate_risk3) + if overall_risk < th: + X_new, Y_new, _, _ = rollout(env, pilco, timesteps=T_sim, verbose=True, SUBS=SUBS) + new_data = True + # Update dataset + X = np.vstack((X, X_new[:T,:])); Y = np.vstack((Y, Y_new[:T,:])) + pilco.mgpr.set_data((X, Y)) + if estimate_risk1 < th/10: + R.coefs.assign(R.coefs.value() * [1.0, 0.75, 1.0, 1.0]) + if estimate_risk2 < th/10: + R.coefs.assign(R.coefs.value() * [1.0, 1.0, 0.75, 1.0]) + if estimate_risk3 < th/10: + R.coefs.assign(R.coefs.value() * [1.0, 1.0, 1.0, 0.75]) + else: + print("*********CHANGING***********") + if estimate_risk1 > th/3: + R.coefs.assign(R.coefs.value() * [1.0, 1.5, 1.0, 1.0]) + if estimate_risk2 > th/3: + R.coefs.assign(R.coefs.value() * [1.0, 1.0, 1.5, 1.0]) + if estimate_risk3 > th/3: + R.coefs.assign(R.coefs.value() * [1.0, 1.0, 1.0, 1.5]) + _, _, r = pilco.predict(m_init, S_init, T) + + + +if __name__=='__main__': + safe_swimmer_run() diff --git a/examples/swimmer.py b/examples/swimmer.py index 80f2512..e79276b 100644 --- a/examples/swimmer.py +++ b/examples/swimmer.py @@ -1,104 +1,111 @@ +import os import numpy as np import gym from pilco.models import PILCO from pilco.controllers import RbfController, LinearController from pilco.rewards import ExponentialReward, LinearReward, CombinedRewards import tensorflow as tf -from tensorflow import logging from utils import rollout, policy -np.random.seed(0) -# Uses a wrapper for the Swimmer -# First one to use a combined reward function, that includes penalties -# Uses the video capture function - -class SwimmerWrapper(): - def __init__(self): - self.env = gym.make('Swimmer-v2').env - self.action_space = self.env.action_space - self.observation_space = self.env.observation_space - - def state_trans(self, s): - return np.hstack([[self.x],s]) - - def step(self, action): - ob, r, done, _ = self.env.step(action) - self.x += r / 10.0 - return self.state_trans(ob), r, done, {} - - def reset(self): - ob = self.env.reset() - self.x = 0.0 - return self.state_trans(ob) - - def render(self): - self.env.render() - - -with tf.Session() as sess: - env = SwimmerWrapper() - state_dim = 9 - control_dim = 2 - SUBS = 2 - maxiter = 40 - max_action = 1.0 - m_init = np.reshape(np.zeros(state_dim), (1,state_dim)) # initial state mean - S_init = 0.05 * np.eye(state_dim) - J = 15 - N = 15 - T = 20 - bf = 20 - T_sim=100 - - # Reward function that dicourages the joints from hitting their max angles - weights_l = np.zeros(state_dim) - weights_l[0] = 0.1 - max_ang = 100 / 180 * np.pi - t1 = np.zeros(state_dim) - t1[2] = max_ang - w1 = 1e-6 * np.eye(state_dim) - w1[2,2] = 10 - t2 = np.zeros(state_dim) - t2[3] = max_ang - w2 = 1e-6 * np.eye(state_dim) - w2[3,3] = 10 - t3 = np.zeros(state_dim); t3[2] = -max_ang - t4 = np.zeros(state_dim); t4[3] = -max_ang - R2 = LinearReward(state_dim, weights_l) - R3 = ExponentialReward(state_dim, W=w1, t=t1) - R4 = ExponentialReward(state_dim, W=w2, t=t2) - R5 = ExponentialReward(state_dim, W=w1, t=t3) - R6 = ExponentialReward(state_dim, W=w2, t=t4) - R = CombinedRewards(state_dim, [R2, R3, R4, R5, R6], coefs=[1.0, -1.0, -1.0, -1.0, -1.0]) - - # Initial random rollouts to generate a dataset - X,Y = rollout(env, None, timesteps=T, random=True, SUBS=SUBS) - for i in range(1,J): - X_, Y_ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True) - X = np.vstack((X, X_)) - Y = np.vstack((Y, Y_)) - - state_dim = Y.shape[1] - control_dim = X.shape[1] - state_dim - controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=bf, max_action=max_action) - - pilco = PILCO(X, Y, controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init) - for model in pilco.mgpr.models: - model.likelihood.variance = 0.001 - model.likelihood.variance.trainable = False - - for rollouts in range(N): - print("**** ITERATION no", rollouts, " ****") - pilco.optimize_models(maxiter=maxiter) - pilco.optimize_policy(maxiter=maxiter, restarts=2) - - X_new, Y_new = rollout(env, pilco, timesteps=T_sim, verbose=True, SUBS=SUBS) - - # Update dataset - X = np.vstack((X, X_new[:T,:])); Y = np.vstack((Y, Y_new[:T,:])) - pilco.mgpr.set_XY(X, Y) - - # Saving a video of a run +seed = 0 +np.random.seed(seed) +name = "swimmer_new" + str(seed) +env = gym.make('Swimmer-v2').env +state_dim = 8 +control_dim = 2 +SUBS = 5 +maxiter = 80 +max_action = 1.0 +m_init = np.reshape(np.zeros(state_dim), (1,state_dim)) # initial state mean +S_init = 0.005 * np.eye(state_dim) +J = 10 +N = 15 +T = 15 +bf = 40 +T_sim = 50 + +# Reward function that dicourages the joints from hitting their max angles +max_ang = 95 / 180 * np.pi +rewards = [] +rewards.append(LinearReward(state_dim, [0, 0, 0, 1.0, 0, 0, 0, 0])) +rewards.append(ExponentialReward( + state_dim, + W=np.diag(np.array([0, 0, 10, 0, 0, 0, 0, 0]) + 1e-6), + t=[0, 0, max_ang, 0, 0, 0, 0, 0]) +) +rewards.append(ExponentialReward( + state_dim, + W=np.diag(np.array([0, 0, 10, 0, 0, 0, 0, 0]) + 1e-6), + t=[0, 0, -max_ang, 0, 0, 0, 0, 0]) +) +rewards.append(ExponentialReward( + state_dim, + W=np.diag(np.array([0, 10, 0, 0, 0, 0, 0, 0]) + 1e-6), + t=[0, max_ang, 0, 0, 0, 0, 0, 0]) +) +rewards.append(ExponentialReward( + state_dim, + W=np.diag(np.array([0, 10, 0, 0, 0, 0, 0, 0]) + 1e-6), + t=[0, -max_ang, 0, 0, 0, 0, 0, 0]) +) +combined_reward = CombinedRewards(state_dim, rewards, coefs=[1.0, -1.0, -1.0, -1.0, -1.0]) +# Initial random rollouts to generate a dataset +X,Y, _, _ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True, render=False) +for i in range(1,J): + X_, Y_, _, _ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True, render=False) + X = np.vstack((X, X_)) + Y = np.vstack((Y, Y_)) + +state_dim = Y.shape[1] +control_dim = X.shape[1] - state_dim +controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=bf, max_action=max_action) + +pilco = PILCO((X, Y), controller=controller, horizon=T, reward=combined_reward, m_init=m_init, S_init=S_init) + +logging = False # To save results in .csv files turn this flag to True +eval_runs = 10 +evaluation_returns_full = np.zeros((N, eval_runs)) +evaluation_returns_sampled = np.zeros((N, eval_runs)) +eval_max_timesteps = 1000//SUBS +X_eval=False +for rollouts in range(N): + print("**** ITERATION no", rollouts, " ****") + pilco.optimize_models(maxiter=maxiter, restarts=2) + pilco.optimize_policy(maxiter=maxiter, restarts=2) + + X_new, Y_new, _, _ = rollout(env, pilco, timesteps=T_sim, verbose=True, SUBS=SUBS, render=True) + + cur_rew = 0 + for t in range(0,len(X_new)): + cur_rew += Rew.compute_reward(X_new[t, 0:state_dim, None].transpose(), 0.0001 * np.eye(state_dim))[0] + if t == T: print('On this episode, on the planning horizon, PILCO reward was: ', cur_rew) + print('On this episode PILCO reward was ', cur_rew) + + gym_steps = 1000 + T_eval = gym_steps // SUBS + # Update dataset + X = np.vstack((X, X_new[:T,:])); Y = np.vstack((Y, Y_new[:T,:])) + pilco.mgpr.set_data((X, Y)) + if logging: + if eval_max_timesteps is None: + eval_max_timesteps = sim_timesteps + for k in range(0, eval_runs): + [X_eval_, _, + evaluation_returns_sampled[rollouts, k], + evaluation_returns_full[rollouts, k]] = rollout(env, pilco, + timesteps=eval_max_timesteps, + verbose=False, SUBS=SUBS, + render=False) + if not X_eval: + X_eval = X_eval_.copy() + else: + X_eval = np.vstack((X_eval, X_eval_)) + np.savetxt("X_" + name + seed + ".csv", X, delimiter=',') + np.savetxt("X_eval_" + name + seed + ".csv", X_eval, delimiter=',') + np.savetxt("evaluation_returns_sampled_" + name + seed + ".csv", evaluation_returns_sampled, delimiter=',') + np.savetxt("evaluation_returns_full_" + name + seed+ ".csv", evaluation_returns_full, delimiter=',') + + # To save a video of a run # env2 = SwimmerWrapper(monitor=True) # rollout(env2, pilco, policy=policy, timesteps=T+50, verbose=True, SUBS=SUBS) # env2.env.close() diff --git a/examples/utils.py b/examples/utils.py index 84730d9..1d3390b 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,28 +1,32 @@ import numpy as np -from gpflow import autoflow -from gpflow import settings -import pilco -float_type = settings.dtypes.float_type - - -def rollout(env, pilco, timesteps, verbose=True, random=False, SUBS=1, render=True): - X = []; Y = [] - x = env.reset() - for timestep in range(timesteps): - if render: env.render() - u = policy(env, pilco, x, random) - for i in range(SUBS): - x_new, _, done, _ = env.step(u) - if done: break +from gpflow import config +from gym import make +float_type = config.default_float() + + +def rollout(env, pilco, timesteps, verbose=True, random=False, SUBS=1, render=False): + X = []; Y = []; + x = env.reset() + ep_return_full = 0 + ep_return_sampled = 0 + for timestep in range(timesteps): if render: env.render() - if verbose: - print("Action: ", u) - print("State : ", x_new) - X.append(np.hstack((x, u))) - Y.append(x_new - x) - x = x_new - if done: break - return np.stack(X), np.stack(Y) + u = policy(env, pilco, x, random) + for i in range(SUBS): + x_new, r, done, _ = env.step(u) + ep_return_full += r + if done: break + if render: env.render() + if verbose: + print("Action: ", u) + print("State : ", x_new) + print("Return so far: ", ep_return_full) + X.append(np.hstack((x, u))) + Y.append(x_new - x) + ep_return_sampled += r + x = x_new + if done: break + return np.stack(X), np.stack(Y), ep_return_sampled, ep_return_full def policy(env, pilco, x, random): @@ -31,22 +35,24 @@ def policy(env, pilco, x, random): else: return pilco.compute_action(x[None, :])[0, :] +class Normalised_Env(): + def __init__(self, env_id, m, std): + self.env = make(env_id).env + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + self.m = m + self.std = std -@autoflow((float_type,[None, None]), (float_type,[None, None])) -def predict_one_step_wrapper(mgpr, m, s): - return mgpr.predict_on_noisy_inputs(m, s) - - -@autoflow((float_type,[None, None]), (float_type,[None, None]), (np.int32, [])) -def predict_trajectory_wrapper(pilco, m, s, horizon): - return pilco.predict(m, s, horizon) - + def state_trans(self, x): + return np.divide(x-self.m, self.std) -@autoflow((float_type,[None, None]), (float_type,[None, None])) -def compute_action_wrapper(pilco, m, s): - return pilco.controller.compute_action(m, s) + def step(self, action): + ob, r, done, _ = self.env.step(action) + return self.state_trans(ob), r, done, {} + def reset(self): + ob = self.env.reset() + return self.state_trans(ob) -@autoflow((float_type, [None, None]), (float_type, [None, None])) -def reward_wrapper(reward, m, s): - return reward.compute_reward(m, s) + def render(self): + self.env.render() diff --git a/pilco/__init__.py b/pilco/__init__.py index 2356919..83dd139 100644 --- a/pilco/__init__.py +++ b/pilco/__init__.py @@ -1,3 +1,3 @@ from . import models from . import controllers -from . import rewards \ No newline at end of file +from . import rewards diff --git a/pilco/controllers.py b/pilco/controllers.py index 0bc1173..0a5d0ab 100644 --- a/pilco/controllers.py +++ b/pilco/controllers.py @@ -1,10 +1,14 @@ import tensorflow as tf +from tensorflow_probability import distributions as tfd import numpy as np import gpflow +from gpflow import Parameter +from gpflow import set_trainable +from gpflow.utilities import positive +f64 = gpflow.utilities.to_default_float from .models import MGPR -from gpflow import settings -float_type = settings.dtypes.float_type +float_type = gpflow.config.default_float() def squash_sin(m, s, max_action=None): ''' @@ -20,26 +24,25 @@ def squash_sin(m, s, max_action=None): else: max_action = max_action * tf.ones((1,k), dtype=float_type) - M = max_action * tf.exp(-tf.diag_part(s) / 2) * tf.sin(m) + M = max_action * tf.exp(-tf.linalg.diag_part(s) / 2) * tf.sin(m) - lq = -( tf.diag_part(s)[:, None] + tf.diag_part(s)[None, :]) / 2 + lq = -( tf.linalg.diag_part(s)[:, None] + tf.linalg.diag_part(s)[None, :]) / 2 q = tf.exp(lq) S = (tf.exp(lq + s) - q) * tf.cos(tf.transpose(m) - m) \ - (tf.exp(lq - s) - q) * tf.cos(tf.transpose(m) + m) S = max_action * tf.transpose(max_action) * S / 2 - C = max_action * tf.diag( tf.exp(-tf.diag_part(s)/2) * tf.cos(m)) + C = max_action * tf.linalg.diag( tf.exp(-tf.linalg.diag_part(s)/2) * tf.cos(m)) return M, S, tf.reshape(C,shape=[k,k]) -class LinearController(gpflow.Parameterized): - def __init__(self, state_dim, control_dim, max_action=None): - gpflow.Parameterized.__init__(self) - self.W = gpflow.Param(np.random.rand(control_dim, state_dim)) - self.b = gpflow.Param(np.random.rand(1, control_dim)) +class LinearController(gpflow.Module): + def __init__(self, state_dim, control_dim, max_action=1.0): + gpflow.Module.__init__(self) + self.W = Parameter(np.random.rand(control_dim, state_dim)) + self.b = Parameter(np.random.rand(1, control_dim)) self.max_action = max_action - @gpflow.params_as_tensors def compute_action(self, m, s, squash=True): ''' Simple affine action: M <- W(m-t) - b @@ -60,13 +63,19 @@ def randomize(self): self.b.assign(mean + sigma*np.random.normal(size=self.b.shape)) -class FakeGPR(gpflow.Parameterized): - def __init__(self, X, Y, kernel): - gpflow.Parameterized.__init__(self) - self.X = gpflow.Param(X) - self.Y = gpflow.Param(Y) - self.kern = kernel +class FakeGPR(gpflow.Module): + def __init__(self, data, kernel, X=None, likelihood_variance=1e-4): + gpflow.Module.__init__(self) + if X is None: + self.X = Parameter(data[0], name="DataX", dtype=gpflow.default_float()) + else: + self.X = X + self.Y = Parameter(data[1], name="DataY", dtype=gpflow.default_float()) + self.data = [self.X, self.Y] + self.kernel = kernel self.likelihood = gpflow.likelihoods.Gaussian() + self.likelihood.variance.assign(likelihood_variance) + set_trainable(self.likelihood.variance, False) class RbfController(MGPR): ''' @@ -74,21 +83,27 @@ class RbfController(MGPR): See Deisenroth et al 2015: Gaussian Processes for Data-Efficient Learning in Robotics and Control Section 5.3.2. ''' - def __init__(self, state_dim, control_dim, num_basis_functions, max_action=None): + def __init__(self, state_dim, control_dim, num_basis_functions, max_action=1.0): MGPR.__init__(self, - np.random.randn(num_basis_functions, state_dim), - 0.1*np.random.randn(num_basis_functions, control_dim) + [np.random.randn(num_basis_functions, state_dim), + 0.1*np.random.randn(num_basis_functions, control_dim)] ) for model in self.models: - model.kern.variance = 1.0 - model.kern.variance.trainable = False - self.max_action = max_action + model.kernel.variance.assign(1.0) + set_trainable(model.kernel.variance, False) + self.max_action = max_action - def create_models(self, X, Y): - self.models = gpflow.params.ParamList([]) + def create_models(self, data): + self.models = [] for i in range(self.num_outputs): - kern = gpflow.kernels.RBF(input_dim=X.shape[1], ARD=True) - self.models.append(FakeGPR(X, Y[:, i:i+1], kern)) + kernel = gpflow.kernels.SquaredExponential(lengthscales=tf.ones([data[0].shape[1],], dtype=float_type)) + transformed_lengthscales = Parameter(kernel.lengthscales, transform=positive(lower=1e-3)) + kernel.lengthscales = transformed_lengthscales + kernel.lengthscales.prior = tfd.Gamma(f64(1.1),f64(1/10.0)) + if i == 0: + self.models.append(FakeGPR((data[0], data[1][:,i:i+1]), kernel)) + else: + self.models.append(FakeGPR((data[0], data[1][:,i:i+1]), kernel, self.models[-1].X)) def compute_action(self, m, s, squash=True): ''' @@ -96,9 +111,10 @@ def compute_action(self, m, s, squash=True): IN: mean (m) and variance (s) of the state OUT: mean (M) and variance (S) of the action ''' - iK, beta = self.calculate_factorizations() - M, S, V = self.predict_given_factorizations(m, s, 0.0 * iK, beta) - S = S - tf.diag(self.variance - 1e-6) + with tf.name_scope("controller") as scope: + iK, beta = self.calculate_factorizations() + M, S, V = self.predict_given_factorizations(m, s, 0.0 * iK, beta) + S = S - tf.linalg.diag(self.variance - 1e-6) if squash: M, S, V2 = squash_sin(M, S, self.max_action) V = V @ V2 @@ -107,8 +123,7 @@ def compute_action(self, m, s, squash=True): def randomize(self): print("Randomising controller") for m in self.models: - mean = 0; sigma = 0.1 - m.X.assign(mean + sigma*np.random.normal(size=m.X.shape)) - m.Y.assign(mean + sigma*np.random.normal(size=m.Y.shape)) + m.X.assign(np.random.normal(size=m.data[0].shape)) + m.Y.assign(self.max_action / 10 * np.random.normal(size=m.data[1].shape)) mean = 1; sigma = 0.1 - m.kern.lengthscales.assign(mean + sigma*np.random.normal(size=m.kern.lengthscales.shape)) + m.kernel.lengthscales.assign(mean + sigma*np.random.normal(size=m.kernel.lengthscales.shape)) diff --git a/pilco/models/__init__.py b/pilco/models/__init__.py index 479004d..b684404 100644 --- a/pilco/models/__init__.py +++ b/pilco/models/__init__.py @@ -1,3 +1,3 @@ from .mgpr import MGPR from .smgpr import SMGPR -from .pilco import PILCO \ No newline at end of file +from .pilco import PILCO diff --git a/pilco/models/mgpr.py b/pilco/models/mgpr.py index b946698..34afe06 100644 --- a/pilco/models/mgpr.py +++ b/pilco/models/mgpr.py @@ -1,68 +1,78 @@ import tensorflow as tf +from tensorflow_probability import distributions as tfd import gpflow +from gpflow.utilities import to_default_float import numpy as np -float_type = gpflow.settings.dtypes.float_type +float_type = gpflow.config.default_float() -def randomize(model): - mean = 1; sigma = 0.01 - - model.kern.lengthscales.assign( - mean + sigma*np.random.normal(size=model.kern.lengthscales.shape)) - model.kern.variance.assign( - mean + sigma*np.random.normal(size=model.kern.variance.shape)) +def randomize(model, mean=1, sigma=0.01): + model.kernel.lengthscales.assign( + mean + sigma*np.random.normal(size=model.kernel.lengthscales.shape)) + model.kernel.variance.assign( + mean + sigma*np.random.normal(size=model.kernel.variance.shape)) if model.likelihood.variance.trainable: model.likelihood.variance.assign( mean + sigma*np.random.normal()) -class MGPR(gpflow.Parameterized): - def __init__(self, X, Y, name=None): +class MGPR(gpflow.Module): + def __init__(self, data, name=None): super(MGPR, self).__init__(name) - self.num_outputs = Y.shape[1] - self.num_dims = X.shape[1] - self.num_datapoints = X.shape[0] + self.num_outputs = data[1].shape[1] + self.num_dims = data[0].shape[1] + self.num_datapoints = data[0].shape[0] - self.create_models(X, Y) + self.create_models(data) self.optimizers = [] - def create_models(self, X, Y): + def create_models(self, data): self.models = [] for i in range(self.num_outputs): - kern = gpflow.kernels.RBF(input_dim=X.shape[1], ARD=True) + kern = gpflow.kernels.SquaredExponential(lengthscales=tf.ones([data[0].shape[1],], dtype=float_type)) #TODO: Maybe fix noise for better conditioning - kern.lengthscales.prior = gpflow.priors.Gamma(1,10) # priors have to be included before - kern.variance.prior = gpflow.priors.Gamma(1.5,2) # before the model gets compiled - self.models.append(gpflow.models.GPR(X, Y[:, i:i+1], kern)) - self.models[i].clear(); self.models[i].compile() + kern.lengthscales.prior = tfd.Gamma(to_default_float(1.1), to_default_float(1/10.0)) # priors have to be included before + kern.variance.prior = tfd.Gamma(to_default_float(1.5), to_default_float(1/2.0)) # before the model gets compiled + self.models.append(gpflow.models.GPR((data[0], data[1][:, i:i+1]), kernel=kern)) + self.models[-1].likelihood.prior = tfd.Gamma(to_default_float(1.2), to_default_float(1/0.05)) - def set_XY(self, X, Y): + def set_data(self, data): for i in range(len(self.models)): - self.models[i].X = X - self.models[i].Y = Y[:, i:i+1] + if isinstance(self.models[i].data[0], gpflow.Parameter): + self.models[i].X.assign(data[0]) + self.models[i].Y.assign(data[1][:, i:i+1]) + self.models[i].data = [self.models[i].X, self.models[i].Y] + else: + self.models[i].data = (data[0], data[1][:, i:i+1]) def optimize(self, restarts=1): if len(self.optimizers) == 0: # This is the first call to optimize(); for model in self.models: # Create an gpflow.train.ScipyOptimizer object for every model embedded in mgpr - optimizer = gpflow.train.ScipyOptimizer(method='L-BFGS-B') - optimizer.minimize(model) + optimizer = gpflow.optimizers.Scipy() + optimizer.minimize(model.training_loss, model.trainable_variables) self.optimizers.append(optimizer) - restarts -= 1 + else: + for model, optimizer in zip(self.models, self.optimizers): + optimizer.minimize(model.training_loss, model.trainable_variables) for model, optimizer in zip(self.models, self.optimizers): - session = optimizer._model.enquire_session(None) - best_parameters = model.read_values(session=session) - best_likelihood = model.compute_log_likelihood() + best_params = { + "lengthscales" : model.kernel.lengthscales.value(), + "k_variance" : model.kernel.variance.value(), + "l_variance" : model.likelihood.variance.value()} + best_loss = model.training_loss() for restart in range(restarts): randomize(model) - optimizer._optimizer.minimize(session=session, - feed_dict=optimizer._gen_feed_dict(optimizer._model, None), - step_callback=None) - likelihood = model.compute_log_likelihood() - if likelihood > best_likelihood: - best_parameters = model.read_values(session=session) - best_likelihood = likelihood - model.assign(best_parameters) + optimizer.minimize(model.training_loss, model.trainable_variables) + loss = model.training_loss() + if loss < best_loss: + best_params["k_lengthscales"] = model.kernel.lengthscales.value() + best_params["k_variance"] = model.kernel.variance.value() + best_params["l_variance"] = model.likelihood.variance.value() + best_loss = model.training_loss() + model.kernel.lengthscales.assign(best_params["lengthscales"]) + model.kernel.variance.assign(best_params["k_variance"]) + model.likelihood.variance.assign(best_params["l_variance"]) def predict_on_noisy_inputs(self, m, s): iK, beta = self.calculate_factorizations() @@ -71,11 +81,11 @@ def predict_on_noisy_inputs(self, m, s): def calculate_factorizations(self): K = self.K(self.X) batched_eye = tf.eye(tf.shape(self.X)[0], batch_shape=[self.num_outputs], dtype=float_type) - L = tf.cholesky(K + self.noise[:, None, None]*batched_eye) - iK = tf.cholesky_solve(L, batched_eye) + L = tf.linalg.cholesky(K + self.noise[:, None, None]*batched_eye) + iK = tf.linalg.cholesky_solve(L, batched_eye, name='chol1_calc_fact') Y_ = tf.transpose(self.Y)[:, :, None] - # Why do we transpose Y? Maybe we need to change the definition of self.Y() or beta? - beta = tf.cholesky_solve(L, Y_)[:, :, 0] + # TODO: Why do we transpose Y? Maybe we need to change the definition of self.Y() or beta? + beta = tf.linalg.cholesky_solve(L, Y_, name="chol2_calc_fact")[:, :, 0] return iK, beta def predict_given_factorizations(self, m, s, iK, beta): @@ -90,14 +100,14 @@ def predict_given_factorizations(self, m, s, iK, beta): inp = tf.tile(self.centralized_input(m)[None, :, :], [self.num_outputs, 1, 1]) # Calculate M and V: mean and inv(s) times input-output covariance - iL = tf.matrix_diag(1/self.lengthscales) + iL = tf.linalg.diag(1/self.lengthscales) iN = inp @ iL B = iL @ s[0, ...] @ iL + tf.eye(self.num_dims, dtype=float_type) # Redefine iN as in^T and t --> t^T # B is symmetric so its the same - t = tf.linalg.transpose( - tf.matrix_solve(B, tf.linalg.transpose(iN), adjoint=True), + t = tf.linalg.matrix_transpose( + tf.linalg.solve(B, tf.linalg.matrix_transpose(iN), adjoint=True, name='predict_gf_t_calc'), ) lb = tf.exp(-tf.reduce_sum(iN * t, -1)/2) * beta @@ -108,7 +118,7 @@ def predict_given_factorizations(self, m, s, iK, beta): V = tf.matmul(tiL, lb[:, :, None], adjoint_a=True)[..., 0] * c[:, None] # Calculate S: Predictive Covariance - R = s @ tf.matrix_diag( + R = s @ tf.linalg.diag( 1/tf.square(self.lengthscales[None, :, :]) + 1/tf.square(self.lengthscales[:, None, :]) ) + tf.eye(self.num_dims, dtype=float_type) @@ -116,13 +126,13 @@ def predict_given_factorizations(self, m, s, iK, beta): # TODO: change this block according to the PR of tensorflow. Maybe move it into a function? X = inp[None, :, :, :]/tf.square(self.lengthscales[:, None, None, :]) X2 = -inp[:, None, :, :]/tf.square(self.lengthscales[None, :, None, :]) - Q = tf.matrix_solve(R, s)/2 + Q = tf.linalg.solve(R, s, name='Q_solve')/2 Xs = tf.reduce_sum(X @ Q * X, -1) X2s = tf.reduce_sum(X2 @ Q * X2, -1) maha = -2 * tf.matmul(X @ Q, X2, adjoint_b=True) + \ Xs[:, :, :, None] + X2s[:, :, None, :] - # - k = tf.log(self.variance)[:, None] - \ + + k = tf.math.log(self.variance)[:, None] - \ tf.reduce_sum(tf.square(iN), -1)/2 L = tf.exp(k[:, None, :, None] + k[None, :, None, :] + maha) S = (tf.tile(beta[:, None, None, :], [1, self.num_outputs, 1, 1]) @@ -131,9 +141,9 @@ def predict_given_factorizations(self, m, s, iK, beta): )[:, :, 0, 0] diagL = tf.transpose(tf.linalg.diag_part(tf.transpose(L))) - S = S - tf.diag(tf.reduce_sum(tf.multiply(iK, diagL), [1, 2])) + S = S - tf.linalg.diag(tf.reduce_sum(tf.multiply(iK, diagL), [1, 2])) S = S / tf.sqrt(tf.linalg.det(R)) - S = S + tf.diag(self.variance) + S = S + tf.linalg.diag(self.variance) S = S - M @ tf.transpose(M) return tf.transpose(M), S, tf.transpose(V) @@ -143,34 +153,38 @@ def centralized_input(self, m): def K(self, X1, X2=None): return tf.stack( - [model.kern.K(X1, X2) for model in self.models] + [model.kernel.K(X1, X2) for model in self.models] ) @property def Y(self): return tf.concat( - [model.Y.parameter_tensor for model in self.models], + [model.data[1] for model in self.models], axis = 1 ) @property def X(self): - return self.models[0].X.parameter_tensor + return self.models[0].data[0] @property def lengthscales(self): return tf.stack( - [model.kern.lengthscales.constrained_tensor for model in self.models] + [model.kernel.lengthscales.value() for model in self.models] ) @property def variance(self): return tf.stack( - [model.kern.variance.constrained_tensor for model in self.models] + [model.kernel.variance.value() for model in self.models] ) @property def noise(self): return tf.stack( - [model.likelihood.variance.constrained_tensor for model in self.models] + [model.likelihood.variance.value() for model in self.models] ) + + @property + def data(self): + return (self.X, self.Y) diff --git a/pilco/models/pilco.py b/pilco/models/pilco.py index 5944201..b09e7cb 100644 --- a/pilco/models/pilco.py +++ b/pilco/models/pilco.py @@ -9,19 +9,19 @@ from .. import controllers from .. import rewards -float_type = gpflow.settings.dtypes.float_type +float_type = gpflow.config.default_float() +from gpflow import set_trainable - -class PILCO(gpflow.models.Model): - def __init__(self, X, Y, num_induced_points=None, horizon=30, controller=None, +class PILCO(gpflow.models.BayesianModel): + def __init__(self, data, num_induced_points=None, horizon=30, controller=None, reward=None, m_init=None, S_init=None, name=None): super(PILCO, self).__init__(name) - if not num_induced_points: - self.mgpr = MGPR(X, Y) + if num_induced_points is None: + self.mgpr = MGPR(data) else: - self.mgpr = SMGPR(X, Y, num_induced_points) - self.state_dim = Y.shape[1] - self.control_dim = X.shape[1] - Y.shape[1] + self.mgpr = SMGPR(data, num_induced_points) + self.state_dim = data[1].shape[1] + self.control_dim = data[0].shape[1] - data[1].shape[1] self.horizon = horizon if controller is None: @@ -37,18 +37,17 @@ def __init__(self, X, Y, num_induced_points=None, horizon=30, controller=None, if m_init is None or S_init is None: # If the user has not provided an initial state for the rollouts, # then define it as the first state in the dataset. - self.m_init = X[0:1, 0:self.state_dim] + self.m_init = data[0][0:1, 0:self.state_dim] self.S_init = np.diag(np.ones(self.state_dim) * 0.1) else: self.m_init = m_init self.S_init = S_init self.optimizer = None - @gpflow.name_scope('likelihood') - def _build_likelihood(self): + def training_loss(self): # This is for tuning controller's parameters reward = self.predict(self.m_init, self.S_init, self.horizon)[2] - return reward + return -reward def optimize_models(self, maxiter=200, restarts=1): ''' @@ -60,9 +59,9 @@ def optimize_models(self, maxiter=200, restarts=1): lengthscales = {}; variances = {}; noises = {}; i = 0 for model in self.mgpr.models: - lengthscales['GP' + str(i)] = model.kern.lengthscales.value - variances['GP' + str(i)] = np.array([model.kern.variance.value]) - noises['GP' + str(i)] = np.array([model.likelihood.variance.value]) + lengthscales['GP' + str(i)] = model.kernel.lengthscales.numpy() + variances['GP' + str(i)] = np.array([model.kernel.variance.numpy()]) + noises['GP' + str(i)] = np.array([model.likelihood.variance.numpy()]) i += 1 print('-----Learned models------') pd.set_option('precision', 3) @@ -78,42 +77,41 @@ def optimize_policy(self, maxiter=50, restarts=1): Optimize controller's parameter's ''' start = time.time() + mgpr_trainable_params = self.mgpr.trainable_parameters + for param in mgpr_trainable_params: + set_trainable(param, False) + if not self.optimizer: - self.optimizer = gpflow.train.ScipyOptimizer(method="L-BFGS-B") - start = time.time() - self.optimizer.minimize(self, maxiter=maxiter) - end = time.time() - print("Controller's optimization: done in %.1f seconds with reward=%.3f." % (end - start, self.compute_reward())) - restarts -= 1 + self.optimizer = gpflow.optimizers.Scipy() + # self.optimizer = tf.optimizers.Adam() + self.optimizer.minimize(self.training_loss, self.trainable_variables, options=dict(maxiter=maxiter)) + # self.optimizer.minimize(self.training_loss, self.trainable_variables) + else: + self.optimizer.minimize(self.training_loss, self.trainable_variables, options=dict(maxiter=maxiter)) + # self.optimizer.minimize(self.training_loss, self.trainable_variables) + end = time.time() + print("Controller's optimization: done in %.1f seconds with reward=%.3f." % (end - start, self.compute_reward())) + restarts -= 1 - session = self.optimizer._model.enquire_session(None) - if restarts > 0: - start = time.time() - self.optimizer._optimizer.minimize(session=session, - feed_dict=self.optimizer._gen_feed_dict(self.optimizer._model, None), - step_callback=None) - end = time.time() - restarts -= 1 - best_parameters = self.read_values(session=session) + best_parameter_values = [param.numpy() for param in self.trainable_parameters] best_reward = self.compute_reward() for restart in range(restarts): self.controller.randomize() start = time.time() - self.optimizer._optimizer.minimize(session=session, - feed_dict=self.optimizer._gen_feed_dict(self.optimizer._model, None), - step_callback=None) + self.optimizer.minimize(self.training_loss, self.trainable_variables, options=dict(maxiter=maxiter)) end = time.time() reward = self.compute_reward() print("Controller's optimization: done in %.1f seconds with reward=%.3f." % (end - start, self.compute_reward())) if reward > best_reward: - best_parameters = self.read_values(session=session) + best_parameter_values = [param.numpy() for param in self.trainable_parameters] best_reward = reward - self.assign(best_parameters) - + for i,param in enumerate(self.trainable_parameters): + param.assign(best_parameter_values[i]) end = time.time() + for param in mgpr_trainable_params: + set_trainable(param, True) - @gpflow.autoflow((float_type,[None, None])) def compute_action(self, x_m): return self.controller.compute_action(x_m, tf.zeros([self.state_dim, self.state_dim], float_type))[0] @@ -135,7 +133,6 @@ def predict(self, m_x, s_x, n): tf.add(reward, self.reward.compute_reward(m_x, s_x)[0]) ), loop_vars ) - return m_x, s_x, reward def propagate(self, m_x, s_x): @@ -155,6 +152,9 @@ def propagate(self, m_x, s_x): M_x.set_shape([1, self.state_dim]); S_x.set_shape([self.state_dim, self.state_dim]) return M_x, S_x - @gpflow.autoflow() def compute_reward(self): - return self._build_likelihood() + return -self.training_loss() + + @property + def maximum_log_likelihood_objective(self): + return -self.training_loss() diff --git a/pilco/models/smgpr.py b/pilco/models/smgpr.py index 28ab5af..5cc38d6 100644 --- a/pilco/models/smgpr.py +++ b/pilco/models/smgpr.py @@ -4,46 +4,44 @@ from .mgpr import MGPR -float_type = gpflow.settings.dtypes.float_type +from gpflow import config +float_type = config.default_float() class SMGPR(MGPR): - def __init__(self, X, Y, num_induced_points, name=None): - gpflow.Parameterized.__init__(self, name) + def __init__(self, data, num_induced_points, name=None): self.num_induced_points = num_induced_points - MGPR.__init__(self, X, Y, name) + MGPR.__init__(self, data, name) - def create_models(self, X, Y): + def create_models(self, data): self.models = [] for i in range(self.num_outputs): - kern = gpflow.kernels.RBF(input_dim=X.shape[1], ARD=True) + kern = gpflow.kernels.SquaredExponential(lengthscales=tf.ones([data[0].shape[1],], dtype=float_type)) Z = np.random.rand(self.num_induced_points, self.num_dims) #TODO: Maybe fix noise for better conditioning - self.models.append(gpflow.models.SGPR(X, Y[:, i:i+1], kern, Z=Z)) - self.models[i].clear(); self.models[i].compile() - + self.models.append(gpflow.models.GPRFITC((data[0], data[1][:, i:i+1]), kern, inducing_variable=Z)) + def calculate_factorizations(self): batched_eye = tf.eye(self.num_induced_points, batch_shape=[self.num_outputs], dtype=float_type) # TODO: Change 1e-6 to the respective constant of GPflow Kmm = self.K(self.Z) + 1e-6 * batched_eye Kmn = self.K(self.Z, self.X) - L = tf.cholesky(Kmm) - V = tf.matrix_triangular_solve(L, Kmn) + L = tf.linalg.cholesky(Kmm) + V = tf.linalg.triangular_solve(L, Kmn) G = self.variance[:, None] - tf.reduce_sum(tf.square(V), axis=[1]) G = tf.sqrt(1.0 + G/self.noise[:, None]) V = V/G[:, None] - Am = tf.cholesky(tf.matmul(V, V, transpose_b=True) + \ + Am = tf.linalg.cholesky(tf.matmul(V, V, transpose_b=True) + \ self.noise[:, None, None] * batched_eye) At = tf.matmul(L, Am) - iAt = tf.matrix_triangular_solve(At, batched_eye) + iAt = tf.linalg.triangular_solve(At, batched_eye) Y_ = tf.transpose(self.Y)[:, :, None] - beta = tf.matrix_triangular_solve(L, - tf.cholesky_solve(Am, (V/G[:, None]) @ Y_), + beta = tf.linalg.triangular_solve(L, + tf.linalg.cholesky_solve(Am, (V/G[:, None]) @ Y_), adjoint=True )[:, :, 0] iB = tf.matmul(iAt, iAt, transpose_a=True) * self.noise[:, None, None] - iK = tf.cholesky_solve(L, batched_eye) - iB - + iK = tf.linalg.cholesky_solve(L, batched_eye) - iB return iK, beta def centralized_input(self, m): @@ -51,4 +49,4 @@ def centralized_input(self, m): @property def Z(self): - return self.models[0].feature.Z.parameter_tensor \ No newline at end of file + return self.models[0].inducing_variable.Z diff --git a/pilco/rewards.py b/pilco/rewards.py index b2e48e4..aeb1085 100644 --- a/pilco/rewards.py +++ b/pilco/rewards.py @@ -1,34 +1,21 @@ -import abc import tensorflow as tf -from gpflow import Parameterized, Param, params_as_tensors, settings import numpy as np +from gpflow import Parameter, Module, config +float_type = config.default_float() -float_type = settings.dtypes.float_type - -class Reward(Parameterized): - def __init__(self): - Parameterized.__init__(self) - - @abc.abstractmethod - def compute_reward(self, m, s): - raise NotImplementedError - - -class ExponentialReward(Reward): +class ExponentialReward(Module): def __init__(self, state_dim, W=None, t=None): - Reward.__init__(self) self.state_dim = state_dim if W is not None: - self.W = Param(np.reshape(W, (state_dim, state_dim)), trainable=False) + self.W = Parameter(np.reshape(W, (state_dim, state_dim)), trainable=False) else: - self.W = Param(np.eye(state_dim), trainable=False) + self.W = Parameter(np.eye(state_dim), trainable=False) if t is not None: - self.t = Param(np.reshape(t, (1, state_dim)), trainable=False) + self.t = Parameter(np.reshape(t, (1, state_dim)), trainable=False) else: - self.t = Param(np.zeros((1, state_dim)), trainable=False) + self.t = Parameter(np.zeros((1, state_dim)), trainable=False) - @params_as_tensors def compute_reward(self, m, s): ''' Reward function, calculating mean and variance of rewards, given @@ -45,14 +32,14 @@ def compute_reward(self, m, s): SW = s @ self.W iSpW = tf.transpose( - tf.matrix_solve( (tf.eye(self.state_dim, dtype=float_type) + SW), + tf.linalg.solve( (tf.eye(self.state_dim, dtype=float_type) + SW), tf.transpose(self.W), adjoint=True)) muR = tf.exp(-(m-self.t) @ iSpW @ tf.transpose(m-self.t)/2) / \ tf.sqrt( tf.linalg.det(tf.eye(self.state_dim, dtype=float_type) + SW) ) i2SpW = tf.transpose( - tf.matrix_solve( (tf.eye(self.state_dim, dtype=float_type) + 2*SW), + tf.linalg.solve( (tf.eye(self.state_dim, dtype=float_type) + 2*SW), tf.transpose(self.W), adjoint=True)) r2 = tf.exp(-(m-self.t) @ i2SpW @ tf.transpose(m-self.t)) / \ @@ -63,35 +50,32 @@ def compute_reward(self, m, s): sR.set_shape([1, 1]) return muR, sR -class LinearReward(Reward): +class LinearReward(Module): def __init__(self, state_dim, W): - Reward.__init__(self) self.state_dim = state_dim - self.W = Param(np.reshape(W, (state_dim, 1)), trainable=False) + self.W = Parameter(np.reshape(W, (state_dim, 1)), trainable=False) - @params_as_tensors def compute_reward(self, m, s): - muR = m @ self.W + muR = tf.reshape(m, (1, self.state_dim)) @ self.W sR = tf.transpose(self.W) @ s @ self.W return muR, sR -class CombinedRewards(Reward): +class CombinedRewards(Module): def __init__(self, state_dim, rewards=[], coefs=None): - Reward.__init__(self) self.state_dim = state_dim self.base_rewards = rewards if coefs is not None: - self.coefs = coefs + self.coefs = Parameter(coefs, trainable=False) else: - self.coefs = np.ones(len(list)) + self.coefs = Parameter(np.ones(len(rewards)), dtype=float_type, trainable=False) - @params_as_tensors def compute_reward(self, m, s): - muR = 0 - sR = 0 - for c,r in enumerate(self.base_rewards): - tmp1, tmp2 = r.compute_reward(m, s) - muR += self.coefs[c] * tmp1 - sR += self.coefs[c]**2 * tmp2 - return muR, sR + total_output_mean = 0 + total_output_covariance = 0 + for reward, coef in zip(self.base_rewards, self.coefs): + output_mean, output_covariance = reward.compute_reward(m, s) + total_output_mean += coef * output_mean + total_output_covariance += coef**2 * output_covariance + + return total_output_mean, total_output_covariance diff --git a/requirements.txt b/requirements.txt index 0b862eb..d4a2b37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,7 @@ -git+git://github.com/GPflow/GPflow.git@v1.5.0 +tensorflow==2.1.0 +gpflow==2.0.0 +matplotlib +pandas +gym +oct2py +pytest \ No newline at end of file diff --git a/safe_pilco_extension/README.md b/safe_pilco_extension/README.md new file mode 100644 index 0000000..5e60cf8 --- /dev/null +++ b/safe_pilco_extension/README.md @@ -0,0 +1,2 @@ +## Safe PILCO extension +In this folder, we include an extension of the standard PILCO algorithm that takes safety constraints (defined on the environment's state space) into account as in [https://arxiv.org/abs/1712.05556](https://arxiv.org/pdf/1712.05556.pdf). The `safe_swimmer_run.py` and `safe_cars_run.py` in the `../examples` folder demonstrate the use of this extension. \ No newline at end of file diff --git a/safe_pilco_extension/__init__.py b/safe_pilco_extension/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/safe_pilco_extension/rewards_safe.py b/safe_pilco_extension/rewards_safe.py new file mode 100644 index 0000000..fc6ecca --- /dev/null +++ b/safe_pilco_extension/rewards_safe.py @@ -0,0 +1,73 @@ +import pilco +import tensorflow as tf +import numpy as np +from scipy.stats import mvn +import tensorflow_probability as tfp +from tensorflow_probability import distributions as tfd + +import gpflow +from gpflow import Module +float_type = gpflow.config.default_float() +from gpflow import set_trainable + +class RiskOfCollision(Module): + def __init__(self, state_dim, low, high): + Module.__init__(self) + self.state_dim = state_dim + self.low = tf.cast(low, float_type) + self.high = tf.cast(high, float_type) + + def compute_reward(self, m, s): + infl_diag_S = 2*tf.linalg.diag_part(s) + dist1 = tfd.Normal(loc=m[0,0], scale=infl_diag_S[0]) + dist2 = tfd.Normal(loc=m[0,2], scale=infl_diag_S[2]) + risk = (dist1.cdf(self.high[0]) - dist1.cdf(self.low[0])) * (dist2.cdf(self.high[1]) - dist2.cdf(self.low[1])) + return risk, 0.0001 * tf.ones(1, dtype=float_type) + +class SingleConstraint(Module): + def __init__(self, dim, high=None, low=None, inside=True): + Module.__init__(self) + if high is None: + self.high = False + else: + self.high = high + if low is None: + self.low = False + else: + self.low = low + if high is None and low is None: + raise Exception("At least one of bounds (high,low) has to be defined") + self.dim = tf.constant(dim, tf.int32) + if inside: + self.inside = True + else: + self.inside = False + + + def compute_reward(self, m, s): + # Risk refers to the space between the low and high value -> 1 + # otherwise self.in = 0 + if not self.high: + dist = tfd.Normal(loc=m[0, self.dim], scale=s[self.dim, self.dim]) + risk = 1 - dist.cdf(self.low) + elif not self.low: + dist = tfd.Normal(loc=m[0, self.dim], scale=s[self.dim, self.dim]) + risk = dist.cdf(self.high) + else: + dist = tfd.Normal(loc=m[0, self.dim], scale=s[self.dim, self.dim]) + risk = dist.cdf(self.high) - dist.cdf(self.low) + if not self.inside: + risk = 1 - risk + return risk, 0.0001 * tf.ones(1, dtype=float_type) + +class ObjectiveFunction(Module): + def __init__(self, reward_f, risk_f, mu=1.0): + Module.__init__(self) + self.reward_f = reward_f + self.risk_f = risk_f + self.mu = Parameter(mu, dtype=float_type, trainable=False) + + def compute_reward(self, m, s): + reward, var = self.reward_f.compute_reward(m, s) + risk, _ = self.risk_f.compute_reward(m, s) + return reward - self.mu * risk, var diff --git a/safe_pilco_extension/safe_pilco.py b/safe_pilco_extension/safe_pilco.py new file mode 100644 index 0000000..c9cb45e --- /dev/null +++ b/safe_pilco_extension/safe_pilco.py @@ -0,0 +1,50 @@ +from pilco.models import PILCO +import tensorflow as tf +import gpflow +import time +import numpy as np +import pandas as pd + +from pilco.models.mgpr import MGPR +from pilco.models.smgpr import SMGPR +from pilco import controllers +from pilco import rewards +from gpflow import Module + +float_type = gpflow.config.default_float() +from gpflow import set_trainable + +class SafePILCO(PILCO): + def __init__(self, data, num_induced_points=None, horizon=30, controller=None, + reward_add=None, reward_mult=None, m_init=None, S_init=None, name=None, mu=5.0): + super(SafePILCO, self).__init__(data, num_induced_points=num_induced_points, horizon=horizon, controller=controller, + reward=reward_add, m_init=m_init, S_init=S_init) + if reward_mult is None: + raise exception("have to define multiplicative reward") + + self.mu = gpflow.Parameter(mu, trainable=False) + + self.reward_mult = reward_mult + + def predict(self, m_x, s_x, n): + loop_vars = [ + tf.constant(0, tf.int32), + m_x, + s_x, + tf.constant([[0]], float_type), + tf.constant([[1.0]], float_type) + ] + + _, m_x, s_x, reward_add, reward_mult = tf.while_loop( + # Termination condition + lambda j, m_x, s_x, reward_add, reward_mult: j < n, + # Body function + lambda j, m_x, s_x, reward_add, reward_mult: ( + j + 1, + *self.propagate(m_x, s_x), + tf.add(reward_add, self.reward.compute_reward(m_x, s_x)[0]), + tf.multiply(reward_mult, 1.0-self.reward_mult.compute_reward(m_x, s_x)[0]) + ), loop_vars + ) + reward_total = reward_add + self.mu * (1.0 - reward_mult) + return m_x, s_x, reward_total diff --git a/setup.py b/setup.py index 9762b3d..cd67394 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ # Dependencies of PILCO requirements = [ - 'gpflow>=1.5.0' + 'gpflow==2.0.0' ] packages = find_packages('.') diff --git a/tests/test_cascade.py b/tests/test_cascade.py index 280630b..cf085a1 100644 --- a/tests/test_cascade.py +++ b/tests/test_cascade.py @@ -2,25 +2,17 @@ from pilco.models.pilco import PILCO import numpy as np import os -from gpflow import autoflow -from gpflow import settings import oct2py import logging +from gpflow import config + octave = oct2py.Oct2Py(logger=oct2py.get_log()) octave.logger = oct2py.get_log('new_log') octave.logger.setLevel(logging.INFO) dir_path = os.path.dirname(os.path.realpath("__file__")) + "/tests/Matlab Code" octave.addpath(dir_path) -float_type = settings.dtypes.float_type - -@autoflow((float_type,[None, None]), (float_type,[None, None]), (np.int32, [])) -def predict_wrapper(pilco, m, s, horizon): - return pilco.predict(m, s, horizon) - -@autoflow((float_type,[None, None]), (float_type,[None, None])) -def compute_action_wrapper(pilco, m, s): - return pilco.controller.compute_action(m, s) +float_type = config.default_float() def test_cascade(): np.random.seed(0) @@ -33,8 +25,9 @@ def test_cascade(): X0 = np.random.rand(100, d + k) A = np.random.rand(d + k, d) Y0 = np.sin(X0).dot(A) + 1e-3*(np.random.rand(100, d) - 0.5) # Just something smooth - pilco = PILCO(X0, Y0) + pilco = PILCO((X0, Y0)) pilco.controller.max_action = e + pilco.optimize_models(restarts=5) pilco.optimize_policy(restarts=5) @@ -43,19 +36,19 @@ def test_cascade(): s = np.random.rand(d, d) s = s.dot(s.T) # Make s positive semidefinite - M, S, reward = predict_wrapper(pilco, m, s, horizon) + M, S, reward = pilco.predict(m, s, horizon) # convert data to the struct expected by the MATLAB implementation policy = oct2py.io.Struct() policy.p = oct2py.io.Struct() - policy.p.w = pilco.controller.W.value - policy.p.b = pilco.controller.b.value.T + policy.p.w = pilco.controller.W.numpy() + policy.p.b = pilco.controller.b.numpy().T policy.maxU = e # convert data to the struct expected by the MATLAB implementation - lengthscales = np.stack([model.kern.lengthscales.value for model in pilco.mgpr.models]) - variance = np.stack([model.kern.variance.value for model in pilco.mgpr.models]) - noise = np.stack([model.likelihood.variance.value for model in pilco.mgpr.models]) + lengthscales = np.stack([model.kernel.lengthscales.numpy() for model in pilco.mgpr.models]) + variance = np.stack([model.kernel.variance.numpy() for model in pilco.mgpr.models]) + noise = np.stack([model.likelihood.variance.numpy() for model in pilco.mgpr.models]) hyp = np.log(np.hstack( (lengthscales, diff --git a/tests/test_controllers.py b/tests/test_controllers.py index 82107f4..e78cd11 100644 --- a/tests/test_controllers.py +++ b/tests/test_controllers.py @@ -1,19 +1,14 @@ from pilco.controllers import RbfController, LinearController, squash_sin import numpy as np import os -from gpflow import autoflow -from gpflow import settings import tensorflow as tf import oct2py octave = oct2py.Oct2Py() dir_path = os.path.dirname(os.path.realpath("__file__")) + "/tests/Matlab Code" octave.addpath(dir_path) -float_type = settings.dtypes.float_type - -@autoflow((float_type,[None, None]), (float_type,[None, None])) -def compute_action_wrapper(controller, m, s, squash=False): - return controller.compute_action(m, s, squash) +from gpflow import config +float_type = config.default_float() def test_rbf(): np.random.seed(0) @@ -26,19 +21,19 @@ def test_rbf(): A = np.random.rand(d, k) Y0 = np.sin(X0).dot(A) + 1e-3*(np.random.rand(100, k) - 0.5) # Just something smooth rbf = RbfController(3, 2, b) - rbf.set_XY(X0, Y0) + rbf.set_data((X0, Y0)) # Generate input m = np.random.rand(1, d) # But MATLAB defines it as m' s = np.random.rand(d, d) s = s.dot(s.T) # Make s positive semidefinite - M, S, V = compute_action_wrapper(rbf, m, s) + M, S, V = rbf.compute_action(m, s, squash=False) # convert data to the struct expected by the MATLAB implementation - lengthscales = np.stack([model.kern.lengthscales.value for model in rbf.models]) - variance = np.stack([model.kern.variance.value for model in rbf.models]) - noise = np.stack([model.likelihood.variance.value for model in rbf.models]) + lengthscales = np.stack([model.kernel.lengthscales.numpy() for model in rbf.models]) + variance = np.stack([model.kernel.variance.numpy() for model in rbf.models]) + noise = np.stack([model.likelihood.variance.numpy() for model in rbf.models]) hyp = np.log(np.hstack( (lengthscales, @@ -74,10 +69,10 @@ def test_linear(): b = np.random.rand(1, k) linear = LinearController(d, k) - linear.W = W - linear.b = b + linear.W.assign(W) + linear.b.assign(b) - M, S, V = compute_action_wrapper(linear, m, s) + M, S, V = linear.compute_action(m, s, squash=False) # convert data to the struct expected by the MATLAB implementation policy = oct2py.io.Struct() @@ -105,8 +100,6 @@ def test_squash(): e = 7.0 M, S, V = squash_sin(m, s, e) - sess = tf.Session() - M, S, V = sess.run([M, S, V]) M_mat, S_mat, V_mat = octave.gSin(m.T, s, e, nout=3) M_mat = np.asarray(M_mat) diff --git a/tests/test_predictions.py b/tests/test_predictions.py index 70c66b8..9ec0341 100644 --- a/tests/test_predictions.py +++ b/tests/test_predictions.py @@ -1,18 +1,14 @@ from pilco.models import MGPR import numpy as np import os -from gpflow import autoflow -from gpflow import settings import oct2py octave = oct2py.Oct2Py() dir_path = os.path.dirname(os.path.realpath("__file__")) + "/tests/Matlab Code" octave.addpath(dir_path) -float_type = settings.dtypes.float_type +from gpflow import config +float_type = config.default_float() -@autoflow((float_type,[None, None]), (float_type,[None, None])) -def predict_wrapper(mgpr, m, s): - return mgpr.predict_on_noisy_inputs(m, s) def test_predictions(): np.random.seed(0) @@ -23,7 +19,7 @@ def test_predictions(): X0 = np.random.rand(100, d) A = np.random.rand(d, k) Y0 = np.sin(X0).dot(A) + 1e-3*(np.random.rand(100, k) - 0.5) # Just something smooth - mgpr = MGPR(X0, Y0) + mgpr = MGPR((X0, Y0)) mgpr.optimize() @@ -32,18 +28,18 @@ def test_predictions(): s = np.random.rand(d, d) s = s.dot(s.T) # Make s positive semidefinite - M, S, V = predict_wrapper(mgpr, m, s) + M, S, V = mgpr.predict_on_noisy_inputs(m, s) # Change the dataset and predict again. Just to make sure that we don't cache something we shouldn't. X0 = 5*np.random.rand(100, d) - mgpr.set_XY(X0, Y0) + mgpr.set_data((X0, Y0)) - M, S, V = predict_wrapper(mgpr, m, s) + M, S, V = mgpr.predict_on_noisy_inputs(m, s) # convert data to the struct expected by the MATLAB implementation - lengthscales = np.stack([model.kern.lengthscales.value for model in mgpr.models]) - variance = np.stack([model.kern.variance.value for model in mgpr.models]) - noise = np.stack([model.likelihood.variance.value for model in mgpr.models]) + lengthscales = np.stack([model.kernel.lengthscales.value() for model in mgpr.models]) + variance = np.stack([model.kernel.variance.value() for model in mgpr.models]) + noise = np.stack([model.likelihood.variance.value() for model in mgpr.models]) hyp = np.log(np.hstack( (lengthscales, diff --git a/tests/test_rewards.py b/tests/test_rewards.py index f90f299..5fb31ec 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -1,18 +1,14 @@ from pilco.rewards import ExponentialReward import numpy as np import os -from gpflow import autoflow -from gpflow import settings import oct2py octave = oct2py.Oct2Py() dir_path = os.path.dirname(os.path.realpath("__file__")) + "/tests/Matlab Code" octave.addpath(dir_path) -float_type = settings.dtypes.float_type +from gpflow import config +float_type = config.default_float() -@autoflow((float_type, [None, None]), (float_type, [None, None])) -def reward_wrapper(reward, m, s): - return reward.compute_reward(m, s) def test_reward(): ''' @@ -24,10 +20,10 @@ def test_reward(): s = s.dot(s.T) reward = ExponentialReward(k) - W = reward.W.value - t = reward.t.value + W = reward.W.numpy() + t = reward.t.numpy() - M, S = reward_wrapper(reward, m, s) + M, S = reward.compute_reward(m, s) M_mat, _, _, S_mat = octave.reward(m.T, s, t.T, W, nout=4) diff --git a/tests/test_sparse_predictions.py b/tests/test_sparse_predictions.py index dea7ca1..019e256 100644 --- a/tests/test_sparse_predictions.py +++ b/tests/test_sparse_predictions.py @@ -1,22 +1,13 @@ from pilco.models import SMGPR import numpy as np import os -from gpflow import autoflow -from gpflow import settings import oct2py octave = oct2py.Oct2Py() dir_path = os.path.dirname(os.path.realpath("__file__")) + "/tests/Matlab Code" octave.addpath(dir_path) -float_type = settings.dtypes.float_type - -@autoflow((float_type,[None, None]), (float_type,[None, None])) -def predict_wrapper(smgpr, m, s): - return smgpr.predict_on_noisy_inputs(m, s) - -@autoflow() -def get_induced_points(smgpr): - return smgpr.Z +from gpflow import config +float_type = config.default_float() def test_sparse_predictions(): np.random.seed(0) @@ -27,7 +18,7 @@ def test_sparse_predictions(): X0 = np.random.rand(100, d) A = np.random.rand(d, k) Y0 = np.sin(X0).dot(A) + 1e-3*(np.random.rand(100, k) - 0.5) # Just something smooth - smgpr = SMGPR(X0, Y0, num_induced_points=30) + smgpr = SMGPR((X0, Y0), num_induced_points=30) smgpr.optimize() @@ -36,12 +27,12 @@ def test_sparse_predictions(): s = np.random.rand(d, d) s = s.dot(s.T) # Make s positive semidefinite - M, S, V = predict_wrapper(smgpr, m, s) + M, S, V = smgpr.predict_on_noisy_inputs(m, s) # convert data to the struct expected by the MATLAB implementation - lengthscales = np.stack([model.kern.lengthscales.value for model in smgpr.models]) - variance = np.stack([model.kern.variance.value for model in smgpr.models]) - noise = np.stack([model.likelihood.variance.value for model in smgpr.models]) + lengthscales = np.stack([model.kernel.lengthscales.value() for model in smgpr.models]) + variance = np.stack([model.kernel.variance.value() for model in smgpr.models]) + noise = np.stack([model.likelihood.variance.value() for model in smgpr.models]) hyp = np.log(np.hstack( (lengthscales, @@ -53,7 +44,7 @@ def test_sparse_predictions(): gpmodel.hyp = hyp gpmodel.inputs = X0 gpmodel.targets = Y0 - gpmodel.induce = get_induced_points(smgpr) + gpmodel.induce = smgpr.Z.numpy() # Call function in octave M_mat, S_mat, V_mat = octave.gp1(gpmodel, m.T, s, nout=3)