Skip to content

Parameters

ModelParameters(prior_dict)

Initialize the calibration model parameters. :param kernel_config: A dictionary containing the kernel configuration.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def __init__(self, prior_dict: ModelParameterPriorDict):
    """
    Initialize the calibration model parameters.
    :param kernel_config: A dictionary containing the kernel configuration.
    """
    # Check if the kernel config is valid
    _check_prior_dict(prior_dict)
    # TODO: Check kernel_config structure agrees with the KOHDataset. Add this check to KOHPosterior?
    # Things to check:
    # - Dimensions of the kernel match the dimensions of the dataset
    # - There are enough theta priors

    self.priors = prior_dict

    self.priors_flat, self.priors_tree = jax.tree.flatten(self.priors)
    self.n_params = len(self.priors_flat)

    # self.prior_log_prob_funcs: List[
    #     Callable[[jaxtyping.Float[Array, "..."]], jaxtyping.Float[Array, ""]]
    # ] = jax.tree.map(lambda dist: jax.jit(dist.log_prob), self.priors_flat)
    self.prior_log_prob_funcs = jax.tree.map(
        lambda dist: jax.jit(dist.log_prob), self.priors_flat
    )

constrain_and_unflatten_sample(samples_flat)

Transform samples to the constrained space and unflatten them to the original prior tree structure. Args: samples_flat: A flat array of samples. Returns: A tree of samples in the constrained space with the same structure as the priors.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def constrain_and_unflatten_sample(
    self,
    samples_flat,  #: ParameterArray
) -> ModelParameterDict:
    """
    Transform samples to the constrained space and unflatten them to the original prior tree structure.
    Args:
        samples_flat: A flat array of samples.
    Returns:
        A tree of samples in the constrained space with the same structure as the priors.
    """
    # Constrain and unflatten the samples
    constrained_samples = self.constrain_sample(samples_flat)
    return self.unflatten_sample(constrained_samples)

constrain_sample(samples_flat)

Transform samples to the constrained space. Args: samples_flat: A flat JAX array of samples. Returns: A list of samples transformed to the constrained space.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def constrain_sample(
    self,
    samples_flat,  #: ParameterArray
):  # -> ParameterValueList:
    """
    Transform samples to the constrained space.
    Args:
        samples_flat: A flat JAX array of samples.
    Returns:
        A list of samples transformed to the constrained space.
    """
    return [
        prior.forward(samples_flat[i]) for i, prior in enumerate(self.priors_flat)
    ]

get_log_prior_func()

Compute the joint log prior probability.

Returns:

  • Callable[[ParameterValueList], Scalar]

    A function that computes the joint log prior probability.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def get_log_prior_func(
    self,
) -> Callable[[ParameterValueList], jaxtyping.Scalar]:
    """Compute the joint log prior probability.

    Returns:
        A function that computes the joint log prior probability.
    """

    @jax.jit
    def log_prior_func(
        unconstrained_params_flat: ParameterValueList,
    ) -> jaxtyping.Scalar:
        """
        Compute the joint log prior probability.
        Args:
            unconstrained_params_flat: A flat array of unconstrained parameters.
        Returns:
            The joint log prior probability.
        """
        log_probs = jax.tree.map(
            lambda log_prob_func, x: log_prob_func(x),
            self.prior_log_prob_funcs,
            unconstrained_params_flat,
        )
        return jnp.sum(jnp.concatenate([jnp.atleast_1d(x) for x in log_probs]))

    return log_prior_func

unflatten_sample(samples_flat)

Unflatten the samples to the original prior tree structure. Args: samples_flat: A flat array of samples. Returns: A tree of samples with the same structure as the priors.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def unflatten_sample(
    self,
    samples_flat,  #: ParameterValueList
) -> ModelParameterDict:
    """
    Unflatten the samples to the original prior tree structure.
    Args:
        samples_flat: A flat array of samples.
    Returns:
        A tree of samples with the same structure as the priors.
    """
    # Unflatten the samples to the original tree structure
    return jax.tree.unflatten(self.priors_tree, samples_flat)

ParameterPrior(distribution, name=None)

Distribution on the constrained parameter. Args: distribution: A numpyro Distribution object representing the prior distribution. name: Optional name for the prior.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def __init__(
    self,
    distribution: npd.Distribution,
    name: Optional[str] = None,
):
    """Distribution on the constrained parameter.
    Args:
        distribution: A numpyro Distribution object representing the prior distribution.
        name: Optional name for the prior.
    """
    if not isinstance(distribution, npd.Distribution):
        raise ValueError("distribution must be a numpyro Distribution object.")

    self.distribution = distribution
    self.bijector = npt.biject_to(
        self.distribution.support
    )  # This maps Reals to the constrained space
    self.name = name

forward(y)

Transform the input to the constrained space. Args: y: The unconstrained input value.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def forward(self, y):
    """Transform the input to the constrained space.
    Args:
        y: The unconstrained input value.
    """
    return self.bijector(y)

inverse(x)

Transform the input to the unconstrained space. Args: x: The constrained input value.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def inverse(self, x):
    """Transform the input to the unconstrained space.
    Args:
        x: The constrained input value.
    """
    return self.bijector._inverse(x)

log_prob(y)

Compute the log probability density function (PDF) of the distribution. Args: y: The unconstrained input value.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def log_prob(self, y):
    """Compute the log probability density function (PDF) of the distribution.
    Args:
        y: The unconstrained input value.
    """
    x = self.bijector(y)  # map y in Reals onto x in the constrained space
    logdet = self.bijector.log_abs_det_jacobian(y, x)

    return self.distribution.log_prob(x) + logdet

prob(y)

Compute the probability density function (PDF) of the distribution. Args: y: The unconstrained input value.

Source code in .tox/docs/lib/python3.12/site-packages/kohgpjax/parameters.py
def prob(self, y):
    """Compute the probability density function (PDF) of the distribution.
    Args:
        y: The unconstrained input value.
    """
    return jnp.exp(self.log_prob(y))