diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f54bcd3..c9dc0f9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -31,7 +31,7 @@ jobs: - name: Install dependencies run: | pip install tox - - name: Run format, sort, lints and types + - name: Run format, lints and types run: | tox -e format,lints,types diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9cfdf4..39a10b4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,17 +31,6 @@ repos: args: ["--ignore-missing-imports"] files: "(ramsey|examples)" -- repo: https://github.com/nbQA-dev/nbQA - rev: 1.6.3 - hooks: - - id: nbqa-black - - id: nbqa-pyupgrade - args: [--py39-plus] - - id: nbqa-isort - args: ['--profile=black'] - - id: nbqa-flake8 - args: ['--ignore=E501,E203,E302,E402,E731,W503'] - - repo: https://github.com/jorisroovers/gitlint rev: v0.19.1 hooks: diff --git a/docs/conf.py b/docs/conf.py index cbb1e42..f64461b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -7,7 +7,7 @@ extensions = [ "nbsphinx", "sphinx.ext.autodoc", - 'sphinx_autodoc_typehints', + "sphinx_autodoc_typehints", "sphinx.ext.autosummary", "sphinx.ext.doctest", "sphinx.ext.intersphinx", @@ -18,13 +18,13 @@ "sphinx_copybutton", "sphinx_math_dollar", "IPython.sphinxext.ipython_console_highlighting", - 'sphinx_design' + "sphinx_design", ] templates_path = ["_templates"] html_static_path = ["_static"] -html_css_files = ['theme.css'] +html_css_files = ["theme.css"] autodoc_default_options = { "member-order": "bysource", @@ -39,7 +39,7 @@ ".DS_Store", "notebooks/.ipynb_checkpoints", "examples/*ipynb", - "examples/*py" + "examples/*py", ] html_theme = "sphinx_book_theme" @@ -50,9 +50,7 @@ "use_download_button": False, "use_fullscreen_button": False, "extra_navbar": "", - "launch_buttons": { - "colab_url": "https://colab.research.google.com" - }, + "launch_buttons": {"colab_url": "https://colab.research.google.com"}, } html_title = "Ramsey 🚀" diff --git a/docs/notebooks/inference_with_flax_and_numpyro.ipynb b/docs/notebooks/inference_with_flax_and_numpyro.ipynb index 6afc3f5..b04ff92 100644 --- a/docs/notebooks/inference_with_flax_and_numpyro.ipynb +++ b/docs/notebooks/inference_with_flax_and_numpyro.ipynb @@ -93,12 +93,8 @@ "def model(y=None):\n", " loc = numpyro.sample(\"loc\", dist.Normal(0.0, 1.0))\n", " scale = numpyro.sample(\"scale\", dist.HalfNormal(0.05))\n", - " ar_coefficients = numpyro.sample(\n", - " \"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0)\n", - " )\n", - " numpyro.sample(\n", - " \"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y\n", - " )" + " ar_coefficients = numpyro.sample(\"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0))\n", + " numpyro.sample(\"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y)" ] }, { @@ -248,12 +244,8 @@ "def model(y=None):\n", " loc = numpyro.sample(\"loc\", dist.Normal(0.0, 1.0))\n", " scale = numpyro.sample(\"scale\", dist.HalfNormal(1.0))\n", - " ar_coefficients = numpyro.sample(\n", - " \"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0)\n", - " )\n", - " numpyro.sample(\n", - " \"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y\n", - " )" + " ar_coefficients = numpyro.sample(\"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0))\n", + " numpyro.sample(\"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y)" ] }, { @@ -450,12 +442,8 @@ "def model(y=None):\n", " loc = numpyro.sample(\"loc\", dist.Normal(0.0, 1.0))\n", " scale = numpyro.sample(\"scale\", dist.HalfNormal(1.0))\n", - " ar_coefficients = numpyro.sample(\n", - " \"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0)\n", - " )\n", - " numpyro.sample(\n", - " \"y\", Autoregressive(loc, ar_coefficients, scale, length=10), obs=y\n", - " )" + " ar_coefficients = numpyro.sample(\"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0))\n", + " numpyro.sample(\"y\", Autoregressive(loc, ar_coefficients, scale, length=10), obs=y)" ] }, { @@ -659,12 +647,8 @@ "def model(y=None):\n", " loc = numpyro.param(\"loc\", 0.0)\n", " scale = numpyro.param(\"scale\", 1.0, constraints=constraints.positive)\n", - " ar_coefficients = numpyro.param(\n", - " \"ar_coefficients\", jnp.array([-1.0, 0.0, 1.0])\n", - " )\n", - " numpyro.sample(\n", - " \"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y\n", - " )" + " ar_coefficients = numpyro.param(\"ar_coefficients\", jnp.array([-1.0, 0.0, 1.0]))\n", + " numpyro.sample(\"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y)" ] }, { @@ -882,9 +866,7 @@ "\n", " def setup(self):\n", " self.loc = self.param(\"loc\", initializers.zeros, (1, 1), jnp.float32)\n", - " self.log_scale = self.param(\n", - " \"log_scale\", initializers.ones, (1, 1), jnp.float32\n", - " )\n", + " self.log_scale = self.param(\"log_scale\", initializers.ones, (1, 1), jnp.float32)\n", " self.ar_coefficients = self.param(\n", " \"ar_coefficients\", initializers.zeros, (self.order, 1), jnp.float32\n", " )\n", @@ -950,10 +932,7 @@ " @jax.jit\n", " def step(rngs, state, **batch):\n", " current_step = state.step\n", - " rngs = {\n", - " name: jr.fold_in(rng, current_step)\n", - " for name, rng in rngs.items()\n", - " }\n", + " rngs = {name: jr.fold_in(rng, current_step) for name, rng in rngs.items()}\n", "\n", " def obj_fn(params):\n", " obj = state.apply_fn(variables=params, rngs=rngs, **batch)\n", @@ -969,9 +948,7 @@ " return state.params, objectives\n", "\n", "\n", - "state = create_train_state(\n", - " jr.PRNGKey(123), ARModel(3), optax.adam(0.01), inputs=y\n", - ")\n", + "state = create_train_state(jr.PRNGKey(123), ARModel(3), optax.adam(0.01), inputs=y)\n", "params, objectives = train(jr.PRNGKey(1), y, state, 1000)" ] }, diff --git a/docs/notebooks/neural_processes.ipynb b/docs/notebooks/neural_processes.ipynb index 4666071..fb5e66d 100644 --- a/docs/notebooks/neural_processes.ipynb +++ b/docs/notebooks/neural_processes.ipynb @@ -68,9 +68,7 @@ "rng_key = jr.PRNGKey(1)\n", "sample_key, rng_key = jr.split(rng_key)\n", "\n", - "data = sample_from_gaussian_process(\n", - " sample_key, batch_size=10, num_observations=200\n", - ")\n", + "data = sample_from_gaussian_process(sample_key, batch_size=10, num_observations=200)\n", "(x_target, y_target), f_target = (data.x, data.y), data.f" ] }, @@ -344,9 +342,7 @@ " latent_encoder=(MLP([dim] * 3), MLP([dim, dim * 2])),\n", " deterministic_encoder=(\n", " MLP([dim] * 3),\n", - " MultiHeadAttention(\n", - " num_heads=4, head_size=16, embedding=MLP([dim] * 2)\n", - " ),\n", + " MultiHeadAttention(num_heads=4, head_size=16, embedding=MLP([dim] * 2)),\n", " ),\n", " )\n", " return np\n", diff --git a/examples/attentive_neural_process.py b/examples/attentive_neural_process.py index 15c9994..b13b755 100644 --- a/examples/attentive_neural_process.py +++ b/examples/attentive_neural_process.py @@ -1,16 +1,15 @@ -""" -Attentive neural process -======================== +# ruff: noqa: D103,PLR0913 +"""Attentive neural process example. Here, we implement and train an attentive neural process and visualize predictions thereof. References ---------- - [1] Kim, Hyunjik, et al. "Attentive Neural Processes." International Conference on Learning Representations. 2019. """ + import argparse import matplotlib.pyplot as plt @@ -114,7 +113,7 @@ def plot( def run(args): - n_context, n_target = 10, 20 + n_context, n_target = (5, 10), (20, 30) data_rng_key, train_rng_key, plot_rng_key = jr.split(jr.PRNGKey(0), 3) (x_target, y_target), f_target = data(data_rng_key) @@ -129,8 +128,8 @@ def run(args): x_target, y_target, f_target, - n_context, - n_target, + n_context=10, + n_target=20, ) diff --git a/examples/experimental/bayesian_neural_network.py b/examples/experimental/bayesian_neural_network.py index 7675f36..83b8b8c 100644 --- a/examples/experimental/bayesian_neural_network.py +++ b/examples/experimental/bayesian_neural_network.py @@ -8,7 +8,6 @@ References ---------- - [1] Blundell C., Cornebise J., Kavukcuoglu K., Wierstra D. "Weight Uncertainty in Neural Networks". ICML, 2015. """ diff --git a/examples/experimental/gaussian_process.py b/examples/experimental/gaussian_process.py index dd27b39..9712712 100644 --- a/examples/experimental/gaussian_process.py +++ b/examples/experimental/gaussian_process.py @@ -1,23 +1,21 @@ -""" -Gaussian process regression -=========================== +"""Gaussian process regression. This example implements the training and prediction of a Gaussian process regression model. References ---------- - [1] Rasmussen, Carl E and Williams, Chris KI. "Gaussian Processes for Machine Learning". MIT press, 2006. """ + import argparse +import jax import matplotlib.patches as mpatches import matplotlib.pyplot as plt from jax import numpy as jnp from jax import random as jr -from jax.config import config from ramsey.data import sample_from_gaussian_process from ramsey.experimental import ( @@ -26,7 +24,7 @@ train_gaussian_process, ) -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) def data(key, rho, sigma, n=1000): diff --git a/examples/experimental/sparse_gaussian_process.py b/examples/experimental/sparse_gaussian_process.py index b6ca6f0..094bf2f 100644 --- a/examples/experimental/sparse_gaussian_process.py +++ b/examples/experimental/sparse_gaussian_process.py @@ -1,24 +1,22 @@ -""" -Sparse Gaussian process regression -================================== +"""Sparse Gaussian process regression example. This example implements the training and prediction of a sparse Gaussian process regression model. References ---------- - [1] Titsias, Michalis K. "Variational Learning of Inducing Variables in Sparse Gaussian Processes". AISTATS, 2009. """ + import argparse +import jax import matplotlib.patches as mpatches import matplotlib.pyplot as plt from jax import numpy as jnp from jax import random as jr -from jax.config import config from ramsey.data import sample_from_gaussian_process from ramsey.experimental import ( @@ -27,7 +25,7 @@ train_sparse_gaussian_process, ) -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) def data(key, rho, sigma, n=1000): diff --git a/pyproject.toml b/pyproject.toml index 752d238..aa9846f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,13 +6,13 @@ skips = ["B101", "B310"] [tool.ruff] line-length = 80 -exclude = ["*_test.py", "setup.py"] +exclude = ["*_test.py", "setup.py", "docs/**", "examples/experimental/**"] [tool.ruff.lint] ignore= ["S101", "ANN1", "ANN2", "ANN0"] select = ["ANN", "D", "E", "F"] extend-select = [ - "UP", "D", "I", "PL", "S" + "UP", "I", "PL", "S" ] [tool.ruff.lint.pydocstyle] diff --git a/ramsey/__init__.py b/ramsey/__init__.py index 7402114..e59baba 100644 --- a/ramsey/__init__.py +++ b/ramsey/__init__.py @@ -5,7 +5,7 @@ from ramsey._src.neural_process.neural_process import NP from ramsey._src.neural_process.train_neural_process import train_neural_process -__version__ = "0.2.1" +__version__ = "0.2.2" __all__ = [ "ANP", diff --git a/ramsey/_src/neural_process/train_neural_process.py b/ramsey/_src/neural_process/train_neural_process.py index fad4b1d..0f1cd2b 100644 --- a/ramsey/_src/neural_process/train_neural_process.py +++ b/ramsey/_src/neural_process/train_neural_process.py @@ -1,3 +1,5 @@ +from typing import Tuple, Union + import jax import numpy as np import optax @@ -31,8 +33,8 @@ def train_neural_process( neural_process: NP, # pylint: disable=invalid-name x: Array, # pylint: disable=invalid-name y: Array, # pylint: disable=invalid-name - n_context: int, - n_target: int, + n_context: Union[int, Tuple[int]], + n_target: Union[int, Tuple[int]], batch_size: int, optimizer=optax.adam(3e-4), n_iter=20000, @@ -60,10 +62,16 @@ def train_neural_process( :math:`b \times n \times q` where :math:`b` and :math:`n` are the same as for :math:`x` and :math:`q` is the number of outputs - n_context: int - number of context points - n_target: int - number of target points + n_context: Union[int, Tuple[int]] + number of context points. If a tuple is given samples the number of + context points per iteration on the interval defined by the tuple. + n_target: Union[int, Tuple[int]] + number of target points. If a tuple is given samples the number of + context points per iteration on the interval defined by the tuple. + The number of target points includes the + number of context points, that means, if n_context=5 and n_target=10 + then the target set contains 5 more points than the context set but + includes the contexts, too. batch_size: int number of elements that are samples for each gradient step, i.e., number of elements in first axis of :math:`x` and :math:`y` @@ -109,22 +117,37 @@ def train_neural_process( return state.params, objectives -# pylint: disable=too-many-locals def _split_data( rng_key: jr.PRNGKey, - x: Array, # pylint: disable=invalid-name - y: Array, # pylint: disable=invalid-name + x: Array, + y: Array, batch_size: int, - n_context: int, - n_target: int, + n_context, + n_target, ): + if isinstance(n_context, tuple): + cnt_key, rng_key = jr.split(rng_key) + n_context = jr.randint( + cnt_key, minval=n_context[0], maxval=n_context[1], shape=() + ) + if isinstance(n_target, tuple): + trg_key, rng_key = jr.split(rng_key) + n_target = jr.randint( + trg_key, minval=n_target[0], maxval=n_target[1], shape=() + ) + + assert n_target > n_context batch_rng_key, idx_rng_key, rng_key = jr.split(rng_key, 3) ibatch = jr.choice( batch_rng_key, x.shape[0], shape=(batch_size,), replace=False ) idxs = jr.choice( - idx_rng_key, x.shape[1], shape=(n_context + n_target,), replace=False + idx_rng_key, + x.shape[1], + shape=(n_target,), + replace=False, ) + ibatch = np.asarray(ibatch, dtype=np.int32) x_context = x[ibatch][:, idxs[:n_context], :] y_context = y[ibatch][:, idxs[:n_context], :] x_target = x[ibatch][:, idxs, :] diff --git a/ramsey/_src/neural_process/train_neural_process_test.py b/ramsey/_src/neural_process/train_neural_process_test.py index 6e1cd50..ee72e52 100644 --- a/ramsey/_src/neural_process/train_neural_process_test.py +++ b/ramsey/_src/neural_process/train_neural_process_test.py @@ -12,7 +12,6 @@ def test_neural_process_training(module): data = sample_from_gaussian_process(key) x_target, y_target = data.x, data.y - print(module) key, train_key = jr.split(key) train_neural_process( train_key, @@ -20,7 +19,7 @@ def test_neural_process_training(module): x=x_target, y=y_target, n_context=10, - n_target=10, + n_target=20, n_iter=10, batch_size=2, ) diff --git a/tox.ini b/tox.ini index f1140bf..7912c4d 100644 --- a/tox.ini +++ b/tox.ini @@ -12,10 +12,10 @@ commands = [testenv:lints] skip_install = true commands_pre = - pip install ruff + pip install ruff bandit pip install -e . commands = - bandit ramsey + bandit -r ramsey -c pyproject.toml ruff check ramsey