Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matérn-5/2 kernel #133

Open
Kamuish opened this issue Jul 30, 2024 · 4 comments
Open

Matérn-5/2 kernel #133

Kamuish opened this issue Jul 30, 2024 · 4 comments

Comments

@Kamuish
Copy link

Kamuish commented Jul 30, 2024

Hello,

Is there any easy way of implementing a Matérn-5/2 kernel with celerite2?

@dfm
Copy link
Member

dfm commented Jul 30, 2024

It isn't straightforward because of technical reasons about how the core algorithms are implemented. Depending on your use case you might be able to use tinygp's quasiseparable solver which is a generalization of the method used in celerite2 (where that kernel is implemented!).

@Kamuish
Copy link
Author

Kamuish commented Jul 30, 2024

I have tried to use tinyGP, but the nature of what I am trying to implement makes it very hard to use jax (as far as I understand) in an efficient manner (and if I solve those issues I would then run into the same padding issue as the one discussed in tinyGP's discussion 108.

This is not the proper repository to discuss/ask this, but do you happen to have a version of the quasiseparable solver that is not written with jax? I have a fork of tinyGP on which I replaced some jax calls with numpy equivalents, but that is far from efficient/optimal

@dfm
Copy link
Member

dfm commented Jul 30, 2024

That padding discussion relates to a very specific kind of use and the limitations really only matter when you just evaluate the GP once (ie no fitting) for many datasets of different sizes. In the vast majority of use cases this definitely won't be a problem! I'd recommend trying it out for your use case and opening an issue there if you truly find it to be a problem.

I'm not aware of any non-JAX implementations of this solver. Replacing the JAX functions with numpy definitely isn't a good idea! It'll be far far slower. You'd need to implement the backend in a compiled language which would certainly be possible, but a major project.

I'd recommend making sure that the existing implementation doesn't do the trick first!

@Kamuish
Copy link
Author

Kamuish commented Jul 30, 2024

Thanks for all the help!

I guess that I will run into some limitations, maybe even at the level of how the JIT compilation is "stored" (sorry, very new to jax and all of this JIT stuff). I am trying to fit GPs in parallel to N datasets (N ~ 100) datasets each with a different size.

I'm not aware of any non-JAX implementations of this solver. Replacing the JAX functions with numpy definitely isn't a good idea! It'll be far far slower. You'd need to implement the backend in a compiled language which would certainly be possible, but a major project.

It was slower, but not terribly so. By any chance are you planning to write up the proofs for the quasi-seperable solver in any form (either docs or a paper)? It would be easier to follow that than to re-derive or reverse-engineer your implementation :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants