Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding debug mode to inferer #248

Merged
merged 3 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions mechanistic_model/mechanistic_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from typing import Union

import bayeux as bx
import jax.numpy as jnp
import jax.typing
import numpy as np
Expand Down Expand Up @@ -336,6 +337,34 @@ def infer(self, obs_metrics: jax.Array) -> MCMC:
self.infer_complete = True
return self.inference_algo

def _debug_likelihood(self, obs_metrics) -> bx.Model:
"""uses Bayeux to recreate the self.likelihood function for purposes of basic sanity checking

Parameters
----------
obs_metrics: jnp.array
observed metrics on which likelihood will be calculated on to tune parameters.
See `likelihood()` method for implemented definition of `obs_metrics`

Returns
-------
Bayeux.Model
model object used to debug
"""
bx_model = bx.Model.from_numpyro(
jax.tree_util.Partial(
self.likelihood,
tf=len(obs_metrics),
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved
obs_metrics=obs_metrics,
),
# this does not work for non-one/sampled self.INITIAL_INFECTIONS_SCALE
initial_state=self.INITIAL_STATE,
)
bx_model.mcmc.numpyro_nuts.debug(
seed=PRNGKey(self.config.INFERENCE_PRNGKEY)
)
return bx_model

def checkpoint(
self, checkpoint_path: str, group_by_chain: bool = True
) -> None:
Expand Down
Loading
Loading