-
I'm interested in trying to compute a stochastic gradient estimate with one of Optax's optimizers (and potentially adding control variates). However, I want access to both the gradient estimates AND the evaluated loss function, similar to I saw that under the hood some stochastic gradient estimate methods use Any help would be immensely appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 4 replies
-
Hello @gil2rok |
Beta Was this translation helpful? Give feedback.
-
Thanks so much for the fast answer! With vanilla jax we can get the loss and gradient: # mean squared error loss
loss_fn = lambda y_true, y_pred: jnp.mean((y_true - y_pred) ** 2)
loss_val, grads = jax.value_and_grad(loss_fn)(model(X) - y) But I have a stochastic loss function (used in variational inference): true_dist = dist_builder(true_params) # true dist with intractable sample method
def neg_elbo(approx_params):
approx_dist = dist_builder(approx_params)
sample = approx_dist.sample(key) # generates a single sample
log_q = approx_dist.logdensity(sample)
log_p = true_dist.logdensity(sample)
return log_p - log_q I can compute the gradient with Monte Carlo estimates, e.g. using pathwise Jacobians: # contains expectation w.r.t. approx_params
loss_fn = neg_elbo(approx_params)
jacobians = optax.monte_carlo.pathwise_jacobians(
function=loss_fn,
params=approx_params,
dist_builder=dist_builder,
rng=jax.random.key(0),
num_samples=100
)
grads = jnp.mean(jacobians, axis=0) But how do I get the loss value that generated this Jacobian/grads? Is this what Pollak SGD does? Also how would I add control variates to this -- I couldn't find any examples in the doc? |
Beta Was this translation helpful? Give feedback.
-
If anyone else needs is curious about alternatives, since |
Beta Was this translation helpful? Give feedback.
I don't think there's currently a way to do that using the methods from
optax.monte_carlo
. As a side note, note that we're deprecating that module (#1076)