Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Term.get_coefficients fails for jax implementation when built from source #75

Open
bmorris3 opened this issue Dec 2, 2022 · 3 comments

Comments

@bmorris3
Copy link

bmorris3 commented Dec 2, 2022

Hi @dfm,

Today I've built celerite2 from source following the recommendations on the install docs. I'm trying to do something simple, like this

from celerite2.jax import terms

sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
sho.get_coefficients()

but I'm getting the following error

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [1], line 4
      1 from celerite2.jax import terms
      3 sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
----> 4 sho.get_coefficients()

File ~/git/celerite2/python/celerite2/jax/terms.py:36, in Term.get_coefficients(self)
     35 def get_coefficients(self):
---> 36     raise NotImplementedError("subclasses must implement this method")

NotImplementedError: subclasses must implement this method

At first I thought this could be an accident of the multiple SHOTerm implementations, for example, here

def SHOTerm(*args, **kwargs):
over = OverdampedSHOTerm(*args, **kwargs)
under = UnderdampedSHOTerm(*args, **kwargs)
if over.Q < 0.5:
return over
return under

and here

class SHOTerm(Term):

but commenting the first one out doesn't solve the problem.

Any ideas? Thanks!

@bmorris3
Copy link
Author

bmorris3 commented Dec 2, 2022

Writing an issue is always the most clarifying exercise.

I was hitting this error because I was using TermConvolution, which uses this line:

class TermConvolution(Term):
def __init__(self, term, delta):
self.delta = np.float64(delta)
try:
self.coefficients = term.get_coefficients()
except NotImplementedError:
raise TypeError(
"Term operations can only be performed on terms that provide "
"coefficients"
)

but the jax implementation of SHOTerm has under/overdamped versions of the coefficients defined separately. Should TermConvolution be modified accordingly?

@dfm
Copy link
Member

dfm commented Dec 2, 2022

Great question! This is an issue introduced by this PR: #68

I should probably spend some time thinking about how to fix this properly, but one option for the short term (it'll take a minor? performance hit) would be something like:

import jax
import jax.numpy as jnp
from celerite2.jax import terms

def custom_sho_get_coeffs(self):
    ar, cr = self.get_overdamped_coefficients()
    ac, bc, cc, dc = self.get_underdamped_coefficients()
    cond = jnp.less(self.Q, 0.5)
    selectr = lambda x: jax.lax.cond(cond, lambda y: y, jnp.zeros_like, operand=x)
    selectc = lambda x: jax.lax.cond(cond, jnp.zeros_like, lambda y: y, operand=x)
    return selectr(ar), selectr(cr), selectc(ac), selectc(bc), selectc(cc), selectc(dc)

terms.SHOTerm.get_coefficients = custom_sho_get_coeffs
sho = terms.SHOTerm(S0=1, w0=3000, Q=0.6)
sho.get_coefficients()

@bmorris3
Copy link
Author

bmorris3 commented Dec 4, 2022

I can confirm the fix above works for me for now. Thanks Dan!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants