In my work I usually draw on probabilistic programming languages to make parameter estimates. The archetypical dataset in environmental toxicology is survival data. I.e. counts of surviving organisms or death counts over time. These type of data can statistically be modelled with a conditional binomial [@Delignette-Muller.] or multinomial distribution [@Jager.2018]. Just to be on the safe side, I always wanted to reproduce this if the approaches yield equivalent parameter estimates. Here I reproduce this statement with numpyro and extend the concept further to handle missing observations and censored values. There is also the additional motivation that, at the time of writing, numpyro didn’t support batched multinomial experiments.

import jax
import numpy as np
from matplotlib import pyplot as plt
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

rng = np.random.default_rng(seed=1)

To get started we define some observations of survival of a population of individuals over time.

For this we have a sample matrix of probabilities, where columns belong to day 1, 2 and column 3 contains the probability until infinity.

# Define probabilities for the multinomial distribution
probs = np.array([
    [0.2, 0.5, 0.3],
    [0.3, 0.4, 0.3],
    [0.1, 0.7, 0.2],
    [0.25, 0.25, 0.5],
    [0.6, 0.2, 0.2],
    [0.3, 0.3, 0.4],
    [0.4, 0.4, 0.2],
    [0.2, 0.3, 0.5],
    [0.3, 0.3, 0.4],
    [0.1, 0.6, 0.3]
]) 

In addition to the probability we need the number of trials in each experiment. In order, to make it easier to approximate the parameters exactly, the trials are upscaled by a factor of 100, so that the samples are close to the true proportions.

n_trials = np.array([10, 15, 20, 25, 30, 35, 40, 45, 50, 55]) * 100
lethality = jnp.array(list(map(lambda n, p: rng.multinomial(n=n, pvals=p), n_trials, probs)))
lethality
Array([[ 197,  522,  281],
       [ 451,  597,  452],
       [ 222, 1398,  380],
       [ 611,  629, 1260],
       [1812,  595,  593],
       [1006, 1053, 1441],
       [1530, 1624,  846],
       [ 859, 1401, 2240],
       [1505, 1493, 2002],
       [ 529, 3275, 1696]], dtype=int32)

Multinomial survival model

First we test, if the multinomial model can recover the parameters. To model the probabilities the Dirichlet distribution is used, which satistfies the requirement, that probabilities in the dependent (event) dimension sum to 1. The model is sampled with the NUTS kernel and the estimated probabilities are compared to the true values, passing the test if the deviation is below 5% as a rule of thumb.

def multinomial_model(lethality, n_trials):
    batch_size, num_categories = lethality.shape
    probs = numpyro.sample("p", dist.Dirichlet(jnp.ones((batch_size,3))))
    
    with numpyro.plate("batch", batch_size):
        counts = numpyro.sample("lethality", dist.Multinomial(total_count=n_trials, probs=probs), obs=lethality)
    return counts

# Run MCMC sampling on the model
nuts_kernel = NUTS(multinomial_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)

# Get samples
samples = mcmc.get_samples()["p"]

# allow 5% deviation from any true probability value
np.testing.assert_allclose(samples.mean(axis=0), probs, atol=0.05)
sample: 100%|██████████| 1500/1500 [00:01<00:00, 1087.60it/s, 7 steps of size 5.68e-01. acc. prob=0.90]

This seems to work 🎉 but not really a big surprise.

Conditional binomial model

Next we take a look at the conditional binomial model for testing equivalence. This model requires some transformations, but can be computed quickly

def conditional_binomial_model(lethality, n_trials):
    batch_size, observation_times = lethality.shape 
    probs = numpyro.sample("p", dist.Dirichlet(jnp.ones((batch_size,3))))
    
    # sum the probabilites in the "time" axis [[0, t1), [t1, t2), [t2, np.inf)]
    probs_sum = probs.cumsum(axis=1)

    # probability of having died at time t [0, t1, t2, inf]
    F = jnp.column_stack([jnp.zeros(batch_size), probs_sum])

    # Survival probability
    S = 1 - F

    # conditional survival probability S(t) / S(t-1)
    S_cond = S[:, 1:] / S[:, :-1]

    # We have to do the same with the data and subtract it from the number of trials
    # to obtain the number of survivors at time t [0, t1, t2, inf]
    L_cumulative_obs = jnp.column_stack([jnp.zeros(batch_size), lethality.cumsum(axis=1)])

    # the number of survivors at time t [0, t1, t2, inf]
    S_obs = n_trials.reshape((batch_size, 1)) - L_cumulative_obs

    # obtain the number of survivors before the start of the next interval [0, t1, t2]
    # also note that the last observation at infinity is only theoretical.
    S_before_t = S_obs[:, :-1]
    
    # here we can exploit the fact that due to all our transformations, all
    # observations are i.i.d they are conditionally independent that the 
    # organisms have survived until the respective beginning of the time window
    with numpyro.plate("time", observation_times):
        with numpyro.plate("batch", batch_size):
            counts = numpyro.sample("lethality", dist.Binomial(total_count=S_before_t, probs=S_cond), obs=S_obs[:, 1:])
    return counts


# Run MCMC sampling on the model
nuts_kernel = NUTS(conditional_binomial_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), lethality=lethality, n_trials=n_trials)

# Get samples
samples_conditional_binomial = mcmc.get_samples()

# allow 5% deviation from any true probability value
np.testing.assert_allclose(samples_conditional_binomial["p"].mean(axis=0), probs, atol=0.05)
sample: 100%|██████████| 1500/1500 [00:01<00:00, 1101.91it/s, 7 steps of size 5.50e-01. acc. prob=0.90]

We take away from this that we need the following fomulae for the conditional binomial

$$S^{obs}_i = \sum_{0}^i L^{obs}_i$$

where $L$ denotes the lethality, i.e. the number of deaths until the end of each time time interval $i$ and $L^{obs}_0 = 0$. The time intervals are $0), [0, t1), [t1, t2), …, [tn, \infty)$

def survival(L):
    ...
    # Adapt the function that i had prepared in the bufferguts notebook and 
    # sim.py to process survival or lethality data for survival analysis.

Multinomial and conditional binomial model are equivalent

We have seen that both models can recover the parameters equally well. In order to be certain that both models are identical, we can compare the likelihoods and indeed see that they are equivalent.

log_lik_conditional_binomial = numpyro.infer.log_likelihood(conditional_binomial_model, samples_conditional_binomial, lethality, n_trials)
log_lik_multinomial = numpyro.infer.log_likelihood(multinomial_model, samples_conditional_binomial, lethality, n_trials)

# we take the sum over the independent time dim (as independent probabilites are multiplied) and then take the mean over the MCMC draws
ll_conditional_binomial_sum_mean = log_lik_conditional_binomial["lethality"].sum(axis=-1).mean(axis=(0))

# for the multinomial only the mean has to be computed.
ll_multinomial_mean = log_lik_multinomial["lethality"].mean(axis=0)

fig, ax = plt.subplots(1,1,figsize=(8,2))
ax.plot(ll_conditional_binomial_sum_mean, ls="", marker="o", alpha=.5, label="conditional binomial")
ax.plot(ll_multinomial_mean, ls="", marker="o", alpha=.5, label="multinomial")
ax.set_ylabel("log likelihood")
ax.set_xlabel("batch")
ax.legend()
<matplotlib.legend.Legend at 0x7fdfcc713350>

png

So far so normal. In the next part Conditional survival part 2: Survival probabilities very close to zero we’ll look at what happens when survival probabilities get extremely low. Scenarios that happen frequently under parametric survival functions and high lethality treatments.