Welcome to KOH-GPJax
KOH-GPJax is a modern Python implementation of the Bayesian Calibration of Computer Models framework first outlined by Kennedy & O'Hagan (2001) [external]1. This package inherits the GPU acceleration and just-in-time compilation features of JAX [external] and builds upon the flexible yet elegant Gaussian processes package GPJax [external].
Basic Example
Bayesian calibration procedures require three ingredients: Gaussian process kernel defintions, parameter prior definitions and a posterior sampler. KOH-GPJax provides the infrastucture to build the Bayesian model using the first two components provided by the user and exposes a log posterior density function for the MCMC sampler of your choosing.
The data for this problem can be found in examples/data/
.
import jax.numpy as jnp
from jax import config, grad, jit
from kohgpjax.parameters import ModelParameters
from dataloader import kohdataset
from model import Model
from priors import prior_dict
config.update("jax_enable_x64", True)
model_parameters = ModelParameters(prior_dict=prior_dict)
model = Model(
model_parameters=model_parameters,
kohdataset=kohdataset,
)
nlpd_func = model.get_KOH_neg_log_pos_dens_func()
# JIT-compile the NLPD function
nlpd_jitted = jit(nlpd_func)
# Compute the gradient of the NLPD
grad_nlpd_jitted = jit(grad(nlpd_func))
# Example usage
# Alternatively take the mean of each parameter's prior distribution.
example_params = jnp.array([0.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
nlpd_value = nlpd_jitted(example_params)
nlpd_gradient = grad_nlpd_jitted(example_params)
print("NLPD Value:", nlpd_value)
print("NLPD Gradient:", nlpd_gradient)
import gpjax as gpx
import jax.numpy as jnp
from kohgpjax.kohmodel import KOHModel
class Model(KOHModel):
def k_eta(self, params_constrained):
params = params_constrained["eta"]
return gpx.kernels.ProductKernel(
kernels=[
gpx.kernels.RBF(
active_dims=[0],
lengthscale=jnp.array(params["lengthscales"]["x_0"]),
variance=jnp.array(
1
/ params["variances"][
"precision"
] # Precision consistent with Higdon et al. (2004)
),
),
gpx.kernels.RBF(
active_dims=[1],
lengthscale=jnp.array(params["lengthscales"]["theta_0"]),
),
]
)
def k_delta(self, params_constrained):
params = params_constrained["delta"]
return gpx.kernels.RBF(
active_dims=[0],
lengthscale=jnp.array(params["lengthscales"]["x_0"]),
variance=jnp.array(1 / params["variances"]["precision"]),
)
# Definition of k_epsilon is optional in model. Defaults behaviour is to use a White kernel.
# Alternative observation noise kernels should be defined here if needed.
# The prior for ["epsilon"]["variances"]["variance"] is still required in priors.py
# even if user defined k_epsilon is not provided.
def k_epsilon(self, params_constrained):
params = params_constrained["epsilon"]
return gpx.kernels.White(
#
active_dims=[0],
variance=jnp.array(1 / params["variances"]["precision"]),
)
# k_epsilon_eta is completely optional. Default behaviour is a white kernel with
# variance=0 effectively turning off this component.
# def k_epsilon_eta(self, params_constrained) -> gpx.kernels.AbstractKernel:
# params = params_constrained['epsilon_eta']
# return gpx.kernels.White(
# active_dims=[0],
# variance=jnp.array(1/params['variances']['precision'])
# )
import numpyro.distributions as dist
from kohgpjax.parameters import ModelParameterPriorDict, ParameterPrior
prior_dict: ModelParameterPriorDict = {
"thetas": {
"theta_0": ParameterPrior(
dist.Uniform(low=0.3, high=0.5),
name="theta_0",
),
},
"eta": {
"variances": {
"precision": ParameterPrior( # Precision consistent with Higdon et al. (2004)
dist.Gamma(concentration=2.0, rate=4.0),
name="eta_precision",
),
},
"lengthscales": {
"x_0": ParameterPrior(
dist.Gamma(concentration=4.0, rate=1.4),
name="eta_lengthscale_x_0",
),
"theta_0": ParameterPrior(
dist.Gamma(concentration=2.0, rate=3.5),
name="eta_lengthscale_theta_0",
),
},
},
"delta": {
"variances": {
"precision": ParameterPrior( # Precision consistent with Higdon et al. (2004)
dist.Gamma(concentration=2.0, rate=0.1),
name="delta_precision",
),
},
"lengthscales": {
"x_0": ParameterPrior(
dist.Gamma(
concentration=5.0, rate=0.3
), # encourage long value => linear discrepancy
name="delta_lengthscale_x_0",
),
},
},
"epsilon": {
"variances": {
"precision": ParameterPrior( # Precision consistent with Higdon et al. (2004)
dist.Gamma(concentration=800, rate=2.0),
name="epsilon_precision",
),
},
},
}
import gpjax as gpx
import jax.numpy as jnp
import numpy as np
from jax import config
from kohgpjax.dataset import KOHDataset
config.update("jax_enable_x64", True)
DATAFIELD = np.loadtxt("field.csv", delimiter=",", dtype=np.float32)
DATASIM = np.loadtxt("sim.csv", delimiter=",", dtype=np.float32)
xf = jnp.reshape(DATAFIELD[:, 0], (-1, 1)).astype(jnp.float64)
xc = jnp.reshape(DATASIM[:, 0], (-1, 1)).astype(jnp.float64)
tc = jnp.reshape(DATASIM[:, 1], (-1, 1)).astype(jnp.float64)
yf = jnp.reshape(DATAFIELD[:, 1], (-1, 1)).astype(jnp.float64)
yc = jnp.reshape(DATASIM[:, 2], (-1, 1)).astype(jnp.float64)
field_dataset = gpx.Dataset(xf, yf)
sim_dataset = gpx.Dataset(jnp.hstack((xc, tc)), yc)
kohdataset = KOHDataset(field_dataset, sim_dataset)
-
Kennedy, M.C. and O'Hagan, A. (2001), Bayesian calibration of computer models. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 63: 425-464. https://doi.org/10.1111/1467-9868.00294 ↩