Skip to content

Commit

Permalink
Change training to use tuples (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier authored Mar 1, 2024
1 parent f79b352 commit 8b8fc06
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install dependencies
run: |
pip install tox
- name: Run format, sort, lints and types
- name: Run format, lints and types
run: |
tox -e format,lints,types
Expand Down
11 changes: 0 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,6 @@ repos:
args: ["--ignore-missing-imports"]
files: "(ramsey|examples)"

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
hooks:
- id: nbqa-black
- id: nbqa-pyupgrade
args: [--py39-plus]
- id: nbqa-isort
args: ['--profile=black']
- id: nbqa-flake8
args: ['--ignore=E501,E203,E302,E402,E731,W503']

- repo: https://github.com/jorisroovers/gitlint
rev: v0.19.1
hooks:
Expand Down
12 changes: 5 additions & 7 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
extensions = [
"nbsphinx",
"sphinx.ext.autodoc",
'sphinx_autodoc_typehints',
"sphinx_autodoc_typehints",
"sphinx.ext.autosummary",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
Expand All @@ -18,13 +18,13 @@
"sphinx_copybutton",
"sphinx_math_dollar",
"IPython.sphinxext.ipython_console_highlighting",
'sphinx_design'
"sphinx_design",
]


templates_path = ["_templates"]
html_static_path = ["_static"]
html_css_files = ['theme.css']
html_css_files = ["theme.css"]

autodoc_default_options = {
"member-order": "bysource",
Expand All @@ -39,7 +39,7 @@
".DS_Store",
"notebooks/.ipynb_checkpoints",
"examples/*ipynb",
"examples/*py"
"examples/*py",
]

html_theme = "sphinx_book_theme"
Expand All @@ -50,9 +50,7 @@
"use_download_button": False,
"use_fullscreen_button": False,
"extra_navbar": "",
"launch_buttons": {
"colab_url": "https://colab.research.google.com"
},
"launch_buttons": {"colab_url": "https://colab.research.google.com"},
}

html_title = "Ramsey 🚀"
45 changes: 11 additions & 34 deletions docs/notebooks/inference_with_flax_and_numpyro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,8 @@
"def model(y=None):\n",
" loc = numpyro.sample(\"loc\", dist.Normal(0.0, 1.0))\n",
" scale = numpyro.sample(\"scale\", dist.HalfNormal(0.05))\n",
" ar_coefficients = numpyro.sample(\n",
" \"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0)\n",
" )\n",
" numpyro.sample(\n",
" \"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y\n",
" )"
" ar_coefficients = numpyro.sample(\"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0))\n",
" numpyro.sample(\"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y)"
]
},
{
Expand Down Expand Up @@ -248,12 +244,8 @@
"def model(y=None):\n",
" loc = numpyro.sample(\"loc\", dist.Normal(0.0, 1.0))\n",
" scale = numpyro.sample(\"scale\", dist.HalfNormal(1.0))\n",
" ar_coefficients = numpyro.sample(\n",
" \"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0)\n",
" )\n",
" numpyro.sample(\n",
" \"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y\n",
" )"
" ar_coefficients = numpyro.sample(\"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0))\n",
" numpyro.sample(\"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y)"
]
},
{
Expand Down Expand Up @@ -450,12 +442,8 @@
"def model(y=None):\n",
" loc = numpyro.sample(\"loc\", dist.Normal(0.0, 1.0))\n",
" scale = numpyro.sample(\"scale\", dist.HalfNormal(1.0))\n",
" ar_coefficients = numpyro.sample(\n",
" \"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0)\n",
" )\n",
" numpyro.sample(\n",
" \"y\", Autoregressive(loc, ar_coefficients, scale, length=10), obs=y\n",
" )"
" ar_coefficients = numpyro.sample(\"ar_coefficients\", dist.Normal(jnp.zeros(3), 1.0))\n",
" numpyro.sample(\"y\", Autoregressive(loc, ar_coefficients, scale, length=10), obs=y)"
]
},
{
Expand Down Expand Up @@ -659,12 +647,8 @@
"def model(y=None):\n",
" loc = numpyro.param(\"loc\", 0.0)\n",
" scale = numpyro.param(\"scale\", 1.0, constraints=constraints.positive)\n",
" ar_coefficients = numpyro.param(\n",
" \"ar_coefficients\", jnp.array([-1.0, 0.0, 1.0])\n",
" )\n",
" numpyro.sample(\n",
" \"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y\n",
" )"
" ar_coefficients = numpyro.param(\"ar_coefficients\", jnp.array([-1.0, 0.0, 1.0]))\n",
" numpyro.sample(\"y\", Autoregressive(loc, ar_coefficients, scale, length=N), obs=y)"
]
},
{
Expand Down Expand Up @@ -882,9 +866,7 @@
"\n",
" def setup(self):\n",
" self.loc = self.param(\"loc\", initializers.zeros, (1, 1), jnp.float32)\n",
" self.log_scale = self.param(\n",
" \"log_scale\", initializers.ones, (1, 1), jnp.float32\n",
" )\n",
" self.log_scale = self.param(\"log_scale\", initializers.ones, (1, 1), jnp.float32)\n",
" self.ar_coefficients = self.param(\n",
" \"ar_coefficients\", initializers.zeros, (self.order, 1), jnp.float32\n",
" )\n",
Expand Down Expand Up @@ -950,10 +932,7 @@
" @jax.jit\n",
" def step(rngs, state, **batch):\n",
" current_step = state.step\n",
" rngs = {\n",
" name: jr.fold_in(rng, current_step)\n",
" for name, rng in rngs.items()\n",
" }\n",
" rngs = {name: jr.fold_in(rng, current_step) for name, rng in rngs.items()}\n",
"\n",
" def obj_fn(params):\n",
" obj = state.apply_fn(variables=params, rngs=rngs, **batch)\n",
Expand All @@ -969,9 +948,7 @@
" return state.params, objectives\n",
"\n",
"\n",
"state = create_train_state(\n",
" jr.PRNGKey(123), ARModel(3), optax.adam(0.01), inputs=y\n",
")\n",
"state = create_train_state(jr.PRNGKey(123), ARModel(3), optax.adam(0.01), inputs=y)\n",
"params, objectives = train(jr.PRNGKey(1), y, state, 1000)"
]
},
Expand Down
8 changes: 2 additions & 6 deletions docs/notebooks/neural_processes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@
"rng_key = jr.PRNGKey(1)\n",
"sample_key, rng_key = jr.split(rng_key)\n",
"\n",
"data = sample_from_gaussian_process(\n",
" sample_key, batch_size=10, num_observations=200\n",
")\n",
"data = sample_from_gaussian_process(sample_key, batch_size=10, num_observations=200)\n",
"(x_target, y_target), f_target = (data.x, data.y), data.f"
]
},
Expand Down Expand Up @@ -344,9 +342,7 @@
" latent_encoder=(MLP([dim] * 3), MLP([dim, dim * 2])),\n",
" deterministic_encoder=(\n",
" MLP([dim] * 3),\n",
" MultiHeadAttention(\n",
" num_heads=4, head_size=16, embedding=MLP([dim] * 2)\n",
" ),\n",
" MultiHeadAttention(num_heads=4, head_size=16, embedding=MLP([dim] * 2)),\n",
" ),\n",
" )\n",
" return np\n",
Expand Down
13 changes: 6 additions & 7 deletions examples/attentive_neural_process.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""
Attentive neural process
========================
# ruff: noqa: D103,PLR0913
"""Attentive neural process example.
Here, we implement and train an attentive neural process
and visualize predictions thereof.
References
----------
[1] Kim, Hyunjik, et al. "Attentive Neural Processes."
International Conference on Learning Representations. 2019.
"""

import argparse

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -114,7 +113,7 @@ def plot(


def run(args):
n_context, n_target = 10, 20
n_context, n_target = (5, 10), (20, 30)
data_rng_key, train_rng_key, plot_rng_key = jr.split(jr.PRNGKey(0), 3)
(x_target, y_target), f_target = data(data_rng_key)

Expand All @@ -129,8 +128,8 @@ def run(args):
x_target,
y_target,
f_target,
n_context,
n_target,
n_context=10,
n_target=20,
)


Expand Down
1 change: 0 additions & 1 deletion examples/experimental/bayesian_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
References
----------
[1] Blundell C., Cornebise J., Kavukcuoglu K., Wierstra D.
"Weight Uncertainty in Neural Networks". ICML, 2015.
"""
Expand Down
10 changes: 4 additions & 6 deletions examples/experimental/gaussian_process.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
"""
Gaussian process regression
===========================
"""Gaussian process regression.
This example implements the training and prediction of a Gaussian process
regression model.
References
----------
[1] Rasmussen, Carl E and Williams, Chris KI.
"Gaussian Processes for Machine Learning". MIT press, 2006.
"""

import argparse
import jax

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from jax import numpy as jnp
from jax import random as jr
from jax.config import config

from ramsey.data import sample_from_gaussian_process
from ramsey.experimental import (
Expand All @@ -26,7 +24,7 @@
train_gaussian_process,
)

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


def data(key, rho, sigma, n=1000):
Expand Down
10 changes: 4 additions & 6 deletions examples/experimental/sparse_gaussian_process.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
"""
Sparse Gaussian process regression
==================================
"""Sparse Gaussian process regression example.
This example implements the training and prediction of a sparse Gaussian process
regression model.
References
----------
[1] Titsias, Michalis K.
"Variational Learning of Inducing Variables in Sparse Gaussian Processes".
AISTATS, 2009.
"""

import argparse
import jax

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from jax import numpy as jnp
from jax import random as jr
from jax.config import config

from ramsey.data import sample_from_gaussian_process
from ramsey.experimental import (
Expand All @@ -27,7 +25,7 @@
train_sparse_gaussian_process,
)

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


def data(key, rho, sigma, n=1000):
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ skips = ["B101", "B310"]

[tool.ruff]
line-length = 80
exclude = ["*_test.py", "setup.py"]
exclude = ["*_test.py", "setup.py", "docs/**", "examples/experimental/**"]

[tool.ruff.lint]
ignore= ["S101", "ANN1", "ANN2", "ANN0"]
select = ["ANN", "D", "E", "F"]
extend-select = [
"UP", "D", "I", "PL", "S"
"UP", "I", "PL", "S"
]

[tool.ruff.lint.pydocstyle]
Expand Down
2 changes: 1 addition & 1 deletion ramsey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ramsey._src.neural_process.neural_process import NP
from ramsey._src.neural_process.train_neural_process import train_neural_process

__version__ = "0.2.1"
__version__ = "0.2.2"

__all__ = [
"ANP",
Expand Down
Loading

0 comments on commit 8b8fc06

Please sign in to comment.