One of the cornerstones of my PhD research has been the development of KOH-GPJax, a Python package for Bayesian calibration of computer models. In this post, I want to walk through what KOH-GPJax is, why I built it, and how it can be used to dramatically accelerate the calibration process.
The Challenge: Calibrating Expensive Computer Models
Computer models are used everywhere, from climate science to engineering, to simulate complex real-world systems. However, these models have parameters that need to be “tuned” or “calibrated” so that the model’s output matches real-world observations. The Kennedy and O’Hagan (2001) framework is a powerful statistical approach for this, but it can be computationally expensive, especially for models with many parameters or that take a long time to run.
The Solution: KOH-GPJax
I developed KOH-GPJax to address this challenge. It is a pure JAX implementation of the Kennedy and O’Hagan calibration framework. By leveraging JAX, KOH-GPJax can:
- Run on GPUs and TPUs: This can lead to massive speedups (10-100x or more) compared to traditional CPU-based implementations.
- Use modern MCMC methods: KOH-GPJax uses state-of-the-art Hamiltonian Monte Carlo (HMC) samplers like those in
blackjax
for efficient inference. - Scale efficiently: This includes scaling to large volumes of data and to calibrating many parameters at once representing a step-change in capability over previous implementations.
A Use Case: Calibrating Climate Models
A key application of KOH-GPJax has been in my works with the UK Met Office and Stanford University, where we have used it to calibrate parameters in weather and climate models.
To see how it works in practice, you can look at the examples in the BayesianCalibrationExamples repository.
theta-dim/
example: This example showcases the calibration of a model with a higher-dimensional parameter space. It demonstrates how KOH-GPJax can efficiently explore this space and find the optimal parameter settings. This is crucial for real-world climate models which can have dozens of parameters.MiMA/
example: The Model of an idealized Moist Atmosphere (MiMA) is an intermediate-complexity climate model that is often used as a testbed for new methods. In this example, we use KOH-GPJax to calibrate a Betts-Miller convection schemme against observational data. This demonstrates the end-to-end workflow of using KOH-GPJax on a realistic climate modelling problem.
Why This Matters for RSE and Quant Roles
Developing and applying tools like KOH-GPJax requires a unique combination of skills:
- Deep understanding of Bayesian statistics and Gaussian Processes: Essential for developing and implementing the core calibration framework.
- Strong software engineering skills: Building a robust, well-tested, and easy-to-use Python package.
- Expertise in high-performance computing: Leveraging modern tools like JAX to accelerate scientific workflows.
These are the skills I am excited to bring to a Research Software Engineer (RSE) or Machine Learning Engineer role. If you are working on problems that involve calibrating complex models, I would love to chat about how my experience could be of help. Check out the KOH-GPJax GitHub page to learn more.