Skip to content

Commit

Permalink
Only declare classes with fields as dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Jun 15, 2024
1 parent c39f36b commit 079a686
Show file tree
Hide file tree
Showing 5 changed files with 0 additions and 17 deletions.
3 changes: 0 additions & 3 deletions jax_transforms/alr.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from dataclasses import dataclass

import jax
import jax.numpy as jnp


@dataclass
class ALR:
def unconstrain(self, x):
return jnp.log(x[..., :-1]) - jnp.log(x[..., -1:])
Expand Down
3 changes: 0 additions & 3 deletions jax_transforms/expanded_softmax.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from dataclasses import dataclass

import jax
import jax.numpy as jnp
from tensorflow_probability.substrates.jax import distributions


@dataclass
class ExpandedSoftmax:
def unconstrain(self, r_x):
r, x = r_x
Expand Down
3 changes: 0 additions & 3 deletions jax_transforms/ilr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from dataclasses import dataclass, field

import jax
import jax.numpy as jnp

Expand All @@ -22,7 +20,6 @@ def _make_semiorthogonal_matrix(N: int):
return V


@dataclass
class ILR:
def unconstrain(self, x):
N = x.shape[-1]
Expand Down
1 change: 0 additions & 1 deletion jax_transforms/normalized_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def default_prior(self, x) -> distributions.Distribution:
return distributions.ExpGamma(concentration=alpha_sum, rate=1)


@dataclass
class NormalizedExponential(NormalizedGamma):
def __init__(self):
super().__init__(alpha=jnp.ones(()))
Expand Down
7 changes: 0 additions & 7 deletions jax_transforms/stickbreaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .util import vmap_over_leading_axes


@dataclass
class StickbreakingBase:
def _unconstrain_single(self, x):
def _running_remainder(remainder, xi):
Expand Down Expand Up @@ -39,7 +38,6 @@ def constrain_with_logdetjac(self, z):
return x, logJ


@dataclass
class StickbreakingCDF:
def get_distribution(self, N: int) -> distributions.Distribution:
raise NotImplementedError
Expand All @@ -65,13 +63,11 @@ def constrain_with_logdetjac(self, y):
return x, logJ


@dataclass
class StickbreakingLogistic(StickbreakingCDF):
def get_distribution(self, N: int) -> distributions.Logistic:
return distributions.Logistic(loc=jnp.log(jnp.arange(N - 1, 0, -1)), scale=1)


@dataclass
class StickbreakingNormal(StickbreakingCDF):
def get_distribution(self, N: int) -> distributions.Normal:
return distributions.Normal(loc=jnp.log(jnp.arange(N - 1, 0, -1)) / 2, scale=1)
Expand Down Expand Up @@ -104,19 +100,16 @@ def constrain_with_logdetjac(self, y):
return x, logJ


@dataclass
class StickbreakingPowerLogistic(StickbreakingPowerCDF):
def __init__(self):
super().__init__(distributions.Logistic(loc=0, scale=1))


@dataclass
class StickbreakingPowerNormal(StickbreakingPowerCDF):
def __init__(self):
super().__init__(distributions.Normal(loc=0, scale=1))


@dataclass
class StickbreakingAngular:
def unconstrain(self, x):
z = StickbreakingBase().unconstrain(x)
Expand Down

0 comments on commit 079a686

Please sign in to comment.