In Conditional survival part 2: Survival probabilities very close to zero, we verified that the multinomial and the conditional binomial probability models produce identical likelihoods when calculated with numpyro. In this part of the series, we take a look at extreme survival observations and survival probabilities very close to zero. We will find out when these break the solver, disrupt the sampler and step by step identify the causes of the problem and solutions to deal with it.
import jax
import numpy as np
import arviz as az
from matplotlib import pyplot as plt
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from diffrax import diffeqsolve, ODETerm, PIDController, SaveAt, Tsit5, Euler
from numpyro.infer import initialization as init
rng = np.random.default_rng(seed=1)
Testing the conditional binomial model for values close to zero
In previous case studies, I observed that under extreme survival data (i.e. very fast mortality), the solver reaches a large number of steps and inference becomes difficult. This suggests that the error model and the solver interact in a way that is problematic when the survival function approaches zero. To investigate, I’m simulating a simple model with a constant hazard rate and pay attention to the interaction between solver tolerances and behavior of the survival function close to zero, and especially the values of the conditional survival und different post processing schemes, which enter the likelihood function.
# generate survival data
rng = np.random.default_rng(1)
t = np.arange(0,11)
n_trials = np.array([10, 15, 20, 25, 30, 35, 40, 45, 50, 55]) * 100
def generate_lethality_from_parametric_function(b_mean):
"""Large b (> 1) generate steep survival functions, and small b (< 1)
generate, moderately sloped survival functions
"""
b = rng.lognormal(np.log(b_mean), 0.1, size=(10))
print("hazard rates:\n", b)
probs = -1 * np.diff(np.exp(- (np.einsum("t,k->kt", t, b))))
probs = np.column_stack([probs, 1 - probs.sum(axis=1)])
lethality = np.array(list(map(lambda n, p: rng.multinomial(n=n, pvals=p), n_trials, probs)))
print("lethality:\n", lethality)
return b, probs, lethality
Programmatically, I’m approaching the problem with diffrax
, a autodifferentiable ODE solver, which is built on JAX
just like numpyro
.
def ode(t, y, params):
H, S = y
b, = params
# constant hazard
h = b
dH_dt = h
dS_dt = -h * S
return jnp.array([dH_dt, dS_dt])
def solve(t, b, atol, rtol):
sol = diffeqsolve(
terms=ODETerm(ode), # type: ignore
solver=Tsit5(),
t0=t.min(),
t1=t.max(),
dt0=0.1, # type: ignore
y0=jnp.array([0.0, 1.0], ndmin=1), # type: ignore
saveat=SaveAt(ts=jnp.array(t)), # type: ignore
args=(b,), # type: ignore
# max_steps=int(1e7),
stepsize_controller=PIDController(rtol=rtol, atol=atol), # type: ignore
)
return sol
# parameters
t_ = np.linspace(0,50,1000)
mf = 3.0
atol = 1e-8
b_mean = 3
truncate = lambda y, eps: jnp.trunc(y/(eps*mf))*(eps*mf)
clip = lambda y, eps: jnp.clip(y, eps*mf, 1.0-eps*mf)
maxim = lambda y, eps: jnp.maximum(y, eps*mf)
softmax = lambda y, eps: jnp.maximum(y, eps) + (eps-jnp.maximum(y, eps)) / (1 + jnp.exp(0.0001 * (jnp.log(jnp.maximum(y, eps)) - jnp.log(eps) )))
as_is = lambda y, eps: y
sols = jax.vmap(lambda atol: solve(t_, b=b_mean, atol=atol, rtol=1e-7).ys)(np.array([atol]))
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(14,5))
ax1.hlines(atol * mf,t_.min(), t_.max(), color="black")
ax1.hlines(-atol * mf,t_.min(), t_.max(), color="black")
ax1.plot(t_, sols[:,:,1].T, alpha=1, lw=1, label="untransformed", color="black")
ax1.plot(t_, truncate(sols[:,:,1].T, atol), alpha=1, lw=1, label="truncated", color="tab:blue")
ax1.plot(t_, clip(sols[:,:,1].T, atol*0.9), alpha=1, lw=1, label="clipped", color="tab:orange")
ax1.plot(t_, np.exp(-sols[:,:,0].T), alpha=1, lw=1.5, label="untruncated exp(-H)", ls="--", color="tab:green")
# ax1.plot(t_, softmax(sols[:,:,1].T, atol/10**30), alpha=1, lw=1, label="softmax")
ax1.set_ylim(-atol*10,atol*10)
ax1.set_xlabel("Time")
ax1.set_ylabel("S")
ax1.set_xlim(t_.min(), t_.max())
ax1.legend()
# ax1.set_yscale("log")
ax1.set_xlim(t_.min(),t_.max()/1)
ax2.plot(t_[1:], sols[:, 1:, 1].T / sols[:, :-1, 1].T, label="untransformed", color="black")
ax2.plot(t_[1:], truncate(sols[:, 1:, 1].T, atol) / truncate(sols[:, :-1, 1].T, atol), label="truncated", color="tab:blue")
ax2.plot(t_[1:], clip(sols[:, 1:, 1].T, atol) / clip(sols[:, :-1, 1].T, atol), label="clipped", color="tab:orange")
ax2.plot(t_[1:], jnp.exp(-(sols[:, 1:, 0].T - sols[:, :-1, 0].T)), label="untruncated (H)", ls="--", color="tab:green", lw=1.5)
# ax2.plot(t_[1:], softmax(sols[:, 1:, 1].T, atol/10**30) / softmax(sols[:, :-1, 1].T, atol/10**30), label="softmax")
ax2.set_xlim(t_.min(),t_.max()/1)
ax2.set_ylim(0.0,1.1)
ax2
ax2.legend()
<matplotlib.legend.Legend at 0x7f9e40520c50>
It turns out that any post processing or the untransformed survival function will be massively problematic for the likelihood function.
- untransformed: Values above and below zero are possible, highly fluctuating. Therefore, unusable for a probability vector
- truncated: Erratic before zero is reached. Unusable
- clipped: relatively stable but the transition to 1.0 conditional probability is wrong, since the hazard is constant. The conditional survival probability should also be constant
When using a constant step size, I also noticed that the conditional survival function makes a switch at approximately $t=30$, leading to a transition to 1.0, due to automatic truncation of floating points when the minimal floating point value is reached. This is a hard constraint for the solver that can only be remedied by using higher floating point bit size.
Using the cumulative (integrated) hazard directly to compute conditional survival
In this moment I realized that I can use the cumulative hazard directly in the calculation of the conditional survival, as it is much better behaved and does not produce underflow issues. The standard calculation of the conditional survival function is:
$$\Pr(t < T~|~t_0 < T) = \frac{S(t)}{S(t_0)}$$
However, $S(t) = \exp(-H(t))$ can be substituted into the equation. Leading to:
$$\Pr(t < T~|~t_0 < T) = \frac{e^{-H(t)}}{e^{-H(t_0)}} = e^{-H(t) - (- H(t_0))} = e^{-H(t) + H(t_0)} = e^{-(H(t) - H(t_0))} = e^{H(t_0) - H(t)}$$
Take some time to think about the equation. It makes sense. If for instance the cumulative hazard becomes constant, due to a zero hazard rate, then the exponent becomes zero and the conditional survival probability goes towards 1.0. Conversly if the hazard rate is very large the exponent goes towards negative infinity and the conditional survival probability approaches one. Also here we have a hard constraint on how extreme the hazard rate can become, but this is exponentially larger than when using the survival function.
def conditional_binomial_model(lethality, n_trials):
batch_size, observation_times = lethality.shape
# linear hazard model to artificially produce high hazards
b = numpyro.sample("b", numpyro.distributions.HalfNormal(scale=np.repeat([10.0],repeats=10)))
atol = 1e-5
sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-4))(b)
# Survival probability
S = sol.ys[:,:,1]
S_ = clip(jnp.column_stack([S, jnp.zeros_like(S[:,[0]])]), atol*10)
# conditional survival probability S(t) / S(t-1)
S_cond = clip(S_[:, 1:] / S_[:, :-1],atol*10)
# We have to do the same with the data and subtract it from the number of trials
# to obtain the number of survivors at time t [0, t1, t2, inf]
L_cumulative_obs = jnp.column_stack([jnp.zeros(batch_size, dtype=int), lethality.cumsum(axis=1)])
# the number of survivors at time t [0, t1, t2, inf]
S_obs = n_trials.reshape((batch_size, 1)) - L_cumulative_obs
# obtain the number of survivors before the start of the next interval [0, t1, t2]
# also note that the last observation at infinity is only theoretical.
S_before_t = S_obs[:, :-1]
# here we can exploit the fact that due to all our transformations, all
# observations are i.i.d they are conditionally independent that the
# organisms have survived until the respective beginning of the time window
with numpyro.plate("time", observation_times):
with numpyro.plate("batch", batch_size):
counts = numpyro.sample("lethality", dist.Binomial(total_count=S_before_t, probs=S_cond), obs=S_obs[:, 1:])
return counts
b, probs, lethality = generate_lethality_from_parametric_function(b_mean=0.1)
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)
# allow 5% deviation from any true probability value
np.testing.assert_allclose(b_est, b, atol=0.05)
hazard rates:
[0.10351625 0.10856315 0.10335957 0.08778182 0.10947605 0.10456487
0.09477208 0.10598335 0.103713 0.10298501]
lethality:
[[ 95 91 78 74 70 71 46 40 37 35 363]
[ 169 117 136 114 90 83 85 55 74 70 507]
[ 188 172 181 163 117 119 115 103 88 85 669]
[ 210 181 174 178 155 149 129 106 99 99 1020]
[ 329 279 235 232 198 182 159 165 120 107 994]
[ 353 311 271 273 262 191 169 163 139 129 1239]
[ 369 344 286 250 242 243 207 175 202 140 1542]
[ 441 420 358 316 301 266 255 232 194 152 1565]
[ 492 450 394 376 300 285 251 249 214 190 1799]
[ 524 443 448 327 360 317 305 281 235 245 2015]]
sample: 100%|██████████| 1500/1500 [00:10<00:00, 143.48it/s, 3 steps of size 7.37e-01. acc. prob=0.86]
estimates b: [0.10272734 0.10828657 0.10805375 0.0894981 0.11091767 0.1046254
0.09517204 0.10565186 0.10257991 0.09936014]
The above function works well for moderate hazards and delivers unbiased estimates. Clipping gets rid of the most easily diagnosable evils. Namely probability above 1 and below 0. These will lead to inifities in the likelihood function and are thus observable even when running the model forward. I.e. computing the likelihood. This means the model does not throw errors, but as previously seen the conditional probabilites are quite mad. If we would not clip the results, the model would not compute at all. Nevertheless the model is still severly wrong, as we will see when we confront the model with more problematic likelihoods
b, probs, lethality = generate_lethality_from_parametric_function(b_mean=3)
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
try:
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
except RuntimeError as err:
print(err.args[0].split("\n")[0])
hazard rates:
[2.89542491 3.09905782 2.89982806 2.98213132 3.07464539 2.78416219
3.21069122 2.86229079 2.75011732 3.02319897]
lethality:
[[ 937 58 5 0 0 0 0 0 0 0 0]
[1443 53 4 0 0 0 0 0 0 0 0]
[1893 102 4 1 0 0 0 0 0 0 0]
[2389 105 6 0 0 0 0 0 0 0 0]
[2868 129 3 0 0 0 0 0 0 0 0]
[3259 223 16 1 1 0 0 0 0 0 0]
[3848 150 2 0 0 0 0 0 0 0 0]
[4210 267 22 1 0 0 0 0 0 0 0]
[4683 299 17 1 0 0 0 0 0 0 0]
[5223 263 12 2 0 0 0 0 0 0 0]]
0%| | 0/1500 [00:06<?, ?it/s]
INTERNAL: Generated function failed: CpuCallback error: EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.
The solver breaks down because it runs out of steps, because supposedly there are weird gradients
from functools import partial
loglik = lambda x: numpyro.infer.log_likelihood(partial(conditional_binomial_model, lethality, n_trials), x)["lethality"].sum()
theta = {"b": jnp.expand_dims(jnp.linspace(8.106, 8.11, 10), axis=0)}
print("Log likelihood:", loglik(theta))
print("Grad Log-likelihood:", jax.grad(loglik)(theta))
Log likelihood: -19883.049
Grad Log-likelihood: {'b': Array([[ -16682.068 , -17545.986 , -39604.836 , -50014.46 ,
-77475.55 , -182425.44 , -195783.94 , -289.02518,
0. , 0. ]], dtype=float32)}
Interestingly, the gradients and the log likelihoods look okay. The only weird thing is that they go to zero at large values. There is a parameter range in which a very strong transition occurs.
The conditional binomial hazard model
Next, I’m looking at the binomial hazard model that I had previously identified as a numerically more stable alternative to using the survival probabilities directly in the computation of the conditional probabilities.
$$\Pr(t < T~|~t_0 < T) = \frac{S(t)}{S(t_0)} = e^{H(t_0) - H(t)}$$
def conditional_binomial_hazard_model(lethality, n_trials):
batch_size, observation_times = lethality.shape
# linear hazard model to artificially produce high hazards
b = numpyro.sample("b", numpyro.distributions.HalfNormal(scale=np.repeat([10.0],repeats=10)))
atol = 1e-6
sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-3))(b)
# Survival probability
H = sol.ys[:,:,0]
H = jnp.column_stack([H, jnp.full_like(H[:, 0], jnp.inf)])
# conditional survival probability S(t) / S(t-1)
S_cond = jnp.exp(- (H[:, 1:] - H[:, :-1]))
# TODO: This may not be necessary, because the data are not connected to any
# variable. Also if this is done with the last modeled survival probability
# because then the conditional probability become 0 / S(t[-1]) = 0
# S_cond = jnp.clip(jnp.column_stack([S_cond, jnp.zeros(batch_size)]), 1e-20, 1-1e-20)
# We have to do the same with the data and subtract it from the number of trials
# to obtain the number of survivors at time t [0, t1, t2, inf]
L_cumulative_obs = jnp.column_stack([jnp.zeros(batch_size, dtype=int), lethality.cumsum(axis=1)])
# the number of survivors at time t [0, t1, t2, inf]
S_obs = n_trials.reshape((batch_size, 1)) - L_cumulative_obs
# obtain the number of survivors before the start of the next interval [0, t1, t2]
# also note that the last observation at infinity is only theoretical.
S_before_t = S_obs[:, :-1]
# lp = dist.Binomial(total_count=S_before_t, probs=S_cond).log_prob(S_obs[:, 1:-1]).sum()
# jax.debug.print("{x}", x=lp)
# here we can exploit the fact that due to all our transformations, all
# observations are i.i.d they are conditionally independent that the
# organisms have survived until the respective beginning of the time window
with numpyro.plate("time", observation_times):
with numpyro.plate("batch", batch_size):
counts = numpyro.sample("lethality", dist.Binomial(total_count=S_before_t, probs=S_cond), obs=S_obs[:, 1:])
return counts
b, probs, lethality = generate_lethality_from_parametric_function(b_mean=0.1)
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)
# allow 5% deviation from any true probability value
np.testing.assert_allclose(b_est, b, atol=0.05)
hazard rates:
[0.09530118 0.09322559 0.10139153 0.09713273 0.1154754 0.10000202
0.10329215 0.10998812 0.09703722 0.11545073]
lethality:
[[ 91 86 83 65 62 55 59 44 47 28 380]
[ 126 124 117 106 85 78 62 69 67 50 616]
[ 203 147 155 146 130 117 116 83 83 78 742]
[ 231 218 192 163 165 150 131 134 95 95 926]
[ 335 337 279 221 215 163 158 141 126 125 900]
[ 333 298 286 262 214 191 201 157 143 148 1267]
[ 367 337 327 288 258 225 201 188 202 164 1443]
[ 472 447 377 356 308 254 234 212 184 167 1489]
[ 481 440 359 338 300 299 264 220 216 189 1894]
[ 588 549 490 419 371 317 287 300 216 187 1776]]
sample: 100%|██████████| 1500/1500 [00:11<00:00, 133.45it/s, 3 steps of size 7.16e-01. acc. prob=0.88]
estimates b: [0.09747483 0.08992738 0.09922651 0.09927562 0.12132029 0.10146473
0.10105705 0.11152633 0.09744926 0.11399455]
The conditional binomial hazard model also is an unbiased estimator of the parameters. One open question is whether it is necessary to include the remaining survival in the likelihood function
b, probs, lethality = generate_lethality_from_parametric_function(b_mean=2)
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model, init_strategy=init.init_to_value(values={"b": np.repeat([2], 10)}))
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
try:
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
except RuntimeError as err:
print(err.args[0].split("\n")[0])
hazard rates:
[1.7656863 2.11675695 2.41223889 2.25432954 2.20837715 2.00403966
2.20672677 1.81581408 2.15598705 1.98262622]
lethality:
[[ 841 134 20 5 0 0 0 0 0 0 0]
[1316 168 16 0 0 0 0 0 0 0 0]
[1820 160 18 2 0 0 0 0 0 0 0]
[2230 242 26 2 0 0 0 0 0 0 0]
[2653 310 33 3 1 0 0 0 0 0 0]
[3012 436 44 5 2 1 0 0 0 0 0]
[3534 423 36 7 0 0 0 0 0 0 0]
[3730 641 111 18 0 0 0 0 0 0 0]
[4414 511 65 9 1 0 0 0 0 0 0]
[4726 669 90 12 3 0 0 0 0 0 0]]
0%| | 0/1500 [00:06<?, ?it/s]
INTERNAL: Generated function failed: CpuCallback error: EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.
Difficulties of the solver to initialize
The model can also not be computed. After intense debugging (debug print statements after drawing samples and after computing likelihoods were the key), I realized that under strong declines of the survival functions (extreme hazards), the solver breaks down this happens at hazard rates beyond 1000.
Apparently just after initialization, the sampler explored extreme parameters. This is probably driven by the gradients, which in turn are driven by the survival samples, because it does not happen under well distributed survival observations. This problem cannot be solved by using a different post-processing of the state variables, because the problem arises from little information contained in the data.
If extreme values of $b$ can also reproduce the data, then the sampler should and will explore these. Only setting hard bounds on this prevents the sampler from running into extreme regions that the solver cannot handle.
from functools import partial
loglik = lambda x: numpyro.infer.log_likelihood(partial(conditional_binomial_hazard_model, lethality, n_trials), x)["lethality"].sum()
theta = {"b": jnp.expand_dims(jnp.linspace(0,100, 10), axis=0)}
print("Log likelihood:", loglik(theta))
print("Grad Log-likelihood:", jax.grad(loglik)(theta))
try:
theta = {"b": jnp.expand_dims(jnp.linspace(0,10000, 10), axis=0)}
print("Log likelihood:", loglik(theta))
print("Grad Log-likelihood:", jax.grad(loglik)(theta))
except RuntimeError as err:
print("Computation for extreme values of 'b' > 100")
print(err.args[0].split("\n")[0])
Log likelihood: -309155.6
Grad Log-likelihood: {'b': Array([[ 0. , -199.97758, -202. , -300.00006, -388.99997,
-551.99994, -516.00006, -917.0001 , nan, nan]], dtype=float32)}
Computation for extreme values of 'b' > 100
The maximum number of solver steps was reached. Try increasing `max_steps`.
So, in order to avoid extreme parameter samples I’m using a TruncatedCauchy distribution. In addition, I avoid the computation of the probability from the last observation to infinity, because the information about the number of alive organisms is already accounted for by comparing the last survival probability to the data.
def conditional_binomial_hazard_model_trunc_prior(lethality, n_trials):
batch_size, observation_times = lethality.shape
# linear hazard model to artificially produce high hazards
b = numpyro.sample("b", numpyro.distributions.TruncatedCauchy(scale=np.repeat([1.0],repeats=10), low=0, high=1000))
atol = 1e-6
sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-3))(b)
# Survival probability
H = sol.ys[:,:,0]
H = jnp.column_stack([H, jnp.full_like(H[:, 0], jnp.inf)])
# conditional survival probability S(t) / S(t-1)
S_cond = jnp.exp(- (H[:, 1:] - H[:, :-1]))
# to obtain the number of survivors at time t [0, t1, t2, inf]
L_cumulative_obs = jnp.column_stack([jnp.zeros(batch_size, dtype=int), lethality.cumsum(axis=1)])
# the number of survivors at time t [0, t1, t2, inf]
S_obs = n_trials.reshape((batch_size, 1)) - L_cumulative_obs
# obtain the number of survivors before the start of the next interval [0, t1, t2]
# also note that the last observation at infinity is only theoretical.
S_before_t = S_obs[:, :-1]
# organisms have survived until the respective beginning of the time window
with numpyro.plate("time", observation_times):
with numpyro.plate("batch", batch_size):
counts = numpyro.sample("lethality", dist.Binomial(total_count=S_before_t, probs=S_cond), obs=S_obs[:, 1:])
return counts
b, probs, lethality = generate_lethality_from_parametric_function(b_mean=0.1)
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model_trunc_prior, init_strategy=init.init_to_sample)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)
# allow 5% deviation from any true probability value
np.testing.assert_allclose(b_est, b, atol=0.05)
hazard rates:
[0.10011198 0.08686346 0.10129305 0.10928235 0.09759295 0.10757679
0.1074746 0.10457167 0.11872323 0.1080997 ]
lethality:
[[ 104 84 72 79 57 54 57 49 39 43 362]
[ 126 139 99 90 89 78 75 74 69 52 609]
[ 177 185 147 180 123 114 89 99 93 79 714]
[ 260 236 208 201 165 137 140 128 118 80 827]
[ 285 241 237 202 198 176 150 139 147 100 1125]
[ 339 348 277 227 228 216 174 178 143 148 1222]
[ 403 356 293 277 304 228 219 213 150 175 1382]
[ 454 414 340 342 274 286 220 202 210 174 1584]
[ 557 518 426 404 340 321 262 249 230 205 1488]
[ 569 520 452 389 380 302 297 262 238 222 1869]]
sample: 100%|██████████| 1500/1500 [00:11<00:00, 134.33it/s, 3 steps of size 6.81e-01. acc. prob=0.88]
estimates b: [0.10174821 0.09040698 0.10257039 0.11098085 0.09811853 0.10506512
0.10538018 0.10456848 0.12049202 0.10804637]
numpyro.infer.initialize.init_to_sample
is the way to initialize the distribution. Init to uniform allows values anywhere on the unconstrained scales, including values that the solver can’t handle
b, probs, lethality = generate_lethality_from_parametric_function(b_mean=5)
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model_trunc_prior, init_strategy=init.init_to_sample)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)
# allow 5% deviation from any true probability value
np.testing.assert_allclose(b_est, b, atol=0.1, rtol=0.1)
hazard rates:
[5.29213585 5.37714278 5.2645693 4.67145342 5.7766158 5.01292862
5.8152567 5.7321696 4.39097979 4.52223617]
lethality:
[[ 991 9 0 0 0 0 0 0 0 0 0]
[1491 9 0 0 0 0 0 0 0 0 0]
[1989 11 0 0 0 0 0 0 0 0 0]
[2479 20 1 0 0 0 0 0 0 0 0]
[2994 6 0 0 0 0 0 0 0 0 0]
[3475 25 0 0 0 0 0 0 0 0 0]
[3982 18 0 0 0 0 0 0 0 0 0]
[4487 12 1 0 0 0 0 0 0 0 0]
[4925 75 0 0 0 0 0 0 0 0 0]
[5440 58 2 0 0 0 0 0 0 0 0]]
sample: 100%|██████████| 1500/1500 [00:13<00:00, 111.99it/s, 3 steps of size 6.67e-01. acc. prob=0.89]
estimates b: [4.7346883 5.1417646 5.2161317 4.748071 6.2511463 4.9498925 5.4192424
5.792143 4.2159777 4.496989 ]
And finally we can sample fast 🚀 even from extreme survival data 🎉. The sampler still slows down when lethality data become more extreme, but it can still handle it well. Despite setting throw_exception=True
no Runtime errors are raised by the solver. It seems to be improving the situation quite a bit.
b, probs, lethality = generate_lethality_from_parametric_function(b_mean=10)
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model_trunc_prior, init_strategy=init.init_to_sample)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)
# allow 5% deviation from any true probability value
try:
np.testing.assert_allclose(b_est, b, atol=0.1, rtol=0.1)
except AssertionError:
print("The mean model parameter estimates diverge from the true values, because the uncertainty is very large")
print("But, HDI estimates contain the true value 🎉")
idata = az.convert_to_inference_data(
{"b":np.expand_dims(samples_conditional_binomial["b"], axis=0)},
coords={"sample":range(10)},
dims={"b":["chain", "draw", "sample"]}
)
az.plot_forest(idata, hdi_prob=0.94)
_ = plt.plot(b, 22.4-np.arange(0,10)* 2.5, marker="o", ls="", color="black", ms=3)
hazard rates:
[10.37056053 8.4157622 8.51890879 12.41824709 8.38021188 10.59177766
9.18675297 9.26523369 8.20162744 9.13863687]
lethality:
[[1000 0 0 0 0 0 0 0 0 0 0]
[1500 0 0 0 0 0 0 0 0 0 0]
[2000 0 0 0 0 0 0 0 0 0 0]
[2500 0 0 0 0 0 0 0 0 0 0]
[2999 1 0 0 0 0 0 0 0 0 0]
[3499 1 0 0 0 0 0 0 0 0 0]
[4000 0 0 0 0 0 0 0 0 0 0]
[4500 0 0 0 0 0 0 0 0 0 0]
[4997 3 0 0 0 0 0 0 0 0 0]
[5498 2 0 0 0 0 0 0 0 0 0]]
sample: 100%|██████████| 1500/1500 [04:42<00:00, 5.32it/s, 31 steps of size 1.08e-01. acc. prob=0.94]
estimates b: [35.194286 36.551746 30.150806 42.831306 8.2548485 8.377418
83.403984 42.569454 7.504727 8.017168 ]
The mean model parameter estimates diverge from the true values, because the uncertainty is very large
But, HDI estimates contain the true value 🎉
Slow performance at such survival data is to be expected, because here any parameter value > 10 would produce the same results. this makes for a difficult posterior geometry and for values that lead probably to a higher number of solver steps. Still the probability model can work with it well.
See also, how the uncertainty intervals become much smaller as soon only one organism survives the first interval. The other intervals are much wider, because the data don’t contain information about the slope of the survival curve beyond all the possible slopes that can lead to 100% mortality in the first day in a binomial world.
Old conditional binomial model did not work well with wider priors
def conditional_binomial_model_trunc_cauchy(lethality, n_trials):
batch_size, observation_times = lethality.shape
# linear hazard model to artificially produce high hazards
b = numpyro.sample("b", numpyro.distributions.TruncatedCauchy(scale=np.repeat([1.0],repeats=10), low=0, high=1000))
atol = 1e-5
sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-4))(b)
# Survival probability
S = sol.ys[:,:,1]
S_ = clip(jnp.column_stack([S, jnp.zeros_like(S[:,[0]])]), atol*10)
# conditional survival probability S(t) / S(t-1)
S_cond = clip(S_[:, 1:] / S_[:, :-1],atol*10)
# We have to do the same with the data and subtract it from the number of trials
# to obtain the number of survivors at time t [0, t1, t2, inf]
L_cumulative_obs = jnp.column_stack([jnp.zeros(batch_size, dtype=int), lethality.cumsum(axis=1)])
# the number of survivors at time t [0, t1, t2, inf]
S_obs = n_trials.reshape((batch_size, 1)) - L_cumulative_obs
# obtain the number of survivors before the start of the next interval [0, t1, t2]
# also note that the last observation at infinity is only theoretical.
S_before_t = S_obs[:, :-1]
# here we can exploit the fact that due to all our transformations, all
# observations are i.i.d they are conditionally independent that the
# organisms have survived until the respective beginning of the time window
with numpyro.plate("time", observation_times):
with numpyro.plate("batch", batch_size):
counts = numpyro.sample("lethality", dist.Binomial(total_count=S_before_t, probs=S_cond), obs=S_obs[:, 1:])
return counts
b, probs, lethality = generate_lethality_from_parametric_function(b_mean=0.1)
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_model_trunc_cauchy)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)
# allow 10% deviation from any true probability value
try:
np.testing.assert_allclose(b_est, b, atol=0.1, rtol=0.1)
except AssertionError:
print("The mean model parameter estimates diverge from the true values.")
idata = az.convert_to_inference_data(
{"b":np.expand_dims(samples_conditional_binomial["b"], axis=0)},
coords={"sample":range(10)},
dims={"b":["chain", "draw", "sample"]}
)
print(az.hdi(idata.posterior.b.values))
az.plot_forest(idata, hdi_prob=0.94)
_ = plt.plot(b, 22.4-np.arange(0,10)* 2.5, marker="o", ls="", color="black", ms=3)
hazard rates:
[0.09596538 0.09692836 0.09255667 0.08576109 0.09105337 0.10334469
0.08884172 0.09372573 0.11784689 0.11192419]
lethality:
[[ 101 81 81 57 62 53 45 44 38 40 398]
[ 143 136 127 114 85 81 74 74 73 58 535]
[ 171 147 143 114 134 110 107 98 95 78 803]
[ 222 142 194 167 139 122 125 117 96 88 1088]
[ 267 245 216 202 180 179 136 128 122 104 1221]
[ 354 313 318 226 211 195 196 173 155 118 1241]
[ 343 305 288 249 233 226 205 213 138 160 1640]
[ 354 364 350 313 313 255 240 206 208 188 1709]
[ 554 490 434 395 354 289 257 243 221 173 1590]
[ 619 502 480 403 335 330 296 260 261 214 1800]]
sample: 100%|██████████| 1500/1500 [37:37<00:00, 1.51s/it, 229 steps of size 8.90e-03. acc. prob=0.90]
estimates b: [ 0.09291796 0.10308105 0.0903046 0.08379658 26.807642 0.1040244
0.08909638 83.624176 18.493137 19.66022 ]
The mean model parameter estimates diverge from the true values.
[[8.56394246e-02 1.00866236e-01]
[9.69580039e-02 1.09456278e-01]
[8.59246626e-02 9.49095562e-02]
[7.95475245e-02 8.80869478e-02]
[8.13150120e+00 6.32290077e+01]
[1.00612350e-01 1.07925475e-01]
[8.56746361e-02 9.24930796e-02]
[8.14574051e+00 3.29237640e+02]
[8.12389755e+00 4.51659050e+01]
[8.14356709e+00 5.07798958e+01]]
This behavior reminds me very well of the previous numpyro runs. Although the problem is simple ($b=0.1$), the algorithm takes a very long time. This is most likely caused by nan gradients that are sporadically encountered by the sampler, due to the underflow problems. This even reproduces the runs that seem to converge okay and then suddenly drift off to a local minimum that is complete nonsense.
In the next part of the series Conditional survival part 3: Missing values we will look at missing values. So far our models were only confronted with data that had observations for each time point. What happens when this is not the case?