In Conditional survival part 2: Survival probabilities very close to zero, we verified that the multinomial and the conditional binomial probability models produce identical likelihoods when calculated with numpyro. In this part of the series, we take a look at extreme survival observations and survival probabilities very close to zero. We will find out when these break the solver, disrupt the sampler and step by step identify the causes of the problem and solutions to deal with it.

import jax
import numpy as np
import arviz as az

from matplotlib import pyplot as plt
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from diffrax import diffeqsolve, ODETerm, PIDController, SaveAt, Tsit5, Euler
from numpyro.infer import initialization as init

rng = np.random.default_rng(seed=1)

Testing the conditional binomial model for values close to zero

In previous case studies, I observed that under extreme survival data (i.e. very fast mortality), the solver reaches a large number of steps and inference becomes difficult. This suggests that the error model and the solver interact in a way that is problematic when the survival function approaches zero. To investigate, I’m simulating a simple model with a constant hazard rate and pay attention to the interaction between solver tolerances and behavior of the survival function close to zero, and especially the values of the conditional survival und different post processing schemes, which enter the likelihood function.


# generate survival data 
rng = np.random.default_rng(1)
t = np.arange(0,11)
n_trials = np.array([10, 15, 20, 25, 30, 35, 40, 45, 50, 55]) * 100

def generate_lethality_from_parametric_function(b_mean):
    """Large b (> 1) generate steep survival functions, and small b (< 1) 
    generate, moderately sloped survival functions
    """
    b = rng.lognormal(np.log(b_mean), 0.1, size=(10))
    print("hazard rates:\n", b)
    probs = -1 * np.diff(np.exp(- (np.einsum("t,k->kt", t, b))))
    probs = np.column_stack([probs, 1 - probs.sum(axis=1)])
    lethality = np.array(list(map(lambda n, p: rng.multinomial(n=n, pvals=p), n_trials, probs)))
    print("lethality:\n", lethality)
    return b, probs, lethality

Programmatically, I’m approaching the problem with diffrax, a autodifferentiable ODE solver, which is built on JAX just like numpyro.


def ode(t, y, params):
    H, S = y
    b, = params

    # constant hazard
    h = b
    dH_dt = h
    dS_dt = -h * S
    return jnp.array([dH_dt, dS_dt])


def solve(t, b, atol, rtol):
    sol = diffeqsolve(
        terms=ODETerm(ode), # type: ignore
        solver=Tsit5(),
        t0=t.min(),
        t1=t.max(),
        dt0=0.1, # type: ignore
        y0=jnp.array([0.0, 1.0], ndmin=1), # type: ignore
        saveat=SaveAt(ts=jnp.array(t)), # type: ignore
        args=(b,), # type: ignore
        # max_steps=int(1e7),
        stepsize_controller=PIDController(rtol=rtol, atol=atol), # type: ignore
    )

    return sol

# parameters
t_ = np.linspace(0,50,1000)
mf = 3.0
atol = 1e-8
b_mean = 3

truncate = lambda y, eps: jnp.trunc(y/(eps*mf))*(eps*mf)
clip = lambda y, eps: jnp.clip(y, eps*mf, 1.0-eps*mf)
maxim = lambda y, eps: jnp.maximum(y, eps*mf)
softmax = lambda y, eps: jnp.maximum(y, eps) + (eps-jnp.maximum(y, eps)) / (1 + jnp.exp(0.0001 * (jnp.log(jnp.maximum(y, eps)) - jnp.log(eps) )))
as_is = lambda y, eps: y


sols = jax.vmap(lambda atol: solve(t_, b=b_mean, atol=atol, rtol=1e-7).ys)(np.array([atol]))

fig, (ax1, ax2) = plt.subplots(1,2, figsize=(14,5))
ax1.hlines(atol * mf,t_.min(), t_.max(), color="black")
ax1.hlines(-atol * mf,t_.min(), t_.max(), color="black")
ax1.plot(t_, sols[:,:,1].T, alpha=1, lw=1, label="untransformed", color="black")
ax1.plot(t_, truncate(sols[:,:,1].T, atol), alpha=1, lw=1, label="truncated", color="tab:blue")
ax1.plot(t_, clip(sols[:,:,1].T, atol*0.9), alpha=1, lw=1, label="clipped", color="tab:orange")
ax1.plot(t_, np.exp(-sols[:,:,0].T), alpha=1, lw=1.5, label="untruncated exp(-H)", ls="--", color="tab:green")
# ax1.plot(t_, softmax(sols[:,:,1].T, atol/10**30), alpha=1, lw=1, label="softmax")
ax1.set_ylim(-atol*10,atol*10)
ax1.set_xlabel("Time")
ax1.set_ylabel("S")
ax1.set_xlim(t_.min(), t_.max())
ax1.legend()
# ax1.set_yscale("log")
ax1.set_xlim(t_.min(),t_.max()/1)

ax2.plot(t_[1:], sols[:, 1:, 1].T / sols[:, :-1, 1].T, label="untransformed", color="black")
ax2.plot(t_[1:], truncate(sols[:, 1:, 1].T, atol) / truncate(sols[:, :-1, 1].T, atol), label="truncated", color="tab:blue")
ax2.plot(t_[1:], clip(sols[:, 1:, 1].T, atol) / clip(sols[:, :-1, 1].T, atol), label="clipped", color="tab:orange")
ax2.plot(t_[1:], jnp.exp(-(sols[:, 1:, 0].T - sols[:, :-1, 0].T)), label="untruncated (H)", ls="--", color="tab:green", lw=1.5)
# ax2.plot(t_[1:], softmax(sols[:, 1:, 1].T, atol/10**30) / softmax(sols[:, :-1, 1].T, atol/10**30), label="softmax")
ax2.set_xlim(t_.min(),t_.max()/1)
ax2.set_ylim(0.0,1.1)
ax2
ax2.legend()
<matplotlib.legend.Legend at 0x7f9e40520c50>

png

It turns out that any post processing or the untransformed survival function will be massively problematic for the likelihood function.

  • untransformed: Values above and below zero are possible, highly fluctuating. Therefore, unusable for a probability vector
  • truncated: Erratic before zero is reached. Unusable
  • clipped: relatively stable but the transition to 1.0 conditional probability is wrong, since the hazard is constant. The conditional survival probability should also be constant

When using a constant step size, I also noticed that the conditional survival function makes a switch at approximately $t=30$, leading to a transition to 1.0, due to automatic truncation of floating points when the minimal floating point value is reached. This is a hard constraint for the solver that can only be remedied by using higher floating point bit size.

Using the cumulative (integrated) hazard directly to compute conditional survival

In this moment I realized that I can use the cumulative hazard directly in the calculation of the conditional survival, as it is much better behaved and does not produce underflow issues. The standard calculation of the conditional survival function is:

$$\Pr(t < T~|~t_0 < T) = \frac{S(t)}{S(t_0)}$$

However, $S(t) = \exp(-H(t))$ can be substituted into the equation. Leading to:

$$\Pr(t < T~|~t_0 < T) = \frac{e^{-H(t)}}{e^{-H(t_0)}} = e^{-H(t) - (- H(t_0))} = e^{-H(t) + H(t_0)} = e^{-(H(t) - H(t_0))} = e^{H(t_0) - H(t)}$$

Take some time to think about the equation. It makes sense. If for instance the cumulative hazard becomes constant, due to a zero hazard rate, then the exponent becomes zero and the conditional survival probability goes towards 1.0. Conversly if the hazard rate is very large the exponent goes towards negative infinity and the conditional survival probability approaches one. Also here we have a hard constraint on how extreme the hazard rate can become, but this is exponentially larger than when using the survival function.


def conditional_binomial_model(lethality, n_trials):
    batch_size, observation_times = lethality.shape 

    # linear hazard model to artificially produce high hazards
    b = numpyro.sample("b", numpyro.distributions.HalfNormal(scale=np.repeat([10.0],repeats=10)))
    atol = 1e-5
    sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-4))(b)

    # Survival probability
    S = sol.ys[:,:,1]
    S_ = clip(jnp.column_stack([S, jnp.zeros_like(S[:,[0]])]), atol*10)

    # conditional survival probability S(t) / S(t-1)
    S_cond = clip(S_[:, 1:] / S_[:, :-1],atol*10)

    # We have to do the same with the data and subtract it from the number of trials
    # to obtain the number of survivors at time t [0, t1, t2, inf]
    L_cumulative_obs = jnp.column_stack([jnp.zeros(batch_size, dtype=int), lethality.cumsum(axis=1)])

    # the number of survivors at time t [0, t1, t2, inf]
    S_obs = n_trials.reshape((batch_size, 1)) - L_cumulative_obs

    # obtain the number of survivors before the start of the next interval [0, t1, t2]
    # also note that the last observation at infinity is only theoretical.
    S_before_t = S_obs[:, :-1]
    
    # here we can exploit the fact that due to all our transformations, all
    # observations are i.i.d they are conditionally independent that the 
    # organisms have survived until the respective beginning of the time window
    with numpyro.plate("time", observation_times):
        with numpyro.plate("batch", batch_size):
            counts = numpyro.sample("lethality", dist.Binomial(total_count=S_before_t, probs=S_cond), obs=S_obs[:, 1:])
    return counts


b, probs, lethality = generate_lethality_from_parametric_function(b_mean=0.1)

# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)

# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)

# allow 5% deviation from any true probability value
np.testing.assert_allclose(b_est, b, atol=0.05)
hazard rates:
 [0.10351625 0.10856315 0.10335957 0.08778182 0.10947605 0.10456487
 0.09477208 0.10598335 0.103713   0.10298501]
lethality:
 [[  95   91   78   74   70   71   46   40   37   35  363]
 [ 169  117  136  114   90   83   85   55   74   70  507]
 [ 188  172  181  163  117  119  115  103   88   85  669]
 [ 210  181  174  178  155  149  129  106   99   99 1020]
 [ 329  279  235  232  198  182  159  165  120  107  994]
 [ 353  311  271  273  262  191  169  163  139  129 1239]
 [ 369  344  286  250  242  243  207  175  202  140 1542]
 [ 441  420  358  316  301  266  255  232  194  152 1565]
 [ 492  450  394  376  300  285  251  249  214  190 1799]
 [ 524  443  448  327  360  317  305  281  235  245 2015]]


sample: 100%|██████████| 1500/1500 [00:10<00:00, 143.48it/s, 3 steps of size 7.37e-01. acc. prob=0.86]


estimates b: [0.10272734 0.10828657 0.10805375 0.0894981  0.11091767 0.1046254
 0.09517204 0.10565186 0.10257991 0.09936014]

The above function works well for moderate hazards and delivers unbiased estimates. Clipping gets rid of the most easily diagnosable evils. Namely probability above 1 and below 0. These will lead to inifities in the likelihood function and are thus observable even when running the model forward. I.e. computing the likelihood. This means the model does not throw errors, but as previously seen the conditional probabilites are quite mad. If we would not clip the results, the model would not compute at all. Nevertheless the model is still severly wrong, as we will see when we confront the model with more problematic likelihoods


b, probs, lethality = generate_lethality_from_parametric_function(b_mean=3)

# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)

try:
    mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
except RuntimeError as err:
    print(err.args[0].split("\n")[0])

hazard rates:
 [2.89542491 3.09905782 2.89982806 2.98213132 3.07464539 2.78416219
 3.21069122 2.86229079 2.75011732 3.02319897]
lethality:
 [[ 937   58    5    0    0    0    0    0    0    0    0]
 [1443   53    4    0    0    0    0    0    0    0    0]
 [1893  102    4    1    0    0    0    0    0    0    0]
 [2389  105    6    0    0    0    0    0    0    0    0]
 [2868  129    3    0    0    0    0    0    0    0    0]
 [3259  223   16    1    1    0    0    0    0    0    0]
 [3848  150    2    0    0    0    0    0    0    0    0]
 [4210  267   22    1    0    0    0    0    0    0    0]
 [4683  299   17    1    0    0    0    0    0    0    0]
 [5223  263   12    2    0    0    0    0    0    0    0]]


  0%|          | 0/1500 [00:06<?, ?it/s]

INTERNAL: Generated function failed: CpuCallback error: EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.

The solver breaks down because it runs out of steps, because supposedly there are weird gradients

from functools import partial
loglik = lambda x: numpyro.infer.log_likelihood(partial(conditional_binomial_model, lethality, n_trials), x)["lethality"].sum()

theta = {"b": jnp.expand_dims(jnp.linspace(8.106, 8.11, 10), axis=0)}
print("Log likelihood:", loglik(theta))
print("Grad Log-likelihood:", jax.grad(loglik)(theta))
Log likelihood: -19883.049
Grad Log-likelihood: {'b': Array([[ -16682.068  ,  -17545.986  ,  -39604.836  ,  -50014.46   ,
         -77475.55   , -182425.44   , -195783.94   ,    -289.02518,
              0.     ,       0.     ]], dtype=float32)}

Interestingly, the gradients and the log likelihoods look okay. The only weird thing is that they go to zero at large values. There is a parameter range in which a very strong transition occurs.

The conditional binomial hazard model

Next, I’m looking at the binomial hazard model that I had previously identified as a numerically more stable alternative to using the survival probabilities directly in the computation of the conditional probabilities.

$$\Pr(t < T~|~t_0 < T) = \frac{S(t)}{S(t_0)} = e^{H(t_0) - H(t)}$$

def conditional_binomial_hazard_model(lethality, n_trials):
    batch_size, observation_times = lethality.shape 

    # linear hazard model to artificially produce high hazards
    b = numpyro.sample("b", numpyro.distributions.HalfNormal(scale=np.repeat([10.0],repeats=10)))
    atol = 1e-6
    sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-3))(b)

    # Survival probability
    H = sol.ys[:,:,0]
    H = jnp.column_stack([H, jnp.full_like(H[:, 0], jnp.inf)])

    # conditional survival probability S(t) / S(t-1)
    S_cond = jnp.exp(- (H[:, 1:] - H[:, :-1]))

    # TODO: This may not be necessary, because the data are not connected to any
    # variable. Also if this is done with the last modeled survival probability
    # because then the conditional probability become 0 / S(t[-1]) = 0
    # S_cond = jnp.clip(jnp.column_stack([S_cond, jnp.zeros(batch_size)]), 1e-20, 1-1e-20)

    # We have to do the same with the data and subtract it from the number of trials
    # to obtain the number of survivors at time t [0, t1, t2, inf]
    L_cumulative_obs = jnp.column_stack([jnp.zeros(batch_size, dtype=int), lethality.cumsum(axis=1)])

    # the number of survivors at time t [0, t1, t2, inf]
    S_obs = n_trials.reshape((batch_size, 1)) - L_cumulative_obs

    # obtain the number of survivors before the start of the next interval [0, t1, t2]
    # also note that the last observation at infinity is only theoretical.
    S_before_t = S_obs[:, :-1]
    
    # lp = dist.Binomial(total_count=S_before_t, probs=S_cond).log_prob(S_obs[:, 1:-1]).sum()
    # jax.debug.print("{x}", x=lp)
    # here we can exploit the fact that due to all our transformations, all
    # observations are i.i.d they are conditionally independent that the 
    # organisms have survived until the respective beginning of the time window
    with numpyro.plate("time", observation_times):
        with numpyro.plate("batch", batch_size):
            counts = numpyro.sample("lethality", dist.Binomial(total_count=S_before_t, probs=S_cond), obs=S_obs[:, 1:])
    return counts



b, probs, lethality = generate_lethality_from_parametric_function(b_mean=0.1)

# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)

# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)

# allow 5% deviation from any true probability value
np.testing.assert_allclose(b_est, b, atol=0.05)

hazard rates:
 [0.09530118 0.09322559 0.10139153 0.09713273 0.1154754  0.10000202
 0.10329215 0.10998812 0.09703722 0.11545073]
lethality:
 [[  91   86   83   65   62   55   59   44   47   28  380]
 [ 126  124  117  106   85   78   62   69   67   50  616]
 [ 203  147  155  146  130  117  116   83   83   78  742]
 [ 231  218  192  163  165  150  131  134   95   95  926]
 [ 335  337  279  221  215  163  158  141  126  125  900]
 [ 333  298  286  262  214  191  201  157  143  148 1267]
 [ 367  337  327  288  258  225  201  188  202  164 1443]
 [ 472  447  377  356  308  254  234  212  184  167 1489]
 [ 481  440  359  338  300  299  264  220  216  189 1894]
 [ 588  549  490  419  371  317  287  300  216  187 1776]]


sample: 100%|██████████| 1500/1500 [00:11<00:00, 133.45it/s, 3 steps of size 7.16e-01. acc. prob=0.88]


estimates b: [0.09747483 0.08992738 0.09922651 0.09927562 0.12132029 0.10146473
 0.10105705 0.11152633 0.09744926 0.11399455]

The conditional binomial hazard model also is an unbiased estimator of the parameters. One open question is whether it is necessary to include the remaining survival in the likelihood function

b, probs, lethality = generate_lethality_from_parametric_function(b_mean=2)

# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model, init_strategy=init.init_to_value(values={"b": np.repeat([2], 10)}))
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
try:
    mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)
except RuntimeError as err:
    print(err.args[0].split("\n")[0])
hazard rates:
 [1.7656863  2.11675695 2.41223889 2.25432954 2.20837715 2.00403966
 2.20672677 1.81581408 2.15598705 1.98262622]
lethality:
 [[ 841  134   20    5    0    0    0    0    0    0    0]
 [1316  168   16    0    0    0    0    0    0    0    0]
 [1820  160   18    2    0    0    0    0    0    0    0]
 [2230  242   26    2    0    0    0    0    0    0    0]
 [2653  310   33    3    1    0    0    0    0    0    0]
 [3012  436   44    5    2    1    0    0    0    0    0]
 [3534  423   36    7    0    0    0    0    0    0    0]
 [3730  641  111   18    0    0    0    0    0    0    0]
 [4414  511   65    9    1    0    0    0    0    0    0]
 [4726  669   90   12    3    0    0    0    0    0    0]]


  0%|          | 0/1500 [00:06<?, ?it/s]

INTERNAL: Generated function failed: CpuCallback error: EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.

Difficulties of the solver to initialize

The model can also not be computed. After intense debugging (debug print statements after drawing samples and after computing likelihoods were the key), I realized that under strong declines of the survival functions (extreme hazards), the solver breaks down this happens at hazard rates beyond 1000.

Apparently just after initialization, the sampler explored extreme parameters. This is probably driven by the gradients, which in turn are driven by the survival samples, because it does not happen under well distributed survival observations. This problem cannot be solved by using a different post-processing of the state variables, because the problem arises from little information contained in the data.

If extreme values of $b$ can also reproduce the data, then the sampler should and will explore these. Only setting hard bounds on this prevents the sampler from running into extreme regions that the solver cannot handle.

from functools import partial
loglik = lambda x: numpyro.infer.log_likelihood(partial(conditional_binomial_hazard_model, lethality, n_trials), x)["lethality"].sum()

theta = {"b": jnp.expand_dims(jnp.linspace(0,100, 10), axis=0)}
print("Log likelihood:", loglik(theta))
print("Grad Log-likelihood:", jax.grad(loglik)(theta))


try:
    theta = {"b": jnp.expand_dims(jnp.linspace(0,10000, 10), axis=0)}
    print("Log likelihood:", loglik(theta))
    print("Grad Log-likelihood:", jax.grad(loglik)(theta))

except RuntimeError as err:
    print("Computation for extreme values of 'b' > 100")
    print(err.args[0].split("\n")[0])
Log likelihood: -309155.6
Grad Log-likelihood: {'b': Array([[   0.     , -199.97758, -202.     , -300.00006, -388.99997,
        -551.99994, -516.00006, -917.0001 ,        nan,        nan]],      dtype=float32)}
Computation for extreme values of 'b' > 100
The maximum number of solver steps was reached. Try increasing `max_steps`.

So, in order to avoid extreme parameter samples I’m using a TruncatedCauchy distribution. In addition, I avoid the computation of the probability from the last observation to infinity, because the information about the number of alive organisms is already accounted for by comparing the last survival probability to the data.

def conditional_binomial_hazard_model_trunc_prior(lethality, n_trials):
    batch_size, observation_times = lethality.shape 

    # linear hazard model to artificially produce high hazards
    b = numpyro.sample("b", numpyro.distributions.TruncatedCauchy(scale=np.repeat([1.0],repeats=10), low=0, high=1000))
    atol = 1e-6
    sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-3))(b)

    # Survival probability
    H = sol.ys[:,:,0]
    H = jnp.column_stack([H, jnp.full_like(H[:, 0], jnp.inf)])

    # conditional survival probability S(t) / S(t-1)
    S_cond = jnp.exp(- (H[:, 1:] - H[:, :-1]))

    # to obtain the number of survivors at time t [0, t1, t2, inf]
    L_cumulative_obs = jnp.column_stack([jnp.zeros(batch_size, dtype=int), lethality.cumsum(axis=1)])

    # the number of survivors at time t [0, t1, t2, inf]
    S_obs = n_trials.reshape((batch_size, 1)) - L_cumulative_obs

    # obtain the number of survivors before the start of the next interval [0, t1, t2]
    # also note that the last observation at infinity is only theoretical.
    S_before_t = S_obs[:, :-1]
    
    # organisms have survived until the respective beginning of the time window
    with numpyro.plate("time", observation_times):
        with numpyro.plate("batch", batch_size):
            counts = numpyro.sample("lethality", dist.Binomial(total_count=S_before_t, probs=S_cond), obs=S_obs[:, 1:])
    return counts


b, probs, lethality = generate_lethality_from_parametric_function(b_mean=0.1)

# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model_trunc_prior, init_strategy=init.init_to_sample)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)

# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)

# allow 5% deviation from any true probability value
np.testing.assert_allclose(b_est, b, atol=0.05)
hazard rates:
 [0.10011198 0.08686346 0.10129305 0.10928235 0.09759295 0.10757679
 0.1074746  0.10457167 0.11872323 0.1080997 ]
lethality:
 [[ 104   84   72   79   57   54   57   49   39   43  362]
 [ 126  139   99   90   89   78   75   74   69   52  609]
 [ 177  185  147  180  123  114   89   99   93   79  714]
 [ 260  236  208  201  165  137  140  128  118   80  827]
 [ 285  241  237  202  198  176  150  139  147  100 1125]
 [ 339  348  277  227  228  216  174  178  143  148 1222]
 [ 403  356  293  277  304  228  219  213  150  175 1382]
 [ 454  414  340  342  274  286  220  202  210  174 1584]
 [ 557  518  426  404  340  321  262  249  230  205 1488]
 [ 569  520  452  389  380  302  297  262  238  222 1869]]


sample: 100%|██████████| 1500/1500 [00:11<00:00, 134.33it/s, 3 steps of size 6.81e-01. acc. prob=0.88]

estimates b: [0.10174821 0.09040698 0.10257039 0.11098085 0.09811853 0.10506512
 0.10538018 0.10456848 0.12049202 0.10804637]

numpyro.infer.initialize.init_to_sample is the way to initialize the distribution. Init to uniform allows values anywhere on the unconstrained scales, including values that the solver can’t handle

b, probs, lethality = generate_lethality_from_parametric_function(b_mean=5)

# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model_trunc_prior, init_strategy=init.init_to_sample)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)

# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)

# allow 5% deviation from any true probability value
np.testing.assert_allclose(b_est, b, atol=0.1, rtol=0.1)
hazard rates:
 [5.29213585 5.37714278 5.2645693  4.67145342 5.7766158  5.01292862
 5.8152567  5.7321696  4.39097979 4.52223617]
lethality:
 [[ 991    9    0    0    0    0    0    0    0    0    0]
 [1491    9    0    0    0    0    0    0    0    0    0]
 [1989   11    0    0    0    0    0    0    0    0    0]
 [2479   20    1    0    0    0    0    0    0    0    0]
 [2994    6    0    0    0    0    0    0    0    0    0]
 [3475   25    0    0    0    0    0    0    0    0    0]
 [3982   18    0    0    0    0    0    0    0    0    0]
 [4487   12    1    0    0    0    0    0    0    0    0]
 [4925   75    0    0    0    0    0    0    0    0    0]
 [5440   58    2    0    0    0    0    0    0    0    0]]


sample: 100%|██████████| 1500/1500 [00:13<00:00, 111.99it/s, 3 steps of size 6.67e-01. acc. prob=0.89]

estimates b: [4.7346883 5.1417646 5.2161317 4.748071  6.2511463 4.9498925 5.4192424
 5.792143  4.2159777 4.496989 ]

And finally we can sample fast 🚀 even from extreme survival data 🎉. The sampler still slows down when lethality data become more extreme, but it can still handle it well. Despite setting throw_exception=True no Runtime errors are raised by the solver. It seems to be improving the situation quite a bit.

b, probs, lethality = generate_lethality_from_parametric_function(b_mean=10)

# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model_trunc_prior, init_strategy=init.init_to_sample)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)

# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)

# allow 5% deviation from any true probability value
try:
    np.testing.assert_allclose(b_est, b, atol=0.1, rtol=0.1)
except AssertionError:
    print("The mean model parameter estimates diverge from the true values, because the uncertainty is very large")

print("But, HDI estimates contain the true value 🎉")
idata = az.convert_to_inference_data(
    {"b":np.expand_dims(samples_conditional_binomial["b"], axis=0)},
    coords={"sample":range(10)},
    dims={"b":["chain", "draw", "sample"]}
)
az.plot_forest(idata, hdi_prob=0.94)
_ = plt.plot(b, 22.4-np.arange(0,10)* 2.5, marker="o", ls="", color="black", ms=3)
hazard rates:
 [10.37056053  8.4157622   8.51890879 12.41824709  8.38021188 10.59177766
  9.18675297  9.26523369  8.20162744  9.13863687]
lethality:
 [[1000    0    0    0    0    0    0    0    0    0    0]
 [1500    0    0    0    0    0    0    0    0    0    0]
 [2000    0    0    0    0    0    0    0    0    0    0]
 [2500    0    0    0    0    0    0    0    0    0    0]
 [2999    1    0    0    0    0    0    0    0    0    0]
 [3499    1    0    0    0    0    0    0    0    0    0]
 [4000    0    0    0    0    0    0    0    0    0    0]
 [4500    0    0    0    0    0    0    0    0    0    0]
 [4997    3    0    0    0    0    0    0    0    0    0]
 [5498    2    0    0    0    0    0    0    0    0    0]]


sample: 100%|██████████| 1500/1500 [04:42<00:00,  5.32it/s, 31 steps of size 1.08e-01. acc. prob=0.94]


estimates b: [35.194286  36.551746  30.150806  42.831306   8.2548485  8.377418
 83.403984  42.569454   7.504727   8.017168 ]
The mean model parameter estimates diverge from the true values, because the uncertainty is very large
But, HDI estimates contain the true value 🎉

png

Slow performance at such survival data is to be expected, because here any parameter value > 10 would produce the same results. this makes for a difficult posterior geometry and for values that lead probably to a higher number of solver steps. Still the probability model can work with it well.

See also, how the uncertainty intervals become much smaller as soon only one organism survives the first interval. The other intervals are much wider, because the data don’t contain information about the slope of the survival curve beyond all the possible slopes that can lead to 100% mortality in the first day in a binomial world.

Old conditional binomial model did not work well with wider priors


def conditional_binomial_model_trunc_cauchy(lethality, n_trials):
    batch_size, observation_times = lethality.shape 

    # linear hazard model to artificially produce high hazards
    b = numpyro.sample("b", numpyro.distributions.TruncatedCauchy(scale=np.repeat([1.0],repeats=10), low=0, high=1000))
    atol = 1e-5
    sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-4))(b)

    # Survival probability
    S = sol.ys[:,:,1]
    S_ = clip(jnp.column_stack([S, jnp.zeros_like(S[:,[0]])]), atol*10)

    # conditional survival probability S(t) / S(t-1)
    S_cond = clip(S_[:, 1:] / S_[:, :-1],atol*10)

    # We have to do the same with the data and subtract it from the number of trials
    # to obtain the number of survivors at time t [0, t1, t2, inf]
    L_cumulative_obs = jnp.column_stack([jnp.zeros(batch_size, dtype=int), lethality.cumsum(axis=1)])

    # the number of survivors at time t [0, t1, t2, inf]
    S_obs = n_trials.reshape((batch_size, 1)) - L_cumulative_obs

    # obtain the number of survivors before the start of the next interval [0, t1, t2]
    # also note that the last observation at infinity is only theoretical.
    S_before_t = S_obs[:, :-1]
    
    # here we can exploit the fact that due to all our transformations, all
    # observations are i.i.d they are conditionally independent that the 
    # organisms have survived until the respective beginning of the time window
    with numpyro.plate("time", observation_times):
        with numpyro.plate("batch", batch_size):
            counts = numpyro.sample("lethality", dist.Binomial(total_count=S_before_t, probs=S_cond), obs=S_obs[:, 1:])
    return counts


b, probs, lethality = generate_lethality_from_parametric_function(b_mean=0.1)

# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_model_trunc_cauchy)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)

# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)

# allow 10% deviation from any true probability value
try:
    np.testing.assert_allclose(b_est, b, atol=0.1, rtol=0.1)
except AssertionError:
    print("The mean model parameter estimates diverge from the true values.")

idata = az.convert_to_inference_data(
    {"b":np.expand_dims(samples_conditional_binomial["b"], axis=0)},
    coords={"sample":range(10)},
    dims={"b":["chain", "draw", "sample"]}
)
print(az.hdi(idata.posterior.b.values))
az.plot_forest(idata, hdi_prob=0.94)
_ = plt.plot(b, 22.4-np.arange(0,10)* 2.5, marker="o", ls="", color="black", ms=3)
hazard rates:
 [0.09596538 0.09692836 0.09255667 0.08576109 0.09105337 0.10334469
 0.08884172 0.09372573 0.11784689 0.11192419]
lethality:
 [[ 101   81   81   57   62   53   45   44   38   40  398]
 [ 143  136  127  114   85   81   74   74   73   58  535]
 [ 171  147  143  114  134  110  107   98   95   78  803]
 [ 222  142  194  167  139  122  125  117   96   88 1088]
 [ 267  245  216  202  180  179  136  128  122  104 1221]
 [ 354  313  318  226  211  195  196  173  155  118 1241]
 [ 343  305  288  249  233  226  205  213  138  160 1640]
 [ 354  364  350  313  313  255  240  206  208  188 1709]
 [ 554  490  434  395  354  289  257  243  221  173 1590]
 [ 619  502  480  403  335  330  296  260  261  214 1800]]


sample: 100%|██████████| 1500/1500 [37:37<00:00,  1.51s/it, 229 steps of size 8.90e-03. acc. prob=0.90]  


estimates b: [ 0.09291796  0.10308105  0.0903046   0.08379658 26.807642    0.1040244
  0.08909638 83.624176   18.493137   19.66022   ]
The mean model parameter estimates diverge from the true values.
[[8.56394246e-02 1.00866236e-01]
 [9.69580039e-02 1.09456278e-01]
 [8.59246626e-02 9.49095562e-02]
 [7.95475245e-02 8.80869478e-02]
 [8.13150120e+00 6.32290077e+01]
 [1.00612350e-01 1.07925475e-01]
 [8.56746361e-02 9.24930796e-02]
 [8.14574051e+00 3.29237640e+02]
 [8.12389755e+00 4.51659050e+01]
 [8.14356709e+00 5.07798958e+01]]

png

This behavior reminds me very well of the previous numpyro runs. Although the problem is simple ($b=0.1$), the algorithm takes a very long time. This is most likely caused by nan gradients that are sporadically encountered by the sampler, due to the underflow problems. This even reproduces the runs that seem to converge okay and then suddenly drift off to a local minimum that is complete nonsense.

In the next part of the series Conditional survival part 3: Missing values we will look at missing values. So far our models were only confronted with data that had observations for each time point. What happens when this is not the case?