Gemlib#

Gemlib

Gemlib is a library providing Python classes for infectious disease modelling. It provides build blocks to assemble complex models, and fit them to data.

Features#

  • programmable classes for deterministic, continuous- and discrete-time state transition models

  • the library is compatible with TensorFlow Probability, allowing complex hierarchical Bayesian models to be built around infectious disease models

  • a suite of MCMC samplers caters for parameter inference. Gemlib provides random walk Metropolis-Hastings, as well as Hamiltonian Monte Carlo, and specialised samplers for integrating out censored epidemiological event data (e.g. infection times)

Installation#

The latest release of Gemlib can be installed from PyPI using pip:

$ pip install gemlib

The current development version of the library can be installed at any time from our GitLab repository with:

$ pip install git+https://gitlab.com/gem-epidemics/gemlib

System requirements:

  • A computer running Linux, Windows 7 (or later), or MacOSX

  • Python >=3.10,<3.13

  • an NVIDIA GPU compatible with the latest version of TensorFlow if ultimate performance is required

Quick example#

Gemlib presents a powerful API for constructing Markov state transition models, such as are used in infectious disease modelling. Here’s a quick example of how to implement a stochastic homogeneously-mixing SIR model in discrete time.:

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from gemlib.distributions import DiscreteTimeStateTransitionModel

# Represent the S -> I -> R model as a graph incidence matrix
incidence_matrix = np.array([[-1,  0],
                            [ 1, -1],
                            [ 0,  1]], dtype=np.float32)

# Initial S, I, and R states for a single population
initial_state = np.array([[99, 1, 0]], dtype=np.float32)

# Define the transition rates
def transition_rate_fn(t, state):
    si_rate = 0.2 * state[:, 1] / tf.reduce_sum(state, axis=-1)
    ir_rate = tf.fill((state.shape[0],), 0.14)
    return si_rate, ir_rate

# Instantiate the model
model = DiscreteTimeStateTransitionModel(
    transition_rate_fn, incidence_matrix, initial_state, num_steps=50
)

# Draw a realisation of the epidemic process
sample = model.sample(seed=[0,0])

# Compute the probability of observing `sample` given the model
log_prob = model.log_prob(sample)

# Convert the transition event tensor output to numbers in each state over time
sample_state = model.compute_state(sample)

# Plot simulation
plt.plot(np.sum(sample_state, axis=1), label=["S", "I", "R"])
plt.xlabel("Time")
plt.ylabel("Number of individuals")
_ = plt.legend()
SIR model simulation

Tip

Since Gemlib is based on TensorFlow, all functions and methods can be optimised by adding tf.function. With complex models such as Gemlib is designed for, this can often result in spectacular speedups compared to unoptimised code.:

seed = [0,1]

@tf.function # TF graph mode
def fast_sample(seed):
    return model.sample(seed=seed)

%time sample = fast_sample(seed)

@tf.function(jit_compile=True) # TF compile to XLA bytecode
def faster_sample(seed):
    return model.sample(seed=seed)

sample = faster_sample(seed)
%time sample = faster_sample(seed)

Acknowledgements#

The Gemlib team is indebted to the TensorFlow Probability and BlackJAX teams, whose wonderful ideas have inspired our library architecture. If you haven’t already, go and check out these fantastic libraries and spot the similarities with Gemlib.