This repository has been archived by the owner on Sep 7, 2023. It is now read-only.
forked from HIPS/autograd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgaussian_process.py
110 lines (88 loc) · 4 KB
/
gaussian_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import absolute_import
from __future__ import print_function
import matplotlib.pyplot as plt
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.numpy.linalg import solve
import autograd.scipy.stats.multivariate_normal as mvn
from autograd import value_and_grad
from scipy.optimize import minimize
def make_gp_funs(cov_func, num_cov_params):
"""Functions that perform Gaussian process regression.
cov_func has signature (cov_params, x, x')"""
def unpack_kernel_params(params):
mean = params[0]
cov_params = params[2:]
noise_scale = np.exp(params[1]) + 0.0001
return mean, cov_params, noise_scale
def predict(params, x, y, xstar):
"""Returns the predictive mean and covariance at locations xstar,
of the latent function value f (without observation noise)."""
mean, cov_params, noise_scale = unpack_kernel_params(params)
cov_f_f = cov_func(cov_params, xstar, xstar)
cov_y_f = cov_func(cov_params, x, xstar)
cov_y_y = cov_func(cov_params, x, x) + noise_scale * np.eye(len(y))
pred_mean = mean + np.dot(solve(cov_y_y, cov_y_f).T, y - mean)
pred_cov = cov_f_f - np.dot(solve(cov_y_y, cov_y_f).T, cov_y_f)
return pred_mean, pred_cov
def log_marginal_likelihood(params, x, y):
mean, cov_params, noise_scale = unpack_kernel_params(params)
cov_y_y = cov_func(cov_params, x, x) + noise_scale * np.eye(len(y))
prior_mean = mean * np.ones(len(y))
return mvn.logpdf(y, prior_mean, cov_y_y)
return num_cov_params + 2, predict, log_marginal_likelihood
# Define an example covariance function.
def rbf_covariance(kernel_params, x, xp):
output_scale = np.exp(kernel_params[0])
lengthscales = np.exp(kernel_params[1:])
diffs = np.expand_dims(x /lengthscales, 1)\
- np.expand_dims(xp/lengthscales, 0)
return output_scale * np.exp(-0.5 * np.sum(diffs**2, axis=2))
def build_toy_dataset(D=1, n_data=20, noise_std=0.1):
rs = npr.RandomState(0)
inputs = np.concatenate([np.linspace(0, 3, num=n_data/2),
np.linspace(6, 8, num=n_data/2)])
targets = (np.cos(inputs) + rs.randn(n_data) * noise_std) / 2.0
inputs = (inputs - 4.0) / 2.0
inputs = inputs.reshape((len(inputs), D))
return inputs, targets
if __name__ == '__main__':
D = 1
# Build model and objective function.
num_params, predict, log_marginal_likelihood = \
make_gp_funs(rbf_covariance, num_cov_params=D + 1)
X, y = build_toy_dataset(D=D)
objective = lambda params: -log_marginal_likelihood(params, X, y)
# Set up figure.
fig = plt.figure(figsize=(12,8), facecolor='white')
ax = fig.add_subplot(111, frameon=False)
plt.show(block=False)
def callback(params):
print("Log likelihood {}".format(-objective(params)))
plt.cla()
# Show posterior marginals.
plot_xs = np.reshape(np.linspace(-7, 7, 300), (300,1))
pred_mean, pred_cov = predict(params, X, y, plot_xs)
marg_std = np.sqrt(np.diag(pred_cov))
ax.plot(plot_xs, pred_mean, 'b')
ax.fill(np.concatenate([plot_xs, plot_xs[::-1]]),
np.concatenate([pred_mean - 1.96 * marg_std,
(pred_mean + 1.96 * marg_std)[::-1]]),
alpha=.15, fc='Blue', ec='None')
# Show samples from posterior.
rs = npr.RandomState(0)
sampled_funcs = rs.multivariate_normal(pred_mean, pred_cov, size=10)
ax.plot(plot_xs, sampled_funcs.T)
ax.plot(X, y, 'kx')
ax.set_ylim([-1.5, 1.5])
ax.set_xticks([])
ax.set_yticks([])
plt.draw()
plt.pause(1.0/60.0)
# Initialize covariance parameters
rs = npr.RandomState(0)
init_params = 0.1 * rs.randn(num_params)
print("Optimizing covariance parameters...")
cov_params = minimize(value_and_grad(objective), init_params, jac=True,
method='CG', callback=callback)
plt.pause(10.0)