Bayesian Optimization¶
Bayesian optimization is a powerful strategy for minimizing (or maximizing) objective functions that are costly to evaluate. It is an important component of automated machine learning toolboxes such as auto-sklearn, auto-weka, and scikit-optimize, where Bayesian optimization is used to select model hyperparameters. Bayesian optimization is used for a wide range of other applications as well; as cataloged in the review [2], these include interactive user-interfaces, robotics, environmental monitoring, information extraction, combinatorial optimization, sensor networks, adaptive Monte Carlo, experimental design, and reinforcement learning.
Problem Setup¶
We are given a minimization problem
where
To solve the minimization problem, we will construct a sequence of points
The Bayesian optimization strategy works as follows:
Place a prior on the objective function
. Each time we evaluate at a new point , we update our model for . This model serves as a surrogate objective function and reflects our beliefs about (in particular it reflects our beliefs about where we expect to be close to ). Since we are being Bayesian, our beliefs are encoded in a posterior that allows us to systematically reason about the uncertainty of our model predictions.Use the posterior to derive an “acquisition” function
that is easy to evaluate and differentiate (so that optimizing is easy). In contrast to , we will generally evaluate at many points , since doing so will be cheap.Repeat until convergence:
Use the acquisition function to derive the next query point according to
Evaluate
and update the posterior.
A good acquisition function should make use of the uncertainty encoded in the posterior to encourage a balance between exploration—querying points where we know little about
In this tutorial, our model for
[1]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import torch
import torch.autograd as autograd
import torch.optim as optim
from torch.distributions import constraints, transform_to
import pyro
import pyro.contrib.gp as gp
assert pyro.__version__.startswith('1.9.1')
pyro.set_rng_seed(1)
Define an objective function¶
For the purposes of demonstration, the objective function we are going to consider is the Forrester et al. (2008) function:
This function has both a local minimum and a global minimum. The global minimum is at
[2]:
def f(x):
return (6 * x - 2)**2 * torch.sin(12 * x - 4)
Let’s begin by plotting
[3]:
x = torch.linspace(0, 1, 100)
plt.figure(figsize=(8, 4))
plt.plot(x.numpy(), f(x).numpy())
plt.show()

Setting a Gaussian Process prior¶
Gaussian processes are a popular choice for a function priors due to their power and flexibility. The core of a Gaussian Process is its covariance function
where
We choose the Matern kernel with
[4]:
# initialize the model with four input points: 0.0, 0.33, 0.66, 1.0
X = torch.tensor([0.0, 0.33, 0.66, 1.0])
y = f(X)
gpmodel = gp.models.GPRegression(X, y, gp.kernels.Matern52(input_dim=1),
noise=torch.tensor(0.1), jitter=1.0e-4)
The following helper function update_posterior
will take care of updating our gpmodel
each time we evaluate
[5]:
def update_posterior(x_new):
y = f(x_new) # evaluate f at new point.
X = torch.cat([gpmodel.X, x_new]) # incorporate new evaluation
y = torch.cat([gpmodel.y, y])
gpmodel.set_data(X, y)
# optimize the GP hyperparameters using Adam with lr=0.001
optimizer = torch.optim.Adam(gpmodel.parameters(), lr=0.001)
gp.util.train(gpmodel, optimizer)
Define an acquisition function¶
There are many reasonable options for the acquisition function (see references [1] and [2] for a list of popular choices and a discussion of their properties). Here we will use one that is ‘simple to implement and interpret,’ namely the ‘Lower Confidence Bound’ acquisition function. It is given by
where
[6]:
def lower_confidence_bound(x, kappa=2):
mu, variance = gpmodel(x, full_cov=False, noiseless=False)
sigma = variance.sqrt()
return mu - kappa * sigma
The final component we need is a way to find (approximate) minimizing points
First, we seed our minimization algorithm with 5 different values: i) one is chosen to be
, i.e. the candidate used in the previous step; and ii) four are chosen uniformly at random from the domain of the objective function.We then run the minimization algorithm to approximate convergence for each seed value.
Finally, from the five candidate
s identified by the minimization algorithm, we select the one that minimizes the acquisition function.
Please refer to reference [2] for a more detailed discussion of this problem in Bayesian Optimization.
[7]:
def find_a_candidate(x_init, lower_bound=0, upper_bound=1):
# transform x to an unconstrained domain
constraint = constraints.interval(lower_bound, upper_bound)
unconstrained_x_init = transform_to(constraint).inv(x_init)
unconstrained_x = unconstrained_x_init.clone().detach().requires_grad_(True)
minimizer = optim.LBFGS([unconstrained_x], line_search_fn='strong_wolfe')
def closure():
minimizer.zero_grad()
x = transform_to(constraint)(unconstrained_x)
y = lower_confidence_bound(x)
autograd.backward(unconstrained_x, autograd.grad(y, unconstrained_x))
return y
minimizer.step(closure)
# after finding a candidate in the unconstrained domain,
# convert it back to original domain.
x = transform_to(constraint)(unconstrained_x)
return x.detach()
The inner loop of Bayesian Optimization¶
With the various helper functions defined above, we can now encapsulate the main logic of a single step of Bayesian Optimization in the function next_x
:
[8]:
def next_x(lower_bound=0, upper_bound=1, num_candidates=5):
candidates = []
values = []
x_init = gpmodel.X[-1:]
for i in range(num_candidates):
x = find_a_candidate(x_init, lower_bound, upper_bound)
y = lower_confidence_bound(x)
candidates.append(x)
values.append(y)
x_init = x.new_empty(1).uniform_(lower_bound, upper_bound)
argmin = torch.min(torch.cat(values), dim=0)[1].item()
return candidates[argmin]
Running the algorithm¶
To illustrate how Bayesian Optimization works, we make a convenient plotting function that will help us visualize our algorithm’s progress.
[9]:
def plot(gs, xmin, xlabel=None, with_title=True):
xlabel = "xmin" if xlabel is None else "x{}".format(xlabel)
Xnew = torch.linspace(-0.1, 1.1, 100)
ax1 = plt.subplot(gs[0])
ax1.plot(gpmodel.X.numpy(), gpmodel.y.numpy(), "kx") # plot all observed data
with torch.no_grad():
loc, var = gpmodel(Xnew, full_cov=False, noiseless=False)
sd = var.sqrt()
ax1.plot(Xnew.numpy(), loc.numpy(), "r", lw=2) # plot predictive mean
ax1.fill_between(Xnew.numpy(), loc.numpy() - 2*sd.numpy(), loc.numpy() + 2*sd.numpy(),
color="C0", alpha=0.3) # plot uncertainty intervals
ax1.set_xlim(-0.1, 1.1)
ax1.set_title("Find {}".format(xlabel))
if with_title:
ax1.set_ylabel("Gaussian Process Regression")
ax2 = plt.subplot(gs[1])
with torch.no_grad():
# plot the acquisition function
ax2.plot(Xnew.numpy(), lower_confidence_bound(Xnew).numpy())
# plot the new candidate point
ax2.plot(xmin.numpy(), lower_confidence_bound(xmin).numpy(), "^", markersize=10,
label="{} = {:.5f}".format(xlabel, xmin.item()))
ax2.set_xlim(-0.1, 1.1)
if with_title:
ax2.set_ylabel("Acquisition Function")
ax2.legend(loc=1)
Our surrogate model gpmodel
already has 4 function evaluations at its disposal; however, we have yet to optimize the GP hyperparameters. So we do that first. Then in a loop we call the next_x
and update_posterior
functions repeatedly. The following plot illustrates how Gaussian Process posteriors and the corresponding acquisition functions change at each step in the algorith. Note how query points are chosen both for exploration and exploitation.
[10]:
plt.figure(figsize=(12, 30))
outer_gs = gridspec.GridSpec(5, 2)
optimizer = torch.optim.Adam(gpmodel.parameters(), lr=0.001)
gp.util.train(gpmodel, optimizer)
for i in range(8):
xmin = next_x()
gs = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=outer_gs[i])
plot(gs, xmin, xlabel=i+1, with_title=(i % 2 == 0))
update_posterior(xmin)
plt.show()

Because we have assumed that our observations contain noise, it is improbable that we will find the exact minimizer of the function
While this tutorial is only intended to be a brief introduction to Bayesian Optimization, we hope that we have been able to convey the basic underlying ideas. Consider watching the lecture by Nando de Freitas [3] for an excellent exposition of the basic theory. Finally, the reference paper [2] gives a review of recent research on Bayesian Optimization, together with many discussions about important technical details.
References¶
[1] Practical bayesian optimization of machine learning algorithms
, Jasper Snoek, Hugo Larochelle, and Ryan P. Adams
[2] Taking the human out of the loop: A review of bayesian optimization
, Bobak Shahriari, Kevin Swersky, Ziyu Wang, Ryan P. Adams, and Nando De Freitas
[3] Machine learning - Bayesian optimization and multi-armed bandits