In the last part Conditional survival part 2: Survival probabilities very close to zero, I dealt with boundary problems of extreme survival observations, which drive the survival function toward the precision boundary of the machine and lead to underflow behavior which makes the computation of the likelihood values problematic. Using the cumulative hazard directly in the computation of the conditional probability resolves the issue and simplifies the model. Below are the functions defined in part 1 and part 2.
import jax
import numpy as np
import arviz as az
from matplotlib import pyplot as plt
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from diffrax import diffeqsolve, ODETerm, PIDController, SaveAt, Tsit5, Euler
from numpyro.infer import initialization as init
rng = np.random.default_rng(seed=1)
# generate survival data
rng = np.random.default_rng(1)
t = np.arange(0,11)
n_trials = np.array([10, 15, 20, 25, 30, 35, 40, 45, 50, 55]) * 1000
def generate_lethality_from_parametric_function(b_mean):
"""Large b (> 1) generate steep survival functions, and small b (< 1)
generate, moderately sloped survival functions
"""
b = rng.lognormal(np.log(b_mean), 0.1, size=(10))
print("hazard rates:\n", b)
probs = -1 * np.diff(np.exp(- (np.einsum("t,k->kt", t, b))))
probs = np.column_stack([probs, 1 - probs.sum(axis=1)])
lethality = np.array(list(map(lambda n, p: rng.multinomial(n=n, pvals=p), n_trials, probs)))
print("lethality:\n", lethality)
return b, probs, lethality
def ode(t, y, params):
H, S = y
b, = params
# constant hazard
h = b
dH_dt = h
dS_dt = -h * S
return jnp.array([dH_dt, dS_dt])
def solve(t, b, atol, rtol):
sol = diffeqsolve(
terms=ODETerm(ode), # type: ignore
solver=Tsit5(),
t0=t.min(),
t1=t.max(),
dt0=0.1, # type: ignore
y0=jnp.array([0.0, 1.0], ndmin=1), # type: ignore
saveat=SaveAt(ts=jnp.array(t)), # type: ignore
args=(b,), # type: ignore
# max_steps=int(1e7),
stepsize_controller=PIDController(rtol=rtol, atol=atol), # type: ignore
)
return sol
Missing values
Missing values are a more intricate problem. First I’m generating a boolean mask of the size of the survival array. Why not the lethality array? Because it does not make sense to mask the incidence. In reality of the laboratory day, we record the surviving organisms, because not recording lethality in a given day (and discarding it, we don’t know how many organisms are alive). Of course we could mask lethality observations by accumulating the incidence in the masked days, but again using survival is a more straightforward choice.
b, probs, lethality = generate_lethality_from_parametric_function(b_mean=3)
def survival_from_lethality(lethality):
"""assumes that the death counts stem from the intervals.
Parameters
----------
lethality : NDarray[I, T, int]
Death counts from T intervals from T+1 timepoints where t=0 is the number
of organisms at the start of the experiment
Returns
-------
out : NDarray[I, T+1, int]
"""
batch_size, observation_times = lethality.shape
L_cumulative_obs = np.column_stack([np.zeros(batch_size, dtype=int), lethality.cumsum(axis=1)])
survival = n_trials.reshape((batch_size, 1)) - L_cumulative_obs
return survival.astype(float)
def preprocess_lethality(survival, mask):
# in order to get our naive model working, we need to fill forward the S_before_t observations
# first prepend survival at t=-1, in order to carry survival at t=0 forward
# survival = np.column_stack([survival[:, 0], survival])
# trials start at t=-1 (in order to include the "observation" at t=0)
# the first observation will contribute a log prob of zero to the overall
# likelihood, but it completes the observation array, so we keep it
trials = survival[:, :-1]
idx = np.where(mask[:, :-1].astype(bool),np.arange(mask[:, :-1].shape[1]),0)
np.maximum.accumulate(idx,axis=1, out=idx)
trials_ = trials[np.arange(idx.shape[0])[:,None], idx]
# just make sure that the first entry (the starting number of survivors is 1)
trials_ = np.column_stack([trials_[:, 0], trials_])
return survival, trials_.astype(int), mask
survival = survival_from_lethality(lethality)
# artificially generate masked observations
rng.bit_generator.state = np.random.PCG64(seed=12).state
mask = rng.binomial(n=1, p=0.8, size=survival.shape)
mask[:, 0] = 1
survival, trials, mask = preprocess_lethality(survival, mask)
survival = np.where(mask, survival, np.nan)
print("Comparing survival and trials\n", "Trials: ", trials[1], "\n Survival: ",survival[1])
print("Comparing survival and trials\n", "Trials: ", trials[3], "\n Survival: ",survival[3])
print("Mask:\n", mask)
np.testing.assert_array_equal(survival.shape, mask.shape)
np.testing.assert_array_equal(trials.shape, mask.shape)
hazard rates:
[3.1054875 3.2568944 3.10078714 2.63345472 3.2842814 3.1369461
2.84316243 3.17950047 3.11138986 3.08955028]
lethality:
[[ 9573 410 16 1 0 0 0 0 0 0 0]
[14399 579 22 0 0 0 0 0 0 0 0]
[19123 836 36 4 1 0 0 0 0 0 0]
[23166 1717 109 8 0 0 0 0 0 0 0]
[28871 1091 37 1 0 0 0 0 0 0 0]
[33479 1461 56 4 0 0 0 0 0 0 0]
[37677 2200 115 7 1 0 0 0 0 0 0]
[43163 1771 61 4 1 0 0 0 0 0 0]
[47739 2149 109 3 0 0 0 0 0 0 0]
[52526 2344 124 6 0 0 0 0 0 0 0]]
Comparing survival and trials
Trials: [15000 15000 601 22 0 0 0 0 0 0 0 0]
Survival: [15000. 601. 22. 0. 0. nan 0. 0. 0. nan
nan nan]
Comparing survival and trials
Trials: [25000 25000 25000 117 8 0 0 0 0 0 0 0]
Survival: [2.50e+04 nan 1.17e+02 8.00e+00 0.00e+00 0.00e+00 0.00e+00 0.00e+00
nan 0.00e+00 0.00e+00 0.00e+00]
Mask:
[[1 0 1 1 1 1 1 1 0 0 1 1]
[1 1 1 1 1 0 1 1 1 0 0 0]
[1 0 1 0 1 1 1 1 1 1 1 1]
[1 0 1 1 1 1 1 1 0 1 1 1]
[1 1 0 1 1 0 0 1 0 1 0 1]
[1 1 1 1 1 1 1 0 0 1 1 1]
[1 0 1 0 1 1 1 1 1 1 1 1]
[1 0 1 1 1 1 1 0 1 1 1 1]
[1 0 1 1 0 1 1 1 1 1 1 1]
[1 0 1 1 1 1 1 1 1 1 1 1]]
For the naive implementation of the masked data, we create a numpyro model where the binomial conditional survival distribution is masked
def conditional_binomial_hazard_model_trunc_prior_masked(survival, trials, mask):
batch_size, observation_times = survival.shape
b = numpyro.sample("b", numpyro.distributions.TruncatedCauchy(scale=np.repeat([1.0],repeats=10), low=0, high=1000))
atol = 1e-6
sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-3))(b)
H = sol.ys[:,:,0]
H = jnp.column_stack([jnp.zeros_like(H[:, 0]), H, jnp.full_like(H[:, 0], jnp.inf)])
S_cond = jnp.exp(- (H[:, 1:] - H[:, :-1]))
# S_before_t = S_obs[:, :-1]
with numpyro.plate("time", observation_times):
with numpyro.plate("batch", batch_size):
# use the .mask() method of the distribution
counts = numpyro.sample("lethality", dist.Binomial(total_count=trials, probs=S_cond).mask(mask), obs=survival)
return counts
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model_trunc_prior_masked, init_strategy=init.init_to_sample)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(1), survival=survival, trials=trials, mask=mask)
# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)
# allow 5% deviation from any true probability value
try:
np.testing.assert_allclose(b_est, b, atol=0.05)
except AssertionError:
print("The mean model parameter estimates diverge from the true values")
sample: 100%|██████████| 1500/1500 [00:18<00:00, 82.32it/s, 7 steps of size 7.34e-01. acc. prob=0.85]
estimates b: [6.3404913 3.220138 6.169644 5.305844 3.3157148 3.1383803 5.78178
6.4429493 6.081432 6.0035124]
The mean model parameter estimates diverge from the true values
idata = az.convert_to_inference_data(
{"b":np.expand_dims(samples_conditional_binomial["b"], axis=0)},
coords={"sample":range(10)},
dims={"b":["chain", "draw", "sample"]}
)
az.plot_forest(idata)
plt.plot(b, 22.4-np.arange(0,10)* 2.5, marker="o", ls="", color="black", ms=3)
[<matplotlib.lines.Line2D at 0x7f6c99e9d7d0>]
So, we can very clearly see that probability model overestimates the $b$ coefficient. This happens, when the first observation is masked. The mechanism is that the model uses S(t=2) for the survivors S(t=0) for the trials (because S(t=1) is masked), but uses S(t=2)/S(t=1) to calculate the conditional probability. This results in a disproportionally large conditional survival probability model for t=2, because S(t=1) is much smaller than S(t=0) = 1.0, which would be the correct survival probability to use in the calculation of the conditonal. This overly large conditional survival probability at $t=2$ can only be compensated by overestimating the slope of the survival function. Hence we observe overestimated $b$ parameters. When NANs ocurr later on, biases probably cancel each other out; the exact mechanism is unclear.
Using the correct conditional probabilities
Imagine, we know the survival probabilites $S(t=0) = 1.0$, $S(t=1) = 0.75$, $S(t=2) = 0.1$, $S(t=3) = 0.05$, $S(t=4) = 0.01$, $S(t=\infty) = 0.0$. A exemplary survival matrix (ID x Time) could look like this.
$$S^{obs} =\begin{bmatrix} t=0 & t=1 & t=2 & t=3 & t=4 & t=\infty\\ 100 & 80 & 5 & 5 & 1 & 0\\ 100 & NaN & 20 & 7 & 3 & 0 \\ 100 & NaN & NaN & 2 & 0 & 0 \\ 100 & 75 & NaN & 2 & 2 & 0 \\ 100 & NaN & 10 & NaN & 1 & 0 \\ \end{bmatrix}$$
The conditional survival probability vector for the case without nan values looks like this
$$S_c = \begin{bmatrix} t=0 & t=1 & t=2 & t=\infty\\ 1.0 & \frac{S(1)}{S(0)} & \frac{S(2)}{S(1)} & 0.0\\ \end{bmatrix}$$
This design will assume false values for time series that contain nan values.
The correct conditional probabilities are the fractions of survival probabilities that span the entire interval in which no observations have been made. This can be achieved by removing nans For the above observation matrix the conditional probability matrix looks like this:
$$S_c = \begin{bmatrix} t=0 & t=1 & t=2 & t=3 & t=4 & t=\infty\\ 1.0 & \frac{S(1)}{S(0)} = 0.75 & \frac{S(2)}{S(1)} = 0.133 & \frac{S(3)}{S(2)} = 0.5 & \frac{S(4)}{S(3)} = 0.2 & 0.0\\ 1.0 & NaN & \mathbf{\frac{S(2)}{S(0)} = 0.1} & \frac{S(3)}{S(2)} = 0.5 & \frac{S(4)}{S(3)} = 0.2 & 0.0\\ 1.0 & NaN & NaN & \mathbf{\frac{S(3)}{S(0)} = 0.05} & \frac{S(4)}{S(3)} = 0.2 & 0.0\\ 1.0 & \frac{S(1)}{S(0)} = 0.75 & NaN & \mathbf{\frac{S(3)}{S(1)} = 0.0667} & \frac{S(4)}{S(3)} = 0.2 & 0.0\\ 1.0 & NaN & \mathbf{\frac{S(2)}{S(0)} = 0.1} & NaN & \mathbf{\frac{S(4)}{S(2)} = 0.1} & 0.0\\ \end{bmatrix}$$
Note that $S_{c,1}(2)=\frac{S(2)}{S(0)}$ instead of $\frac{S(2)}{S(1)}$, resulting in a conditional survival probability of 0.1 instead of 0.133. The stronger the survival difference in the previous NaN interval, the more pronounced is the bias. Let’s reproduce the above example in Python
Mathematically efficient way to use the correct survival probabilities.
In a 1-D case, when we only had one observation, the issue would not arise at all. However, vectoriation is usually computationally more efficient, so we need to find an efficient way to select the right survival probabilities $S(t)$ in the calculation of the conditional survival probability.
S = np.array([1.0, 0.75, 0.1, 0.05, 0.01, 0.0])
S_obs = np.array([
[100, 80, 5, 5, 1, 0],
[100, np.nan, 20, 7, 3, 0],
[100, np.nan, np.nan, 2, 0, 0],
[100, 75, np.nan, 2, 2, 0],
[100, np.nan, 10, np.nan, 1, 0],
])
S_c_true = np.array([
[1.0 , 0.75 , 0.133 , 0.5 , 0.2 , 0.0],
[1.0 , np.nan , 0.1 , 0.5 , 0.2 , 0.0],
[1.0 , np.nan , np.nan , 0.05 , 0.2 , 0.0],
[1.0 , 0.75 , np.nan , 0.067 , 0.2 , 0.0],
[1.0 , np.nan , 0.1 , np.nan , 0.1 , 0.0]
])
# Produce a mask to see which indices are NaN
mask_obs = jnp.isnan(S_obs)
# clones the survival probability vector into a matrix, to compute the conditional probabilities for each ID i
S_i = np.tile(S, (len(S_obs),1))
# computation of the conditional probability:
# Option 1 (while loop):
# results in a number of iterations equivalent to the maximum number of consecutive nans.
S_i_corrected = np.where(mask_obs, np.nan, S_i)
while np.isnan(S_i_corrected).sum() > 0:
S_i_corrected[:, 1:] = np.where(mask_obs[:, 1:], S_i_corrected[:,:-1], S_i_corrected[:, 1:])
# Option 2a (accumulate):
# Does not work because jnp does not support maximum accumulate
idx = np.where(~mask_obs,np.arange(mask_obs.shape[1]),0)
idx = np.maximum.accumulate(idx,axis=1)
out = S_i[np.arange(idx.shape[0])[:,None], idx]
# Option 2b (manual accumulate)
# accumulate is not implemented in jax
def max1(a,b):
return (a + b + abs(a - b)) / 2
maximum = jnp.frompyfunc(max1, nin=2, nout=1)
idx = np.where(~mask_obs,np.arange(mask_obs.shape[1]),0)
idx = maximum.accumulate(jnp.array(idx), axis=1).astype(int)
out = S_i[np.arange(idx.shape[0])[:,None], idx]
# Option 2c (accumulate with jnp.maximum)
maximum = jnp.frompyfunc(jnp.maximum, nin=2, nout=1, identity=None)
idx = np.where(~mask_obs,np.arange(mask_obs.shape[1]),0)
idx_ = maximum.accumulate(jnp.array(idx), axis=1)
out = S_i[np.arange(idx.shape[0])[:,None], idx_]
Next we test the performance of the forward filling method in terms of computation time.
maximum = jnp.frompyfunc(jnp.maximum, nin=2, nout=1, identity=None)
# maximum = jnp.frompyfunc(max1, nin=2, nout=1)
@jax.jit
def ffill_na(x, mask):
"""Forward-fill nan values in x. If a mask is provided, assume
Parameters
----------
x : _type_
_description_
mask : _type_
_description_
Returns
-------
_type_
_description_
"""
if mask is None:
mask = jnp.logical_not(jnp.isnan(x))
idx = jnp.where(mask,jnp.arange(mask.shape[1]),0)
idx_ = maximum.accumulate(jnp.array(idx), axis=1).astype(int)
return x[jnp.arange(idx.shape[0])[:,None], idx_]
mask_obs_ = jnp.logical_not(mask_obs)
print("False indicates a missing value: \n", mask_obs_)
print(ffill_na(S_i, mask_obs_))
import time
start = time.time()
for i in range(1000):
ffill_na(S_i, mask_obs_)
stop = time.time()
print(round(stop-start, 4), "ms per iteration")
False indicates a missing value:
[[ True True True True True True]
[ True False True True True True]
[ True False False True True True]
[ True True False True True True]
[ True False True False True True]]
[[1. 0.75 0.1 0.05 0.01 0. ]
[1. 1. 0.1 0.05 0.01 0. ]
[1. 1. 1. 0.05 0.01 0. ]
[1. 0.75 0.75 0.05 0.01 0. ]
[1. 1. 0.1 0.1 0.01 0. ]]
0.0162 ms per iteration
Either of the implementations with max1
and jnp.maximum
are equally fast. 0.006 ms per iteration. This should not contribute to a significant speedup of the operations. The while loop is more difficult to implement, because it involves a call to the jax.lax
primitive while_loop
.
Calculation of conditional probabilities
Having established the forward filling of conditional survival probabilities, we test if the concept also applies to hazards
import warnings
# test if the normal calculation works out
S_c = jnp.round(S_i[:, 1:] / ffill_na(S_i, ~mask_obs)[:, :-1], 3)
S_c = np.where(mask_obs[:,1:], np.nan, S_c)
np.testing.assert_almost_equal(S_c, S_c_true[:, 1:])
with warnings.catch_warnings(action="ignore"):
# Runtime warning can be ignored because S(t=infty) = 0 -> -infty us correct
H_i = -np.log(S_i)
H_im1 = ffill_na(H_i, ~mask_obs)
# version a (- (H_1 - H_0)):
S_c_hazard = jnp.round(jnp.exp(- (H_i[:, 1:] - H_im1[:, :-1] )), 3)
S_c_hazard = np.where(mask_obs[:,1:], np.nan, S_c_hazard)
np.testing.assert_almost_equal(S_c_hazard, S_c_true[:, 1:])
# version b (H_0 - H_1):
# slightly more efficient, because two negative ops less
S_c_hazard = jnp.round(jnp.exp(H_im1[:, :-1] - (H_i[:, 1:])), 3)
S_c_hazard = np.where(mask_obs[:,1:], np.nan, S_c_hazard)
np.testing.assert_almost_equal(S_c_hazard, S_c_true[:, 1:])
# version c (H_0 - H_1):
# forward fills directly in the computation
S_c_hazard = jnp.round(jnp.exp(ffill_na(H_i, ~mask_obs)[:, :-1] - (H_i[:, 1:])), 3)
S_c_hazard = np.where(mask_obs[:,1:], np.nan, S_c_hazard)
np.testing.assert_almost_equal(S_c_hazard, S_c_true[:, 1:])
The above shows that both methods either using hazard or using survival probabilities directly yield equivalent correct conditional survival probabilities.
def conditional_survival_from_hazard(x, mask):
"""Calculates the conditional survival from cumulative hazard values.
This equation is used when survival is repeatedly observed over time and
Parameters
----------
x : np.ndarray[I,T, float]
A 2-dimensional I x T array of cumulative hazards defined as H = -ln(S).
I is the batch dimension and T is the time dimension
mask : np.ndarray[I,T, bool]
A 2-dimensional array of the same shape as x, taking True if the survival
was observed for the given index (i,t) and taking False if survival was
not observed for the given index (i,t).
Returns
-------
out : np.ndarray[I,T, float]
A matrix with conditional probabilities and nans in place where the
mask has nans. Output has the same shape as input.
Example
-------
Calculation example from survival probabilities to conditional survival
probabilities given some masked values.
>>> S_i = np.array([
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
>>> [1. , 0.75, 0.1 , 0.05, 0.01, 0. ],
>>> ])
>>> mask_obs = np.array([
>>> [ True, True, True, True, True, True],
>>> [ True, False, True, True, True, True],
>>> [ True, False, False, True, True, True],
>>> [ True, True, False, True, True, True],
>>> [ True, False, True, False, True, True],
>>> ])
>>> conditional_survival_from_hazard(-jnp.log(S_i), mask_obs)
array([
[1.0 0.75 0.133 0.5 0.2 0.0]
[1.0 nan 0.1 0.5 0.2 0.0]
[1.0 nan nan 0.05 0.2 0.0]
[1.0 0.75 nan 0.0667 0.2 0.0]
[1.0 nan 0.1 nan 0.1 0.0]
])
"""
# Append zeros (hazard) to the beginning of the array (this aligns with the
# safe assumption that before the zeroth observation S(t=-1) = 1.0)
x_ = jnp.column_stack([jnp.zeros_like(x[:, 0]), x])
# also mask needs to be expanded accordingly
mask_ = jnp.column_stack([jnp.ones_like(mask[:, 0]), mask])
# fill NaNs with forward
x_filled = ffill_na(x_, mask_)
# calculate the conditional survival.
conditional_survival = jnp.exp(x_filled[:, :-1] - (x_[:, 1:]))
# add nans and return
return jnp.where(
mask, conditional_survival, jnp.nan
)
S_c_no_nans = conditional_survival_from_hazard(-jnp.log(S_i), jnp.ones_like(S_i))
S_c_masked = conditional_survival_from_hazard(-jnp.log(S_i), ~mask_obs)
print(S_c_masked)
np.testing.assert_almost_equal(np.round(S_c_masked,3), S_c_true)
start = time.time()
for i in range(1000):
conditional_survival_from_hazard(-jnp.log(S_i), ~mask_obs)
stop = time.time()
print(round(stop-start, 4), "ms per iteration")
[[1. 0.75 0.13333333 0.5 0.19999999 0. ]
[1. nan 0.09999999 0.5 0.19999999 0. ]
[1. nan nan 0.05 0.19999999 0. ]
[1. 0.75 nan 0.06666666 0.19999999 0. ]
[1. nan 0.09999999 nan 0.09999999 0. ]]
1.7509 ms per iteration
def conditional_binomial_hazard_model_trunc_prior_masked_corrected(survival, trials, mask):
batch_size, observation_times = survival.shape
b = numpyro.sample("b", numpyro.distributions.TruncatedCauchy(scale=np.repeat([1.0],repeats=10), low=0, high=1000))
atol = 1e-6
sol = jax.vmap(lambda b: solve(t, b=b, atol=atol, rtol=1e-3))(b)
H = sol.ys[:,:,0]
H = jnp.column_stack([H, jnp.full_like(H[:, 0], jnp.inf)])
S_cond = conditional_survival_from_hazard(H, mask)
with numpyro.plate("time", observation_times):
with numpyro.plate("batch", batch_size):
# use the .mask() method of the distribution
counts = numpyro.sample("lethality", dist.Binomial(total_count=trials, probs=S_cond).mask(mask), obs=survival)
return counts
print(survival, trials, mask)
# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_hazard_model_trunc_prior_masked_corrected, init_strategy=init.init_to_sample)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(1), survival=survival, trials=trials, mask=mask.astype(bool))
# Get samples
samples_conditional_binomial = mcmc.get_samples()
b_est = samples_conditional_binomial["b"].mean(axis=0)
print("estimates b:", b_est)
# allow 5% deviation from any true probability value
try:
np.testing.assert_allclose(b_est, b, atol=0.05)
except AssertionError:
print("The mean model parameter estimates diverge from the true values")
idata = az.convert_to_inference_data(
{"b":np.expand_dims(samples_conditional_binomial["b"], axis=0)},
coords={"sample":range(10)},
dims={"b":["chain", "draw", "sample"]}
)
az.plot_forest(idata, hdi_prob=0.94)
plt.plot(b, 22.4-np.arange(0,10)* 2.5, marker="o", ls="", color="black", ms=3)
[[1.000e+04 nan 1.700e+01 1.000e+00 0.000e+00 0.000e+00 0.000e+00
0.000e+00 nan nan 0.000e+00 0.000e+00]
[1.500e+04 6.010e+02 2.200e+01 0.000e+00 0.000e+00 nan 0.000e+00
0.000e+00 0.000e+00 nan nan nan]
[2.000e+04 nan 4.100e+01 nan 1.000e+00 0.000e+00 0.000e+00
0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00]
[2.500e+04 nan 1.170e+02 8.000e+00 0.000e+00 0.000e+00 0.000e+00
0.000e+00 nan 0.000e+00 0.000e+00 0.000e+00]
[3.000e+04 1.129e+03 nan 1.000e+00 0.000e+00 nan nan
0.000e+00 nan 0.000e+00 nan 0.000e+00]
[3.500e+04 1.521e+03 6.000e+01 4.000e+00 0.000e+00 0.000e+00 0.000e+00
nan nan 0.000e+00 0.000e+00 0.000e+00]
[4.000e+04 nan 1.230e+02 nan 1.000e+00 0.000e+00 0.000e+00
0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00]
[4.500e+04 nan 6.600e+01 5.000e+00 1.000e+00 0.000e+00 0.000e+00
nan 0.000e+00 0.000e+00 0.000e+00 0.000e+00]
[5.000e+04 nan 1.120e+02 3.000e+00 nan 0.000e+00 0.000e+00
0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00]
[5.500e+04 nan 1.300e+02 6.000e+00 0.000e+00 0.000e+00 0.000e+00
0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00]] [[10000 10000 10000 17 1 0 0 0 0 0 0 0]
[15000 15000 601 22 0 0 0 0 0 0 0 0]
[20000 20000 20000 41 41 1 0 0 0 0 0 0]
[25000 25000 25000 117 8 0 0 0 0 0 0 0]
[30000 30000 1129 1129 1 0 0 0 0 0 0 0]
[35000 35000 1521 60 4 0 0 0 0 0 0 0]
[40000 40000 40000 123 123 1 0 0 0 0 0 0]
[45000 45000 45000 66 5 1 0 0 0 0 0 0]
[50000 50000 50000 112 3 3 0 0 0 0 0 0]
[55000 55000 55000 130 6 0 0 0 0 0 0 0]] [[1 0 1 1 1 1 1 1 0 0 1 1]
[1 1 1 1 1 0 1 1 1 0 0 0]
[1 0 1 0 1 1 1 1 1 1 1 1]
[1 0 1 1 1 1 1 1 0 1 1 1]
[1 1 0 1 1 0 0 1 0 1 0 1]
[1 1 1 1 1 1 1 0 0 1 1 1]
[1 0 1 0 1 1 1 1 1 1 1 1]
[1 0 1 1 1 1 1 0 1 1 1 1]
[1 0 1 1 0 1 1 1 1 1 1 1]
[1 0 1 1 1 1 1 1 1 1 1 1]]
sample: 100%|██████████| 1500/1500 [00:14<00:00, 104.87it/s, 7 steps of size 6.53e-01. acc. prob=0.88]
estimates b: [3.1883357 3.221586 3.0853176 2.683745 3.2806582 3.1381066 2.8905814
3.2526445 3.0562105 3.0263512]
The mean model parameter estimates diverge from the true values
[<matplotlib.lines.Line2D at 0x7f6c982c8850>]
The coefficients are estimated much better with the improved model. When N is small (n=2500) ID=3 is estimated incorrectly (not shown). This probably has to do with the fact that the 1st observation is censored and the 2nd observation is exceptionally large. It is simply unlikely, hence there is a small bias in the observation. All the more reason, to include many observations early on. This suspicion is confirmed, when using larger n. Then all true values lie within the 94% HDI, and the IDs where all observations are present (IDs: 1,4,5) are hit with high precision and high confidence.
In the final part of the series Conditional survival part 4: Writing a probability distribution, all previous insights are combined in a single distribution made available for numpyro and optionally for scipy.