diff --git a/pilco/models/pilco.py b/pilco/models/pilco.py index b09e7cb..c9f66e9 100644 --- a/pilco/models/pilco.py +++ b/pilco/models/pilco.py @@ -115,6 +115,7 @@ def optimize_policy(self, maxiter=50, restarts=1): def compute_action(self, x_m): return self.controller.compute_action(x_m, tf.zeros([self.state_dim, self.state_dim], float_type))[0] + @tf.function def predict(self, m_x, s_x, n): loop_vars = [ tf.constant(0, tf.int32), diff --git a/pilco/rewards.py b/pilco/rewards.py index aeb1085..c2f68e7 100644 --- a/pilco/rewards.py +++ b/pilco/rewards.py @@ -16,6 +16,7 @@ def __init__(self, state_dim, W=None, t=None): else: self.t = Parameter(np.zeros((1, state_dim)), trainable=False) + @tf.function def compute_reward(self, m, s): ''' Reward function, calculating mean and variance of rewards, given