From 7c7e5b1742a699a9b58812274a8e9d37a0ce130e Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Fri, 1 Mar 2024 19:20:42 +0100 Subject: [PATCH] update examples --- examples/experimental/gaussian_process.py | 8 +++----- examples/experimental/sparse_gaussian_process.py | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/experimental/gaussian_process.py b/examples/experimental/gaussian_process.py index cc78b77..9712712 100644 --- a/examples/experimental/gaussian_process.py +++ b/examples/experimental/gaussian_process.py @@ -1,6 +1,4 @@ -""" -Gaussian process regression -=========================== +"""Gaussian process regression. This example implements the training and prediction of a Gaussian process regression model. @@ -12,12 +10,12 @@ """ import argparse +import jax import matplotlib.patches as mpatches import matplotlib.pyplot as plt from jax import numpy as jnp from jax import random as jr -from jax.config import config from ramsey.data import sample_from_gaussian_process from ramsey.experimental import ( @@ -26,7 +24,7 @@ train_gaussian_process, ) -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) def data(key, rho, sigma, n=1000): diff --git a/examples/experimental/sparse_gaussian_process.py b/examples/experimental/sparse_gaussian_process.py index af7dc9a..094bf2f 100644 --- a/examples/experimental/sparse_gaussian_process.py +++ b/examples/experimental/sparse_gaussian_process.py @@ -1,6 +1,4 @@ -""" -Sparse Gaussian process regression -================================== +"""Sparse Gaussian process regression example. This example implements the training and prediction of a sparse Gaussian process regression model. @@ -13,12 +11,12 @@ """ import argparse +import jax import matplotlib.patches as mpatches import matplotlib.pyplot as plt from jax import numpy as jnp from jax import random as jr -from jax.config import config from ramsey.data import sample_from_gaussian_process from ramsey.experimental import ( @@ -27,7 +25,7 @@ train_sparse_gaussian_process, ) -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) def data(key, rho, sigma, n=1000):