Gemlib ====== .. image:: ./images/gem-logo.png :alt: Gemlib :width: 200px :align: center Gemlib is a library providing Python classes for infectious disease modelling. It provides build blocks to assemble complex models, and fit them to data. Features -------- * programmable classes for deterministic, continuous- and discrete-time state transition models * the library is compatible with `TensorFlow Probability `_, allowing complex hierarchical Bayesian models to be built around infectious disease models * a suite of MCMC samplers caters for parameter inference. Gemlib provides random walk Metropolis-Hastings, as well as Hamiltonian Monte Carlo, and specialised samplers for integrating out censored epidemiological event data (e.g. infection times) Installation ------------ The latest release of Gemlib can be installed from PyPI using `pip`:: $ pip install gemlib The current development version of the library can be installed at any time from our GitLab repository with:: $ pip install git+https://gitlab.com/gem-epidemics/gemlib System requirements: * A computer running Linux, Windows 7 (or later), or MacOSX * Python >=3.10,<3.13 * an NVIDIA GPU compatible with the latest version of TensorFlow if ultimate performance is required Quick example ------------- Gemlib presents a powerful API for constructing Markov state transition models, such as are used in infectious disease modelling. Here's a quick example of how to implement a stochastic homogeneously-mixing `SIR model `_ in discrete time.:: import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from gemlib.distributions import DiscreteTimeStateTransitionModel # Represent the S -> I -> R model as a graph incidence matrix incidence_matrix = np.array([[-1, 0], [ 1, -1], [ 0, 1]], dtype=np.float32) # Initial S, I, and R states for a single population initial_state = np.array([[99, 1, 0]], dtype=np.float32) # Define the transition rates def transition_rate_fn(t, state): si_rate = 0.2 * state[:, 1] / tf.reduce_sum(state, axis=-1) ir_rate = tf.fill((state.shape[0],), 0.14) return si_rate, ir_rate # Instantiate the model model = DiscreteTimeStateTransitionModel( transition_rate_fn, incidence_matrix, initial_state, num_steps=50 ) # Draw a realisation of the epidemic process sample = model.sample(seed=[0,0]) # Compute the probability of observing `sample` given the model log_prob = model.log_prob(sample) # Convert the transition event tensor output to numbers in each state over time sample_state = model.compute_state(sample) # Plot simulation plt.plot(np.sum(sample_state, axis=1), label=["S", "I", "R"]) plt.xlabel("Time") plt.ylabel("Number of individuals") _ = plt.legend() .. image:: ./images/sir_example_plot.png :width: 400 :alt: SIR model simulation .. tip:: Since Gemlib is based on `TensorFlow `_, all functions and methods can be optimised by adding `tf.function `_. With complex models such as Gemlib is designed for, this can often result in spectacular speedups compared to unoptimised code.:: seed = [0,1] @tf.function # TF graph mode def fast_sample(seed): return model.sample(seed=seed) %time sample = fast_sample(seed) @tf.function(jit_compile=True) # TF compile to XLA bytecode def faster_sample(seed): return model.sample(seed=seed) sample = faster_sample(seed) %time sample = faster_sample(seed) Acknowledgements ---------------- The Gemlib team is indebted to the `TensorFlow Probability `_ and `BlackJAX `_ teams, whose wonderful ideas have inspired our library architecture. If you haven't already, go and check out these fantastic libraries and spot the similarities with Gemlib. .. toctree:: :maxdepth: 2 :caption: Contents :hidden: api