Skip to content

Commit

Permalink
Move to hatch
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Jul 20, 2024
1 parent 4c2fcc6 commit 3a282f5
Show file tree
Hide file tree
Showing 15 changed files with 44 additions and 66 deletions.
5 changes: 2 additions & 3 deletions ramsey/_src/data/dataset_m4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from collections import namedtuple
from dataclasses import dataclass
from typing import Tuple
from urllib.parse import urlparse
from urllib.request import urlretrieve

Expand Down Expand Up @@ -114,7 +113,7 @@ class M4Dataset:
os.path.join(os.path.dirname(__file__), ".data")
)

def load(self, interval: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
def load(self, interval: str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Load a M4 data set.
Parameters
Expand Down Expand Up @@ -156,7 +155,7 @@ def _download(self, dataset):

def _load(
self, dataset, train_csv_path: str, test_csv_path: str
) -> Tuple[pd.DataFrame, pd.DataFrame]:
) -> tuple[pd.DataFrame, pd.DataFrame]:
self._download(dataset)
train_df = pd.read_csv(train_csv_path, sep=",", header=0, index_col=0)
test_df = pd.read_csv(test_csv_path, sep=",", header=0, index_col=0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Tuple

from flax import linen as nn
from flax.linen import initializers
from jax import Array
Expand Down Expand Up @@ -49,9 +47,9 @@ class BayesianLinear(nn.Module):
output_size: int
use_bias: bool = True
mc_sample_size: int = 10
w_prior: Optional[dist.Distribution] = dist.Normal(loc=0.0, scale=1.0)
b_prior: Optional[dist.Distribution] = dist.Normal(loc=0.0, scale=1.0)
name: Optional[str] = None
w_prior: dist.Distribution | None = dist.Normal(loc=0.0, scale=1.0)
b_prior: dist.Distribution | None = dist.Normal(loc=0.0, scale=1.0)
name: str | None = None

def setup(self):
"""Construct a linear Bayesian layer."""
Expand Down Expand Up @@ -119,7 +117,7 @@ def _get_bias(self, layer_dim, dtype):
def _init_param(self, weight_name, param_name, constraint, shape, dtype):
init = initializers.xavier_normal()

shape = (shape,) if not isinstance(shape, Tuple) else shape
shape = (shape,) if not isinstance(shape, tuple) else shape

Check warning on line 120 in ramsey/_src/experimental/bayesian_neural_network/bayesian_linear.py

View check run for this annotation

Codecov / codecov/patch

ramsey/_src/experimental/bayesian_neural_network/bayesian_linear.py#L120

Added line #L120 was not covered by tests
params = self.param(f"{weight_name}_{param_name}", init, shape, dtype)

params = jnp.where(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable
from collections.abc import Iterable

from flax import linen as nn
from jax import Array
Expand Down
9 changes: 4 additions & 5 deletions ramsey/_src/experimental/distributions/autoregressive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial
from typing import Optional

import jax
import numpy as np
Expand Down Expand Up @@ -53,8 +52,8 @@ def __init__(self, loc, ar_coefficients, scale, length=None):
def sample(
self,
rng_key: jr.PRNGKey,
length: Optional[int] = None,
initial_state: Optional[float] = None,
length: int | None = None,
initial_state: float | None = None,
sample_shape=(),
):
"""Sample from the distribution.
Expand Down Expand Up @@ -125,8 +124,8 @@ def log_prob(self, value: Array):

def mean(
self,
length: Optional[int] = None,
initial_state: Optional[float] = None,
length: int | None = None,
initial_state: float | None = None,
):
"""Compute the mean of the autoregressive distribution.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from flax import linen as nn
from flax.linen import initializers
from jax import Array
Expand All @@ -25,7 +23,7 @@ class GP(nn.Module):
"""

kernel: Kernel
sigma_init: Optional[initializers.Initializer] = None
sigma_init: initializers.Initializer | None = None

@nn.compact
def __call__(self, x: Array, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from flax import linen as nn
from flax.linen import initializers
from jax import Array
Expand All @@ -23,10 +21,10 @@ class Linear(Kernel, nn.Module):
an initializer object from Flax or None
"""

active_dims: Optional[list] = None
sigma_b_init: Optional[initializers.Initializer] = initializers.uniform()
sigma_v_init: Optional[initializers.Initializer] = initializers.uniform()
offset_init: Optional[initializers.Initializer] = initializers.uniform()
active_dims: list | None = None
sigma_b_init: initializers.Initializer | None = initializers.uniform()
sigma_v_init: initializers.Initializer | None = initializers.uniform()
offset_init: initializers.Initializer | None = initializers.uniform()

def setup(self):
"""Construct parameters."""
Expand Down
16 changes: 7 additions & 9 deletions ramsey/_src/experimental/gaussian_process/kernel/stationary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Union

from flax import linen as nn
from flax.linen import initializers
from jax import Array
Expand All @@ -26,9 +24,9 @@ class Periodic(Kernel, nn.Module):
"""

period: float
active_dims: Optional[list] = None
rho_init: Optional[initializers.Initializer] = initializers.uniform()
sigma_init: Optional[initializers.Initializer] = initializers.uniform()
active_dims: list | None = None
rho_init: initializers.Initializer | None = initializers.uniform()
sigma_init: initializers.Initializer | None = initializers.uniform()

def setup(self):
"""Construct the covariance function."""
Expand Down Expand Up @@ -74,9 +72,9 @@ class ExponentiatedQuadratic(Kernel, nn.Module):
name of the layer
"""

active_dims: Optional[list] = None
rho_init: Optional[initializers.Initializer] = None
sigma_init: Optional[initializers.Initializer] = None
active_dims: list | None = None
rho_init: initializers.Initializer | None = None
sigma_init: initializers.Initializer | None = None

def setup(self):
"""Construct a stationary covariance."""
Expand Down Expand Up @@ -117,7 +115,7 @@ def exponentiated_quadratic(
x1: Array,
x2: Array,
sigma: float,
rho: Union[float, jnp.ndarray],
rho: float | jnp.ndarray,
):
"""Exponentiated-quadratic convariance function.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from flax import linen as nn
from flax.linen import initializers
from jax import Array
Expand Down Expand Up @@ -38,11 +36,11 @@ class SparseGP(nn.Module):

kernel: Kernel
n_inducing: int
jitter: Optional[float] = 10e-8
log_sigma_init: Optional[initializers.Initializer] = initializers.constant(
jitter: float | None = 10e-8
log_sigma_init: initializers.Initializer | None = initializers.constant(
jnp.log(1.0)
)
inducing_init: Optional[initializers.Initializer] = initializers.uniform(1)
inducing_init: initializers.Initializer | None = initializers.uniform(1)

@nn.compact
def __call__(self, x: Array, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Tuple

from chex import assert_axis_dimension, assert_rank
from flax import linen as nn
from jax import Array
Expand Down Expand Up @@ -39,8 +37,8 @@ class RANP(ANP):
"""

decoder: nn.Module
latent_encoder: Optional[Tuple[nn.Module, nn.Module]] = None
deterministic_encoder: Optional[Tuple[nn.Module, Attention]] = None
latent_encoder: tuple[nn.Module, nn.Module] | None = None
deterministic_encoder: tuple[nn.Module, Attention] | None = None
family: Family = Gaussian()

def setup(self):
Expand Down
6 changes: 2 additions & 4 deletions ramsey/_src/neural_process/attentive_neural_process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from chex import assert_axis_dimension
from flax import linen as nn
from jax import numpy as jnp
Expand Down Expand Up @@ -43,8 +41,8 @@ class ANP(NP):
"""

decoder: nn.Module
latent_encoder: Optional[nn.Module] = None
deterministic_encoder: Optional[nn.Module] = None
latent_encoder: nn.Module | None = None
deterministic_encoder: nn.Module | None = None
family: Family = Gaussian()

def setup(self):
Expand Down
6 changes: 2 additions & 4 deletions ramsey/_src/neural_process/neural_process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional, Tuple

import flax
import jax
import numpyro.distributions as dist
Expand Down Expand Up @@ -45,8 +43,8 @@ class NP(nn.Module):
"""

decoder: nn.Module
latent_encoder: Optional[Tuple[flax.linen.Module, flax.linen.Module]] = None
deterministic_encoder: Optional[flax.linen.Module] = None
latent_encoder: tuple[flax.linen.Module, flax.linen.Module] | None = None
deterministic_encoder: flax.linen.Module | None = None
family: Family = Gaussian()

def setup(self):
Expand Down
6 changes: 2 additions & 4 deletions ramsey/_src/neural_process/train_neural_process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple, Union

import jax
import numpy as np
import optax
Expand Down Expand Up @@ -33,8 +31,8 @@ def train_neural_process(
neural_process: NP, # pylint: disable=invalid-name
x: Array, # pylint: disable=invalid-name
y: Array, # pylint: disable=invalid-name
n_context: Union[int, Tuple[int]],
n_target: Union[int, Tuple[int]],
n_context: int | tuple[int],
n_target: int | tuple[int],
batch_size: int,
optimizer=optax.adam(3e-4),
n_iter=20000,
Expand Down
4 changes: 2 additions & 2 deletions ramsey/_src/nn/MLP.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterable, Optional
from collections.abc import Callable, Iterable

import jax
from flax import linen as nn
Expand Down Expand Up @@ -29,7 +29,7 @@ class MLP(nn.Module):
"""

output_sizes: Iterable[int]
dropout: Optional[float] = None
dropout: float | None = None
kernel_init: initializers.Initializer = default_kernel_init
bias_init: initializers.Initializer = initializers.zeros_init()
use_bias: bool = True
Expand Down
4 changes: 1 addition & 3 deletions ramsey/_src/nn/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import chex
from flax import linen as nn
from jax import Array
Expand All @@ -17,7 +15,7 @@ class Attention(nn.Module):
an optional embedding network that embeds keys and queries
"""

embedding: Optional[nn.Module]
embedding: nn.Module | None

@nn.compact
def __call__(self, key: Array, value: Array, query: Array):
Expand Down
14 changes: 7 additions & 7 deletions ramsey/_src/nn/attention/multihead_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Callable, Optional
from collections.abc import Callable

from flax import linen as nn
from flax.linen import dot_product_attention, initializers
Expand Down Expand Up @@ -39,7 +39,7 @@ class MultiHeadAttention(Attention):

num_heads: int
head_size: int
embedding: Optional[nn.Module]
embedding: nn.Module | None

def setup(self) -> None:
"""Construct the networks."""
Expand Down Expand Up @@ -78,11 +78,11 @@ class _MultiHeadAttention(nn.Module):
num_heads: int
dtype = None
param_dtype = jnp.float32
qkv_features: Optional[int] = None
out_features: Optional[int] = None
qkv_features: int | None = None
out_features: int | None = None
broadcast_dropout: bool = True
dropout_rate: float = 0.0
deterministic: Optional[bool] = None
deterministic: bool | None = None
precision: PrecisionLike = None
kernel_init: Callable = default_kernel_init
bias_init: Callable = initializers.zeros_init()
Expand All @@ -98,8 +98,8 @@ def __call__(
query: Array,
key: Array,
value: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
mask: Array | None = None,
deterministic: bool | None = None,
) -> Array:
features = self.out_features or query.shape[-1]
qkv_features = self.qkv_features or query.shape[-1]
Expand Down

0 comments on commit 3a282f5

Please sign in to comment.